import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from typing import Optional
from comfy.ldm.modules.attention import optimized_attention_masked
import comfy.ops

class WhisperFeatureExtractor(nn.Module):
    def __init__(self, n_mels=128, device=None):
        super().__init__()
        self.sample_rate = 16000
        self.n_fft = 400
        self.hop_length = 160
        self.n_mels = n_mels
        self.chunk_length = 30
        self.n_samples = 480000

        self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
            sample_rate=self.sample_rate,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            n_mels=self.n_mels,
            f_min=0,
            f_max=8000,
            norm="slaney",
            mel_scale="slaney",
        ).to(device)

    def __call__(self, audio):
        audio = torch.mean(audio, dim=1)
        batch_size = audio.shape[0]
        processed_audio = []

        for i in range(batch_size):
            aud = audio[i]
            if aud.shape[0] > self.n_samples:
                aud = aud[:self.n_samples]
            elif aud.shape[0] < self.n_samples:
                aud = F.pad(aud, (0, self.n_samples - aud.shape[0]))
            processed_audio.append(aud)

        audio = torch.stack(processed_audio)

        mel_spec = self.mel_spectrogram(audio.to(self.mel_spectrogram.spectrogram.window.device))[:, :, :-1].to(audio.device)

        log_mel_spec = torch.clamp(mel_spec, min=1e-10).log10()
        log_mel_spec = torch.maximum(log_mel_spec, log_mel_spec.max() - 8.0)
        log_mel_spec = (log_mel_spec + 4.0) / 4.0

        return log_mel_spec


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dtype=None, device=None, operations=None):
        super().__init__()
        assert d_model % n_heads == 0

        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.q_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
        self.k_proj = operations.Linear(d_model, d_model, bias=False, dtype=dtype, device=device)
        self.v_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
        self.out_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        batch_size, seq_len, _ = query.shape

        q = self.q_proj(query)
        k = self.k_proj(key)
        v = self.v_proj(value)

        attn_output = optimized_attention_masked(q, k, v, self.n_heads, mask)
        attn_output = self.out_proj(attn_output)

        return attn_output


class EncoderLayer(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dtype=None, device=None, operations=None):
        super().__init__()

        self.self_attn = MultiHeadAttention(d_model, n_heads, dtype=dtype, device=device, operations=operations)
        self.self_attn_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)

        self.fc1 = operations.Linear(d_model, d_ff, dtype=dtype, device=device)
        self.fc2 = operations.Linear(d_ff, d_model, dtype=dtype, device=device)
        self.final_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)

    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        residual = x
        x = self.self_attn_layer_norm(x)
        x = self.self_attn(x, x, x, attention_mask)
        x = residual + x

        residual = x
        x = self.final_layer_norm(x)
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        x = residual + x

        return x


class AudioEncoder(nn.Module):
    def __init__(
        self,
        n_mels: int = 128,
        n_ctx: int = 1500,
        n_state: int = 1280,
        n_head: int = 20,
        n_layer: int = 32,
        dtype=None,
        device=None,
        operations=None
    ):
        super().__init__()

        self.conv1 = operations.Conv1d(n_mels, n_state, kernel_size=3, padding=1, dtype=dtype, device=device)
        self.conv2 = operations.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1, dtype=dtype, device=device)

        self.embed_positions = operations.Embedding(n_ctx, n_state, dtype=dtype, device=device)

        self.layers = nn.ModuleList([
            EncoderLayer(n_state, n_head, n_state * 4, dtype=dtype, device=device, operations=operations)
            for _ in range(n_layer)
        ])

        self.layer_norm = operations.LayerNorm(n_state, dtype=dtype, device=device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))

        x = x.transpose(1, 2)

        x = x + comfy.ops.cast_to_input(self.embed_positions.weight[:, :x.shape[1]], x)

        all_x = ()
        for layer in self.layers:
            all_x += (x,)
            x = layer(x)

        x = self.layer_norm(x)
        all_x += (x,)
        return x, all_x


class WhisperLargeV3(nn.Module):
    def __init__(
        self,
        n_mels: int = 128,
        n_audio_ctx: int = 1500,
        n_audio_state: int = 1280,
        n_audio_head: int = 20,
        n_audio_layer: int = 32,
        dtype=None,
        device=None,
        operations=None
    ):
        super().__init__()

        self.feature_extractor = WhisperFeatureExtractor(n_mels=n_mels, device=device)

        self.encoder = AudioEncoder(
            n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer,
            dtype=dtype, device=device, operations=operations
        )

    def forward(self, audio):
        mel = self.feature_extractor(audio)
        x, all_x = self.encoder(mel)
        return x, all_x
