import cv2
import numpy as np
import insightface
from insightface.app import FaceAnalysis
import os
import argparse
import mysql.connector
import logging

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

def update_job_status(job_id, status, error=None):
    conn = mysql.connector.connect(
        host="localhost",
        user="root",
        password="abc1234$",
        database="moro"
    )
    cursor = conn.cursor()

    query = "UPDATE face_swap_jobs SET status = %s, error = %s WHERE id = %s"
    cursor.execute(query, (status, error, job_id))

    conn.commit()
    cursor.close()
    conn.close()

def face_swap(job_id, source_path, target_path, output_path):
    try:
        app = FaceAnalysis(
            name='buffalo_l',
            allowed_modules=['detection', 'landmark_2d_106', 'recognition'],
            providers=['CPUExecutionProvider']
        )
        app.prepare(ctx_id=0, det_size=(640, 640))

        model_path = os.path.expanduser('~/.insightface/models/inswapper_128.onnx')
        swapper = insightface.model_zoo.get_model(model_path, providers=['CPUExecutionProvider'])

        source_img = cv2.imread(source_path)
        target_img = cv2.imread(target_path)

        if source_img is None or target_img is None:
            update_job_status(job_id, "failed", "Unable to read source or target image")
            raise Exception("Unable to read source or target image")

        source_faces = app.get(source_img)
        target_faces = app.get(target_img)

        if len(source_faces) == 0 or len(target_faces) == 0:
            update_job_status(job_id, "failed", "No face found in the photo")
            raise Exception("No face found in the photo")

        source_face = source_faces[0]
        target_face = target_faces[0]

        result = swapper.get(target_img, target_face, source_face, paste_back=True)

        if result is None:
            update_job_status(job_id, "failed", "Face Swap failed")
            raise Exception("Face Swap failed")

        result = cv2.fastNlMeansDenoisingColored(result, None, 3, 3, 7, 21)
        cv2.imwrite(output_path, result)
        update_job_status(job_id, "finished", None)

        return True

    except Exception as e:
        logging.error(f"Error in face_swap: {str(e)}")
        update_job_status(job_id, "failed", str(e))
        return False

def main():
    parser = argparse.ArgumentParser(description='Face Swap Tool')
    parser.add_argument('--source', required=True, help='The path to the source image')
    parser.add_argument('--target', required=True, help='The path to the destination image')
    parser.add_argument('--output', required=True, help='The path to the result image')
    parser.add_argument('--job-id', required=True, help='Job ID')
    args = parser.parse_args()

    if not os.path.exists(args.source):
        update_job_status(args.job_id, "failed", f"No source image file found: {args.source}")
        logging.error(f"No source image file found: {args.source}")
        return

    if not os.path.exists(args.target):
        update_job_status(args.job_id, "failed", f"No destination image file found: {args.target}")
        logging.error(f"No destination image file found: {args.target}")
        return

    output_dir = os.path.dirname(args.output)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    update_job_status(args.job_id, "processing", None)
    if face_swap(args.job_id, args.source, args.target, args.output):
        logging.info(f"Successfully and saved the results: {args.output}")
    else:
        logging.error("Image processing fails.")

if __name__ == "__main__":
    main()