import logging
from functools import cache
from typing import Optional

import torch
import torch.nn.functional as F
import comfy.model_management
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose


@cache
def _warn_loha_bypass_inefficient():
    """One-time warning about LoHa bypass inefficiency."""
    logging.warning(
        "LoHa bypass mode is inefficient: full weight diff is computed each forward pass. "
        "Consider using LoRA or LoKr for training with bypass mode."
    )


class HadaWeight(torch.autograd.Function):
    @staticmethod
    def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)):
        ctx.save_for_backward(w1d, w1u, w2d, w2u, scale)
        diff_weight = ((w1u @ w1d) * (w2u @ w2d)) * scale
        return diff_weight

    @staticmethod
    def backward(ctx, grad_out):
        (w1d, w1u, w2d, w2u, scale) = ctx.saved_tensors
        grad_out = grad_out * scale
        temp = grad_out * (w2u @ w2d)
        grad_w1u = temp @ w1d.T
        grad_w1d = w1u.T @ temp

        temp = grad_out * (w1u @ w1d)
        grad_w2u = temp @ w2d.T
        grad_w2d = w2u.T @ temp

        del temp
        return grad_w1u, grad_w1d, grad_w2u, grad_w2d, None


class HadaWeightTucker(torch.autograd.Function):
    @staticmethod
    def forward(ctx, t1, w1u, w1d, t2, w2u, w2d, scale=torch.tensor(1)):
        ctx.save_for_backward(t1, w1d, w1u, t2, w2d, w2u, scale)

        rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1d, w1u)
        rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2d, w2u)

        return rebuild1 * rebuild2 * scale

    @staticmethod
    def backward(ctx, grad_out):
        (t1, w1d, w1u, t2, w2d, w2u, scale) = ctx.saved_tensors
        grad_out = grad_out * scale

        temp = torch.einsum("i j ..., j r -> i r ...", t2, w2d)
        rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2u)

        grad_w = rebuild * grad_out
        del rebuild

        grad_w1u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
        grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1u.T)
        del grad_w, temp

        grad_w1d = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp)
        grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1d.T)
        del grad_temp

        temp = torch.einsum("i j ..., j r -> i r ...", t1, w1d)
        rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1u)

        grad_w = rebuild * grad_out
        del rebuild

        grad_w2u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
        grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2u.T)
        del grad_w, temp

        grad_w2d = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp)
        grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2d.T)
        del grad_temp
        return grad_t1, grad_w1u, grad_w1d, grad_t2, grad_w2u, grad_w2d, None


class LohaDiff(WeightAdapterTrainBase):
    def __init__(self, weights):
        super().__init__()
        # Unpack weights tuple from LoHaAdapter
        w1a, w1b, alpha, w2a, w2b, t1, t2, _ = weights

        # Create trainable parameters
        self.hada_w1_a = torch.nn.Parameter(w1a)
        self.hada_w1_b = torch.nn.Parameter(w1b)
        self.hada_w2_a = torch.nn.Parameter(w2a)
        self.hada_w2_b = torch.nn.Parameter(w2b)

        self.use_tucker = False
        if t1 is not None and t2 is not None:
            self.use_tucker = True
            self.hada_t1 = torch.nn.Parameter(t1)
            self.hada_t2 = torch.nn.Parameter(t2)
        else:
            # Keep the attributes for consistent access
            self.hada_t1 = None
            self.hada_t2 = None

        # Store rank and non-trainable alpha
        self.rank = w1b.shape[0]
        self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)

    def __call__(self, w):
        org_dtype = w.dtype

        scale = self.alpha / self.rank
        if self.use_tucker:
            diff_weight = HadaWeightTucker.apply(
                self.hada_t1,
                self.hada_w1_a,
                self.hada_w1_b,
                self.hada_t2,
                self.hada_w2_a,
                self.hada_w2_b,
                scale,
            )
        else:
            diff_weight = HadaWeight.apply(
                self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale
            )

        # Add the scaled difference to the original weight
        weight = w.to(diff_weight) + diff_weight.reshape(w.shape)

        return weight.to(org_dtype)

    def passive_memory_usage(self):
        """Calculates memory usage of the trainable parameters."""
        return sum(param.numel() * param.element_size() for param in self.parameters())


class LoHaAdapter(WeightAdapterBase):
    name = "loha"

    def __init__(self, loaded_keys, weights):
        self.loaded_keys = loaded_keys
        self.weights = weights

    @classmethod
    def create_train(cls, weight, rank=1, alpha=1.0):
        out_dim = weight.shape[0]
        in_dim = weight.shape[1:].numel()
        mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
        mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
        torch.nn.init.normal_(mat1, 0.1)
        torch.nn.init.constant_(mat2, 0.0)
        mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
        mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
        torch.nn.init.normal_(mat3, 0.1)
        torch.nn.init.normal_(mat4, 0.01)
        return LohaDiff((mat1, mat2, alpha, mat3, mat4, None, None, None))

    def to_train(self):
        return LohaDiff(self.weights)

    @classmethod
    def load(
        cls,
        x: str,
        lora: dict[str, torch.Tensor],
        alpha: float,
        dora_scale: torch.Tensor,
        loaded_keys: set[str] = None,
    ) -> Optional["LoHaAdapter"]:
        if loaded_keys is None:
            loaded_keys = set()

        hada_w1_a_name = "{}.hada_w1_a".format(x)
        hada_w1_b_name = "{}.hada_w1_b".format(x)
        hada_w2_a_name = "{}.hada_w2_a".format(x)
        hada_w2_b_name = "{}.hada_w2_b".format(x)
        hada_t1_name = "{}.hada_t1".format(x)
        hada_t2_name = "{}.hada_t2".format(x)
        if hada_w1_a_name in lora.keys():
            hada_t1 = None
            hada_t2 = None
            if hada_t1_name in lora.keys():
                hada_t1 = lora[hada_t1_name]
                hada_t2 = lora[hada_t2_name]
                loaded_keys.add(hada_t1_name)
                loaded_keys.add(hada_t2_name)

            weights = (
                lora[hada_w1_a_name],
                lora[hada_w1_b_name],
                alpha,
                lora[hada_w2_a_name],
                lora[hada_w2_b_name],
                hada_t1,
                hada_t2,
                dora_scale,
            )
            loaded_keys.add(hada_w1_a_name)
            loaded_keys.add(hada_w1_b_name)
            loaded_keys.add(hada_w2_a_name)
            loaded_keys.add(hada_w2_b_name)
            return cls(loaded_keys, weights)
        else:
            return None

    def calculate_weight(
        self,
        weight,
        key,
        strength,
        strength_model,
        offset,
        function,
        intermediate_dtype=torch.float32,
        original_weight=None,
    ):
        v = self.weights
        w1a = v[0]
        w1b = v[1]
        if v[2] is not None:
            alpha = v[2] / w1b.shape[0]
        else:
            alpha = 1.0

        w2a = v[3]
        w2b = v[4]
        dora_scale = v[7]
        if v[5] is not None:  # cp decomposition
            t1 = v[5]
            t2 = v[6]
            m1 = torch.einsum(
                "i j k l, j r, i p -> p r k l",
                comfy.model_management.cast_to_device(
                    t1, weight.device, intermediate_dtype
                ),
                comfy.model_management.cast_to_device(
                    w1b, weight.device, intermediate_dtype
                ),
                comfy.model_management.cast_to_device(
                    w1a, weight.device, intermediate_dtype
                ),
            )

            m2 = torch.einsum(
                "i j k l, j r, i p -> p r k l",
                comfy.model_management.cast_to_device(
                    t2, weight.device, intermediate_dtype
                ),
                comfy.model_management.cast_to_device(
                    w2b, weight.device, intermediate_dtype
                ),
                comfy.model_management.cast_to_device(
                    w2a, weight.device, intermediate_dtype
                ),
            )
        else:
            m1 = torch.mm(
                comfy.model_management.cast_to_device(
                    w1a, weight.device, intermediate_dtype
                ),
                comfy.model_management.cast_to_device(
                    w1b, weight.device, intermediate_dtype
                ),
            )
            m2 = torch.mm(
                comfy.model_management.cast_to_device(
                    w2a, weight.device, intermediate_dtype
                ),
                comfy.model_management.cast_to_device(
                    w2b, weight.device, intermediate_dtype
                ),
            )

        try:
            lora_diff = (m1 * m2).reshape(weight.shape)
            if dora_scale is not None:
                weight = weight_decompose(
                    dora_scale,
                    weight,
                    lora_diff,
                    alpha,
                    strength,
                    intermediate_dtype,
                    function,
                )
            else:
                weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
        except Exception as e:
            logging.error("ERROR {} {} {}".format(self.name, key, e))
        return weight

    def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
        """
        Additive bypass component for LoHa: h(x) = diff_weight @ x

        WARNING: Inefficient - computes full Hadamard product each forward.

        Note:
            Does not access original model weights - bypass mode is designed
            for quantized models where weights may not be accessible.

        Args:
            x: Input tensor
            base_out: Output from base forward (unused, for API consistency)

        Reference: LyCORIS functional/loha.py bypass_forward_diff
        """
        _warn_loha_bypass_inefficient()

        # FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
        FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]

        v = self.weights
        # v[0]=w1a, v[1]=w1b, v[2]=alpha, v[3]=w2a, v[4]=w2b, v[5]=t1, v[6]=t2, v[7]=dora
        w1a = v[0]
        w1b = v[1]
        alpha = v[2]
        w2a = v[3]
        w2b = v[4]
        t1 = v[5]
        t2 = v[6]

        # Compute scale
        rank = w1b.shape[0]
        scale = (alpha / rank if alpha is not None else 1.0) * getattr(
            self, "multiplier", 1.0
        )

        # Cast dtype
        w1a = w1a.to(dtype=x.dtype)
        w1b = w1b.to(dtype=x.dtype)
        w2a = w2a.to(dtype=x.dtype)
        w2b = w2b.to(dtype=x.dtype)

        # Use module info from bypass injection, not weight dimension
        is_conv = getattr(self, "is_conv", False)
        conv_dim = getattr(self, "conv_dim", 0)
        kw_dict = getattr(self, "kw_dict", {})

        # Compute diff weight using Hadamard product
        if t1 is not None and t2 is not None:
            t1 = t1.to(dtype=x.dtype)
            t2 = t2.to(dtype=x.dtype)
            m1 = torch.einsum("i j k l, j r, i p -> p r k l", t1, w1b, w1a)
            m2 = torch.einsum("i j k l, j r, i p -> p r k l", t2, w2b, w2a)
            diff_weight = (m1 * m2) * scale
        else:
            m1 = w1a @ w1b
            m2 = w2a @ w2b
            diff_weight = (m1 * m2) * scale

        if is_conv:
            op = FUNC_LIST[conv_dim + 2]
            kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
            in_channels = getattr(self, "in_channels", None)

            # Reshape 2D diff_weight to conv format using kernel_size
            # diff_weight: [out_channels, in_channels * prod(kernel_size)] -> [out_channels, in_channels, *kernel_size]
            if diff_weight.dim() == 2:
                if in_channels is not None:
                    diff_weight = diff_weight.view(
                        diff_weight.shape[0], in_channels, *kernel_size
                    )
                else:
                    diff_weight = diff_weight.view(
                        *diff_weight.shape, *([1] * conv_dim)
                    )
        else:
            op = F.linear
            kw_dict = {}

        return op(x, diff_weight, **kw_dict)
