from fastapi import FastAPI, BackgroundTasks
from pydantic import BaseModel
import cv2
import numpy as np
import insightface
from insightface.app import FaceAnalysis
import os
import logging
import mysql.connector
import redis
from dotenv import load_dotenv

load_dotenv()

logging.basicConfig(
    level=logging.DEBUG,  # ERROR|DEBUG
    format='%(asctime)s - %(levelname)s - %(message)s'
)

app = FastAPI()

try:
    redis_client = redis.Redis(
        host='localhost',
        port=6379,
        db=0,
        decode_responses=True
    )
    #redis_client.ping()
except redis.ConnectionError as e:
    logging.error(f"Redis connection error: {str(e)}")
    redis_client = None

face_analyzer = FaceAnalysis(
    name='buffalo_l',
    allowed_modules=['detection', 'landmark_2d_106', 'recognition'],
    providers=['CPUExecutionProvider']
)
face_analyzer.prepare(ctx_id=0, det_size=(640, 640))

model_path = os.path.expanduser('~/.insightface/models/inswapper_128.onnx')
face_swapper = insightface.model_zoo.get_model(model_path, providers=['CPUExecutionProvider'])


class FaceSwapRequest(BaseModel):
    job_id: int
    source_path: str
    target_path: str
    output_path: str


def get_db_connection():
    return mysql.connector.connect(
        host=os.getenv('DB_HOST', 'localhost'),
        user=os.getenv('DB_USERNAME', 'root'),
        password=os.getenv('DB_PASSWORD', ''),
        database=os.getenv('DB_DATABASE', 'moro'),
        connection_timeout=5
    )


def update_job_status(job_id, status, error=None):
    try:
        conn = get_db_connection()
        cursor = conn.cursor()

        try:
            query = "UPDATE face_swap_jobs SET status = %s, error = %s WHERE id = %s"
            cursor.execute(query, (status, error, job_id))
            conn.commit()
        finally:
            cursor.close()
            conn.close()

    except Exception as e:
        logging.error(f"Database error: {str(e)}")


async def process_faceswap(job_id: int, source_path: str, target_path: str, output_path: str):
    try:
        update_job_status(job_id, "processing", None)

        source_img = cv2.imread(source_path)
        target_img = cv2.imread(target_path)

        if source_img is None or target_img is None:
            raise Exception("Unable to read source or target image")

        source_faces = face_analyzer.get(source_img)
        target_faces = face_analyzer.get(target_img)

        if len(source_faces) == 0 or len(target_faces) == 0:
            raise Exception("No face found in the photo")

        source_face = source_faces[0]
        target_face = target_faces[0]

        result = face_swapper.get(target_img, target_face, source_face, paste_back=True)

        if result is None:
            raise Exception("Face Swap failed")

        result = cv2.fastNlMeansDenoisingColored(result, None, 3, 3, 7, 21)
        cv2.imwrite(output_path, result)

        update_job_status(job_id, "finished", None)

    except Exception as e:
        logging.error(f"Error in face_swap: {str(e)}")
        update_job_status(job_id, "failed", str(e))


@app.post("/faceswap")
async def start_faceswap(background_tasks: BackgroundTasks, request: FaceSwapRequest):
    try:
        background_tasks.add_task(process_faceswap, request.job_id, request.source_path, request.target_path, request.output_path)
        return {"status": "accepted"}
    except Exception as e:
        return {"status": "error", "message": str(e)}


@app.get("/health")
async def health_check():
    return {"status": "healthy"}


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=5000)
