#!/usr/bin/env python3
"""
Advanced Anime Consistency Checker with Full Auto-Detection
Uses multiple advanced techniques for comprehensive visual comparison
"""

import os
import sys
from pathlib import Path
from PIL import Image
import numpy as np
import imagehash
import cv2
from skimage.metrics import structural_similarity as ssim
from skimage.color import rgb2gray
from skimage import io

GENERATED_DIR = "/mnt/c/Users/fbmor/broken-spire-comparison/generated"
REFERENCES_DIR = "/mnt/c/Users/fbmor/broken-spire-comparison/references"

def load_image_for_cv(path):
    """Load image for OpenCV processing"""
    img = cv2.imread(path)
    if img is None:
        return None
    img = cv2.resize(img, (512, 512))
    return img

def get_perceptual_hash(path):
    """Get perceptual hash of image"""
    try:
        img = Image.open(path)
        return imagehash.phash(img)
    except:
        return None

def get_ahash(path):
    """Get average hash"""
    try:
        img = Image.open(path)
        return imagehash.average_hash(img)
    except:
        return None

def get_dhash(path):
    """Get difference hash"""
    try:
        img = Image.open(path)
        return imagehash.dhash(img)
    except:
        return None

def get_whash(path):
    """Get wavelet hash"""
    try:
        img = Image.open(path)
        return imagehash.whash(img)
    except:
        return None

def compare_hashes(h1, h2):
    """Compare two hashes - return similarity score 0-1"""
    if h1 is None or h2 is None:
        return 0.0
    # Hamming distance normalized
    similarity = 1 - (h1 - h2) / len(h1.hash) ** 2
    return max(0, min(1, similarity))

def get_color_histogram(img_path, bins=32):
    """Get color histogram"""
    img = cv2.imread(img_path)
    if img is None:
        return None
    img = cv2.resize(img, (512, 512))
    
    hist_b = cv2.calcHist([img], [0], None, [bins], [0, 256])
    hist_g = cv2.calcHist([img], [1], None, [bins], [0, 256])
    hist_r = cv2.calcHist([img], [2], None, [bins], [0, 256])
    
    # Normalize
    hist_b = cv2.normalize(hist_b, hist_b).flatten()
    hist_g = cv2.normalize(hist_g, hist_g).flatten()
    hist_r = cv2.normalize(hist_r, hist_r).flatten()
    
    return np.concatenate([hist_b, hist_g, hist_r])

def compare_histograms(h1, h2):
    """Compare two histograms using correlation"""
    if h1 is None or h2 is None:
        return 0.0
    # Correlation coefficient
    corr = np.corrcoef(h1, h2)[0, 1]
    return max(0, min(1, (corr + 1) / 2))

def get_edge_density(img_path):
    """Get edge density using Canny"""
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        return 0
    img = cv2.resize(img, (512, 512))
    edges = cv2.Canny(img, 100, 200)
    return np.sum(edges > 0) / (512 * 512)

def get_ssim_score(img1_path, img2_path):
    """Get structural similarity score"""
    img1 = cv2.imread(img1_path)
    img2 = cv2.imread(img2_path)
    
    if img1 is None or img2 is None:
        return 0
    
    img1 = cv2.resize(img1, (512, 512))
    img2 = cv2.resize(img2, (512, 512))
    
    gray1 = rgb2gray(cv2.cvtColor(img1, cv2.COLOR_BGR2RGB))
    gray2 = rgb2gray(cv2.cvtColor(img2, cv2.COLOR_BGR2RGB))
    
    return ssim(gray1, gray2, data_range=1.0)

def get_hair_color_score(ref_path, gen_path):
    """Estimate hair color similarity"""
    ref_img = cv2.imread(ref_path)
    gen_img = cv2.imread(gen_path)
    
    if ref_img is None or gen_img is None:
        return 0
    
    # Focus on upper part of image (usually where hair is)
    ref_upper = ref_img[:200, :]
    gen_upper = gen_img[:200, :]
    
    # Get average color of upper region
    ref_avg = ref_upper.mean(axis=(0,1))
    gen_avg = gen_upper.mean(axis=(0,1))
    
    # Normalize and compare
    ref_norm = ref_avg / ref_avg.sum()
    gen_norm = gen_avg / gen_avg.sum()
    
    # Cosine similarity
    similarity = np.dot(ref_norm, gen_norm)
    return max(0, min(1, similarity))

def get_skin_tone_score(ref_path, gen_path):
    """Estimate skin tone similarity"""
    ref_img = cv2.imread(ref_path)
    gen_img = cv2.imread(gen_path)
    
    if ref_img is None or gen_img is None:
        return 0
    
    # Convert to HSV for skin detection
    ref_hsv = cv2.cvtColor(ref_img, cv2.COLOR_BGR2HSV)
    gen_hsv = cv2.cvtColor(gen_img, cv2.COLOR_BGR2HSV)
    
    # Skin color range in HSV
    lower_skin = np.array([0, 20, 70], dtype=np.uint8)
    upper_skin = np.array([20, 255, 255], dtype=np.uint8)
    
    ref_mask = cv2.inRange(ref_hsv, lower_skin, upper_skin)
    gen_mask = cv2.inRange(gen_hsv, lower_skin, upper_skin)
    
    ref_skin = cv2.bitwise_and(ref_img, ref_img, mask=ref_mask)
    gen_skin = cv2.bitwise_and(gen_img, gen_img, mask=gen_mask)
    
    # Get average skin color if any
    ref_skin_pixels = ref_skin[ref_mask > 0]
    gen_skin_pixels = gen_skin[gen_mask > 0]
    
    if len(ref_skin_pixels) > 10 and len(gen_skin_pixels) > 10:
        ref_avg = ref_skin_pixels.mean(axis=0)
        gen_avg = gen_skin_pixels.mean(axis=0)
        
        # Compare skin tones
        diff = np.abs(ref_avg - gen_avg).mean()
        return max(0, 1 - (diff / 255))
    
    return 0.5  # Neutral if no skin detected

def get_detail_level(img_path):
    """Get level of detail in image"""
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        return 0
    
    img = cv2.resize(img, (512, 512))
    
    # Laplacian variance as measure of detail/sharpness
    laplacian = cv2.Laplacian(img, cv2.CV_64F)
    return laplacian.var()

def analyze_character_auto(char_name, ref_file, gen_file):
    """Full automated analysis for a character"""
    ref_path = os.path.join(REFERENCES_DIR, ref_file)
    gen_path = os.path.join(GENERATED_DIR, gen_file)
    
    if not os.path.exists(ref_path):
        return {"error": f"Reference not found: {ref_path}"}
    if not os.path.exists(gen_path):
        return {"error": f"Generated not found: {gen_path}"}
    
    results = {}
    
    # 1. Perceptual Hash Comparison (overall similarity)
    phash_ref = get_perceptual_hash(ref_path)
    phash_gen = get_perceptual_hash(gen_path)
    results["perceptual_hash"] = compare_hashes(phash_ref, phash_gen)
    
    # 2. Average Hash
    ahash_ref = get_ahash(ref_path)
    ahash_gen = get_ahash(gen_path)
    results["average_hash"] = compare_hashes(ahash_ref, ahash_gen)
    
    # 3. Difference Hash
    dhash_ref = get_dhash(ref_path)
    dhash_gen = get_dhash(gen_path)
    results["difference_hash"] = compare_hashes(dhash_ref, dhash_gen)
    
    # 4. Wavelet Hash
    whash_ref = get_whash(ref_path)
    whash_gen = get_whash(gen_path)
    results["wavelet_hash"] = compare_hashes(whash_ref, whash_gen)
    
    # 5. Color Histogram
    hist_ref = get_color_histogram(ref_path)
    hist_gen = get_color_histogram(gen_path)
    results["color_histogram"] = compare_histograms(hist_ref, hist_gen)
    
    # 6. Structural Similarity
    results["structural_similarity"] = get_ssim_score(ref_path, gen_path)
    
    # 7. Hair color (upper region)
    results["hair_color"] = get_hair_color_score(ref_path, gen_path)
    
    # 8. Skin tone
    results["skin_tone"] = get_skin_tone_score(ref_path, gen_path)
    
    # 9. Edge density (detail level)
    ref_edges = get_edge_density(ref_path)
    gen_edges = get_edge_density(gen_path)
    edge_diff = abs(ref_edges - gen_edges)
    results["detail_level"] = max(0, 1 - edge_diff)
    
    # Calculate overall score
    weights = {
        "perceptual_hash": 0.20,
        "average_hash": 0.15,
        "difference_hash": 0.15,
        "wavelet_hash": 0.10,
        "color_histogram": 0.15,
        "structural_similarity": 0.10,
        "hair_color": 0.05,
        "skin_tone": 0.05,
        "detail_level": 0.05
    }
    
    overall = sum(results[k] * weights[k] for k in weights)
    results["overall_score"] = overall
    
    # Determine pass/fail
    if overall >= 0.70:
        results["verdict"] = "✅ PASS"
    elif overall >= 0.50:
        results["verdict"] = "⚠️ PARTIAL"
    else:
        results["verdict"] = "❌ FAIL"
    
    return results

def run_full_analysis():
    """Run comprehensive analysis on all characters"""
    
    print("=" * 80)
    print("BROKEN SPIRE - FULLY AUTOMATED CONSISTENCY ANALYSIS")
    print("=" * 80)
    
    comparisons = [
        ("Ash", "Ash.png", "02_ash_birth_00002_.png"),
        ("Far-Future Ash", "Far-Future Ash.png", "03_far_future_ash_evil_00002_.png"),
        ("Everly", "Everly.png", "05_everly_soldier_00002_.png"),
        ("Éva Moreau", "Éva Moreau.png", "06_eva_doctor_00002_.png"),
        ("Nova", "Nova Human.png", "07_nova_warrior_00002_.png"),
        ("Violet", "Violet Devil.png", "08_violet_devil_00002_.png"),
        ("Lin Weishan", "Lin Weishan.png", "09_lin_weishan_00002_.png"),
        ("TC-23", "TC-23.png", "10_tc23_esper_00002_.png"),
        ("Jonas", "Jonas.png", "11_jonas_ghost_00002_.png"),
    ]
    
    all_results = []
    
    for char_name, ref_file, gen_file in comparisons:
        print(f"\n{'='*80}")
        print(f"Character: {char_name}")
        print(f"{'='*80}")
        
        results = analyze_character_auto(char_name, ref_file, gen_file)
        
        if "error" in results:
            print(f"ERROR: {results['error']}")
            continue
        
        print(f"""
Reference: {ref_file}
Generated: {gen_file}

DETAILED METRICS:
-----------------
Perceptual Hash:   {results['perceptual_hash']:.1%}
Average Hash:      {results['average_hash']:.1%}
Difference Hash:   {results['difference_hash']:.1%}
Wavelet Hash:      {results['wavelet_hash']:.1%}
Color Histogram:   {results['color_histogram']:.1%}
Structural Sim:    {results['structural_similarity']:.1%}
Hair Color:        {results['hair_color']:.1%}
Skin Tone:         {results['skin_tone']:.1%}
Detail Level:      {results['detail_level']:.1%}

OVERALL SCORE: {results['overall_score']:.1%}
VERDICT: {results['verdict']}
""")
        
        all_results.append((char_name, results))
    
    # Summary
    print("\n" + "=" * 80)
    print("FINAL SUMMARY")
    print("=" * 80)
    
    passed = sum(1 for _, r in all_results if r.get("verdict") == "✅ PASS")
    partial = sum(1 for _, r in all_results if r.get("verdict") == "⚠️ PARTIAL")
    failed = sum(1 for _, r in all_results if r.get("verdict") == "❌ FAIL")
    total = len(all_results)
    
    print(f"\nTotal Characters: {total}")
    print(f"✅ PASS: {passed}")
    print(f"⚠️  PARTIAL: {partial}")
    print(f"❌ FAIL: {failed}")
    
    # List failing characters
    if failed > 0 or partial > 0:
        print("\n" + "=" * 80)
        print("NEEDS REGENERATION")
        print("=" * 80)
        for char_name, results in all_results:
            if results.get("verdict") in ["⚠️ PARTIAL", "❌ FAIL"]:
                print(f"  • {char_name}: {results['overall_score']:.1%} - {results['verdict']}")
                print(f"    Score breakdown:")
                for key, val in results.items():
                    if key not in ["verdict", "overall_score"]:
                        bar = "█" * int(val * 10) + "░" * (10 - int(val * 10))
                        print(f"      {key:20s} {bar} {val:.0%}")
    
    # Save results
    report_path = "/mnt/c/Users/fbmor/broken-spire-comparison/auto_analysis_results.txt"
    with open(report_path, "w") as f:
        f.write("BROKEN SPIRE - AUTOMATED CONSISTENCY ANALYSIS\n")
        f.write("=" * 60 + "\n\n")
        for char_name, results in all_results:
            f.write(f"{char_name}: {results.get('verdict', 'N/A')} ({results.get('overall_score', 0):.1%})\n")
        f.write(f"\nTotal: {passed} passed, {partial} partial, {failed} failed\n")
    
    print(f"\nReport saved to: {report_path}")
    
    return all_results

if __name__ == "__main__":
    run_full_analysis()