import logging
from typing import Optional

import torch
import comfy.model_management
from .base import WeightAdapterBase, weight_decompose


class BOFTAdapter(WeightAdapterBase):
    name = "boft"

    def __init__(self, loaded_keys, weights):
        self.loaded_keys = loaded_keys
        self.weights = weights

    @classmethod
    def load(
        cls,
        x: str,
        lora: dict[str, torch.Tensor],
        alpha: float,
        dora_scale: torch.Tensor,
        loaded_keys: set[str] = None,
    ) -> Optional["BOFTAdapter"]:
        if loaded_keys is None:
            loaded_keys = set()
        blocks_name = "{}.oft_blocks".format(x)
        rescale_name = "{}.rescale".format(x)

        blocks = None
        if blocks_name in lora.keys():
            blocks = lora[blocks_name]
            if blocks.ndim == 4:
                loaded_keys.add(blocks_name)
            else:
                blocks = None
        if blocks is None:
            return None

        rescale = None
        if rescale_name in lora.keys():
            rescale = lora[rescale_name]
            loaded_keys.add(rescale_name)

        weights = (blocks, rescale, alpha, dora_scale)
        return cls(loaded_keys, weights)

    def calculate_weight(
        self,
        weight,
        key,
        strength,
        strength_model,
        offset,
        function,
        intermediate_dtype=torch.float32,
        original_weight=None,
    ):
        v = self.weights
        blocks = v[0]
        rescale = v[1]
        alpha = v[2]
        dora_scale = v[3]

        blocks = comfy.model_management.cast_to_device(
            blocks, weight.device, intermediate_dtype
        )
        if rescale is not None:
            rescale = comfy.model_management.cast_to_device(
                rescale, weight.device, intermediate_dtype
            )

        boft_m, block_num, boft_b, *_ = blocks.shape

        try:
            # Get r
            I = torch.eye(boft_b, device=blocks.device, dtype=blocks.dtype)
            # for Q = -Q^T
            q = blocks - blocks.transpose(-1, -2)
            normed_q = q
            if alpha > 0:  # alpha in boft/bboft is for constraint
                q_norm = torch.norm(q) + 1e-8
                if q_norm > alpha:
                    normed_q = q * alpha / q_norm
            # use float() to prevent unsupported type in .inverse()
            r = (I + normed_q) @ (I - normed_q).float().inverse()
            r = r.to(weight)
            inp = org = weight

            r_b = boft_b // 2
            for i in range(boft_m):
                bi = r[i]
                g = 2
                k = 2**i * r_b
                if strength != 1:
                    bi = bi * strength + (1 - strength) * I
                inp = (
                    inp.unflatten(0, (-1, g, k))
                    .transpose(1, 2)
                    .flatten(0, 2)
                    .unflatten(0, (-1, boft_b))
                )
                inp = torch.einsum("b i j, b j ...-> b i ...", bi, inp)
                inp = (
                    inp.flatten(0, 1)
                    .unflatten(0, (-1, k, g))
                    .transpose(1, 2)
                    .flatten(0, 2)
                )

            if rescale is not None:
                inp = inp * rescale

            lora_diff = inp - org
            lora_diff = comfy.model_management.cast_to_device(
                lora_diff, weight.device, intermediate_dtype
            )
            if dora_scale is not None:
                weight = weight_decompose(
                    dora_scale,
                    weight,
                    lora_diff,
                    alpha,
                    strength,
                    intermediate_dtype,
                    function,
                )
            else:
                weight += function((strength * lora_diff).type(weight.dtype))
        except Exception as e:
            logging.error("ERROR {} {} {}".format(self.name, key, e))
        return weight

    def _get_orthogonal_matrices(self, device, dtype):
        """Compute the orthogonal rotation matrices R from BOFT blocks."""
        v = self.weights
        blocks = v[0].to(device=device, dtype=dtype)
        alpha = v[2]
        if alpha is None:
            alpha = 0

        boft_m, block_num, boft_b, _ = blocks.shape
        I = torch.eye(boft_b, device=device, dtype=dtype)

        # Q = blocks - blocks^T (skew-symmetric)
        q = blocks - blocks.transpose(-1, -2)
        normed_q = q

        # Apply constraint if alpha > 0
        if alpha > 0:
            q_norm = torch.norm(q) + 1e-8
            if q_norm > alpha:
                normed_q = q * alpha / q_norm

        # Cayley transform: R = (I + Q)(I - Q)^-1
        r = (I + normed_q) @ (I - normed_q).float().inverse()
        return r, boft_m, boft_b

    def g(self, y: torch.Tensor) -> torch.Tensor:
        """
        Output transformation for BOFT: applies butterfly orthogonal transform.

        BOFT uses multiple stages of butterfly-structured orthogonal transforms.

        Reference: LyCORIS ButterflyOFTModule._bypass_forward
        """
        v = self.weights
        rescale = v[1]

        r, boft_m, boft_b = self._get_orthogonal_matrices(y.device, y.dtype)
        r_b = boft_b // 2

        # Apply multiplier
        multiplier = getattr(self, "multiplier", 1.0)
        I = torch.eye(boft_b, device=y.device, dtype=y.dtype)

        # Use module info from bypass injection to determine conv vs linear
        is_conv = getattr(self, "is_conv", y.dim() > 2)

        if is_conv:
            # Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C)
            y = y.transpose(1, -1)

        # Apply butterfly transform stages
        inp = y
        for i in range(boft_m):
            bi = r[i]  # (block_num, boft_b, boft_b)
            g = 2
            k = 2**i * r_b

            # Interpolate with identity based on multiplier
            if multiplier != 1:
                bi = bi * multiplier + (1 - multiplier) * I

            # Reshape for butterfly: unflatten last dim, transpose, flatten, unflatten
            inp = (
                inp.unflatten(-1, (-1, g, k))
                .transpose(-2, -1)
                .flatten(-3)
                .unflatten(-1, (-1, boft_b))
            )
            # Apply block-diagonal orthogonal transform
            inp = torch.einsum("b i j, ... b j -> ... b i", bi, inp)
            # Reshape back
            inp = (
                inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3)
            )

        # Apply rescale if present
        if rescale is not None:
            rescale = rescale.to(device=y.device, dtype=y.dtype)
            inp = inp * rescale.transpose(0, -1)

        if is_conv:
            # Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...)
            inp = inp.transpose(1, -1)

        return inp
