"""
E2E tests for Queue-specific Preview Method Override feature.

Tests actual execution with different preview_method values.
Requires a running ComfyUI server with models.

Usage:
    COMFYUI_SERVER=http://localhost:8988 pytest test_preview_method_e2e.py -v -m preview_method

Note:
    These tests execute actual image generation and wait for completion.
    Tests verify preview image transmission based on preview_method setting.
"""
import os
import json
import pytest
import uuid
import time
import random
import websocket
import urllib.request
from pathlib import Path


# Server configuration
SERVER_URL = os.environ.get("COMFYUI_SERVER", "http://localhost:8988")
SERVER_HOST = SERVER_URL.replace("http://", "").replace("https://", "")

# Use existing inference graph fixture
GRAPH_FILE = Path(__file__).parent.parent / "inference" / "graphs" / "default_graph_sdxl1_0.json"


def is_server_running() -> bool:
    """Check if ComfyUI server is running."""
    try:
        request = urllib.request.Request(f"{SERVER_URL}/system_stats")
        with urllib.request.urlopen(request, timeout=2.0):
            return True
    except Exception:
        return False


def prepare_graph_for_test(graph: dict, steps: int = 5) -> dict:
    """Prepare graph for testing: randomize seeds and reduce steps."""
    adapted = json.loads(json.dumps(graph))  # Deep copy
    for node_id, node in adapted.items():
        inputs = node.get("inputs", {})
        # Handle both "seed" and "noise_seed" (used by KSamplerAdvanced)
        if "seed" in inputs:
            inputs["seed"] = random.randint(0, 2**32 - 1)
        if "noise_seed" in inputs:
            inputs["noise_seed"] = random.randint(0, 2**32 - 1)
        # Reduce steps for faster testing (default 20 -> 5)
        if "steps" in inputs:
            inputs["steps"] = steps
    return adapted


# Alias for backward compatibility
randomize_seed = prepare_graph_for_test


class PreviewMethodClient:
    """Client for testing preview_method with WebSocket execution tracking."""

    def __init__(self, server_address: str):
        self.server_address = server_address
        self.client_id = str(uuid.uuid4())
        self.ws = None

    def connect(self):
        """Connect to WebSocket."""
        self.ws = websocket.WebSocket()
        self.ws.settimeout(120)  # 2 minute timeout for sampling
        self.ws.connect(f"ws://{self.server_address}/ws?clientId={self.client_id}")

    def close(self):
        """Close WebSocket connection."""
        if self.ws:
            self.ws.close()

    def queue_prompt(self, prompt: dict, extra_data: dict = None) -> dict:
        """Queue a prompt and return response with prompt_id."""
        data = {
            "prompt": prompt,
            "client_id": self.client_id,
            "extra_data": extra_data or {}
        }
        req = urllib.request.Request(
            f"http://{self.server_address}/prompt",
            data=json.dumps(data).encode("utf-8"),
            headers={"Content-Type": "application/json"}
        )
        return json.loads(urllib.request.urlopen(req).read())

    def wait_for_execution(self, prompt_id: str, timeout: float = 120.0) -> dict:
        """
        Wait for execution to complete via WebSocket.

        Returns:
            dict with keys: completed, error, preview_count, execution_time
        """
        result = {
            "completed": False,
            "error": None,
            "preview_count": 0,
            "execution_time": 0.0
        }

        start_time = time.time()
        self.ws.settimeout(timeout)

        try:
            while True:
                out = self.ws.recv()
                elapsed = time.time() - start_time

                if isinstance(out, str):
                    message = json.loads(out)
                    msg_type = message.get("type")
                    data = message.get("data", {})

                    if data.get("prompt_id") != prompt_id:
                        continue

                    if msg_type == "executing":
                        if data.get("node") is None:
                            # Execution complete
                            result["completed"] = True
                            result["execution_time"] = elapsed
                            break

                    elif msg_type == "execution_error":
                        result["error"] = data
                        result["execution_time"] = elapsed
                        break

                    elif msg_type == "progress":
                        # Progress update during sampling
                        pass

                elif isinstance(out, bytes):
                    # Binary data = preview image
                    result["preview_count"] += 1

        except websocket.WebSocketTimeoutException:
            result["error"] = "Timeout waiting for execution"
            result["execution_time"] = time.time() - start_time

        return result


def load_graph() -> dict:
    """Load the SDXL graph fixture with randomized seed."""
    with open(GRAPH_FILE) as f:
        graph = json.load(f)
    return randomize_seed(graph)  # Avoid caching


# Skip all tests if server is not running
pytestmark = [
    pytest.mark.skipif(
        not is_server_running(),
        reason=f"ComfyUI server not running at {SERVER_URL}"
    ),
    pytest.mark.preview_method,
    pytest.mark.execution,
]


@pytest.fixture
def client():
    """Create and connect a test client."""
    c = PreviewMethodClient(SERVER_HOST)
    c.connect()
    yield c
    c.close()


@pytest.fixture
def graph():
    """Load the test graph."""
    return load_graph()


class TestPreviewMethodExecution:
    """Test actual execution with different preview methods."""

    def test_execution_with_latent2rgb(self, client, graph):
        """
        Execute with preview_method=latent2rgb.
        Should complete and potentially receive preview images.
        """
        extra_data = {"preview_method": "latent2rgb"}

        response = client.queue_prompt(graph, extra_data)
        assert "prompt_id" in response

        result = client.wait_for_execution(response["prompt_id"])

        # Should complete (may error if model missing, but that's separate)
        assert result["completed"] or result["error"] is not None
        # Execution should take some time (sampling)
        if result["completed"]:
            assert result["execution_time"] > 0.5, "Execution too fast - likely didn't run"
            # latent2rgb should produce previews
            print(f"latent2rgb: {result['preview_count']} previews in {result['execution_time']:.2f}s")  # noqa: T201

    def test_execution_with_taesd(self, client, graph):
        """
        Execute with preview_method=taesd.
        TAESD provides higher quality previews.
        """
        extra_data = {"preview_method": "taesd"}

        response = client.queue_prompt(graph, extra_data)
        assert "prompt_id" in response

        result = client.wait_for_execution(response["prompt_id"])

        assert result["completed"] or result["error"] is not None
        if result["completed"]:
            assert result["execution_time"] > 0.5
            # taesd should also produce previews
            print(f"taesd: {result['preview_count']} previews in {result['execution_time']:.2f}s")  # noqa: T201

    def test_execution_with_none_preview(self, client, graph):
        """
        Execute with preview_method=none.
        No preview images should be generated.
        """
        extra_data = {"preview_method": "none"}

        response = client.queue_prompt(graph, extra_data)
        assert "prompt_id" in response

        result = client.wait_for_execution(response["prompt_id"])

        assert result["completed"] or result["error"] is not None
        if result["completed"]:
            # With "none", should receive no preview images
            assert result["preview_count"] == 0, \
                f"Expected no previews with 'none', got {result['preview_count']}"
            print(f"none: {result['preview_count']} previews in {result['execution_time']:.2f}s")  # noqa: T201

    def test_execution_with_default(self, client, graph):
        """
        Execute with preview_method=default.
        Should use server's CLI default setting.
        """
        extra_data = {"preview_method": "default"}

        response = client.queue_prompt(graph, extra_data)
        assert "prompt_id" in response

        result = client.wait_for_execution(response["prompt_id"])

        assert result["completed"] or result["error"] is not None
        if result["completed"]:
            print(f"default: {result['preview_count']} previews in {result['execution_time']:.2f}s")  # noqa: T201

    def test_execution_without_preview_method(self, client, graph):
        """
        Execute without preview_method in extra_data.
        Should use server's default preview method.
        """
        extra_data = {}  # No preview_method

        response = client.queue_prompt(graph, extra_data)
        assert "prompt_id" in response

        result = client.wait_for_execution(response["prompt_id"])

        assert result["completed"] or result["error"] is not None
        if result["completed"]:
            print(f"(no override): {result['preview_count']} previews in {result['execution_time']:.2f}s")  # noqa: T201


class TestPreviewMethodComparison:
    """Compare preview behavior between different methods."""

    def test_none_vs_latent2rgb_preview_count(self, client, graph):
        """
        Compare preview counts: 'none' should have 0, others should have >0.
        This is the key verification that preview_method actually works.
        """
        results = {}

        # Run with none (randomize seed to avoid caching)
        graph_none = randomize_seed(graph)
        extra_data_none = {"preview_method": "none"}
        response = client.queue_prompt(graph_none, extra_data_none)
        results["none"] = client.wait_for_execution(response["prompt_id"])

        # Run with latent2rgb (randomize seed again)
        graph_rgb = randomize_seed(graph)
        extra_data_rgb = {"preview_method": "latent2rgb"}
        response = client.queue_prompt(graph_rgb, extra_data_rgb)
        results["latent2rgb"] = client.wait_for_execution(response["prompt_id"])

        # Verify both completed
        assert results["none"]["completed"], f"'none' execution failed: {results['none']['error']}"
        assert results["latent2rgb"]["completed"], f"'latent2rgb' execution failed: {results['latent2rgb']['error']}"

        # Key assertion: 'none' should have 0 previews
        assert results["none"]["preview_count"] == 0, \
            f"'none' should have 0 previews, got {results['none']['preview_count']}"

        # 'latent2rgb' should have at least 1 preview (depends on steps)
        assert results["latent2rgb"]["preview_count"] > 0, \
            f"'latent2rgb' should have >0 previews, got {results['latent2rgb']['preview_count']}"

        print("\nPreview count comparison:")  # noqa: T201
        print(f"  none: {results['none']['preview_count']} previews")  # noqa: T201
        print(f"  latent2rgb: {results['latent2rgb']['preview_count']} previews")  # noqa: T201


class TestPreviewMethodSequential:
    """Test sequential execution with different preview methods."""

    def test_sequential_different_methods(self, client, graph):
        """
        Execute multiple prompts sequentially with different preview methods.
        Each should complete independently with correct preview behavior.
        """
        methods = ["latent2rgb", "none", "default"]
        results = []

        for method in methods:
            # Randomize seed for each execution to avoid caching
            graph_run = randomize_seed(graph)
            extra_data = {"preview_method": method}
            response = client.queue_prompt(graph_run, extra_data)

            result = client.wait_for_execution(response["prompt_id"])
            results.append({
                "method": method,
                "completed": result["completed"],
                "preview_count": result["preview_count"],
                "execution_time": result["execution_time"],
                "error": result["error"]
            })

        # All should complete or have clear errors
        for r in results:
            assert r["completed"] or r["error"] is not None, \
                f"Method {r['method']} neither completed nor errored"

        # "none" should have zero previews if completed
        none_result = next(r for r in results if r["method"] == "none")
        if none_result["completed"]:
            assert none_result["preview_count"] == 0, \
                f"'none' should have 0 previews, got {none_result['preview_count']}"

        print("\nSequential execution results:")  # noqa: T201
        for r in results:
            status = "✓" if r["completed"] else f"✗ ({r['error']})"
            print(f"  {r['method']}: {status}, {r['preview_count']} previews, {r['execution_time']:.2f}s")  # noqa: T201
