# Original: https://github.com/Tencent/Hunyuan3D-2/blob/main/hy3dgen/shapegen/models/autoencoders/model.py
# Since the header on their VAE source file was a bit confusing we asked for permission to use this code from tencent under the GPL license used in ComfyUI.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from tqdm import tqdm

from typing import Optional

import logging

import comfy.ops
ops = comfy.ops.disable_weight_init

def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_random: bool = True):

    # manually create the pointer vector
    assert src.size(0) == batch.numel()

    batch_size = int(batch.max()) + 1
    deg = src.new_zeros(batch_size, dtype = torch.long)

    deg.scatter_add_(0, batch, torch.ones_like(batch))

    ptr_vec = deg.new_zeros(batch_size + 1)
    torch.cumsum(deg, 0, out=ptr_vec[1:])

    #return fps_sampling(src, ptr_vec, ratio)
    sampled_indicies = []

    for b in range(batch_size):
        # start and the end of each batch
        start, end = ptr_vec[b].item(), ptr_vec[b + 1].item()
        # points from the point cloud
        points = src[start:end]

        num_points = points.size(0)
        num_samples = max(1, math.ceil(num_points * sampling_ratio))

        selected = torch.zeros(num_samples, device = src.device, dtype = torch.long)
        distances = torch.full((num_points,), float("inf"), device = src.device)

        # select a random start point
        if start_random:
            farthest = torch.randint(0, num_points, (1,), device = src.device)
        else:
            farthest = torch.tensor([0], device = src.device, dtype = torch.long)

        for i in range(num_samples):
            selected[i] = farthest
            centroid = points[farthest].squeeze(0)
            dist = torch.norm(points - centroid, dim = 1) # compute euclidean distance
            distances = torch.minimum(distances, dist)
            farthest = torch.argmax(distances)

        sampled_indicies.append(torch.arange(start, end)[selected])

    return torch.cat(sampled_indicies, dim = 0)
class PointCrossAttention(nn.Module):
    def __init__(self,
        num_latents: int,
        downsample_ratio: float,
        pc_size: int,
        pc_sharpedge_size: int,
        point_feats: int,
        width: int,
        heads: int,
        layers: int,
        fourier_embedder,
        normal_pe: bool = False,
        qkv_bias: bool = False,
        use_ln_post: bool = True,
        qk_norm: bool = True):

        super().__init__()

        self.fourier_embedder = fourier_embedder

        self.pc_size = pc_size
        self.normal_pe = normal_pe
        self.downsample_ratio = downsample_ratio
        self.pc_sharpedge_size = pc_sharpedge_size
        self.num_latents = num_latents
        self.point_feats = point_feats

        self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)

        self.cross_attn = ResidualCrossAttentionBlock(
            width = width,
            heads = heads,
            qkv_bias = qkv_bias,
            qk_norm = qk_norm
        )

        self.self_attn = None
        if layers > 0:
            self.self_attn = Transformer(
                width = width,
                heads = heads,
                qkv_bias = qkv_bias,
                qk_norm = qk_norm,
                layers = layers
            )

        if use_ln_post:
            self.ln_post = nn.LayerNorm(width)
        else:
            self.ln_post = None

    def sample_points_and_latents(self, point_cloud: torch.Tensor, features: torch.Tensor):

        """
        Subsample points randomly from the point cloud (input_pc)
        Further sample the subsampled points to get query_pc
        take the fourier embeddings for both input and query pc

        Mental Note: FPS-sampled points (query_pc) act as latent tokens that attend to and learn from the broader context in input_pc.
        Goal: get a smaller represenation (query_pc) to represent the entire scence structure by learning from a broader subset (input_pc).
        More computationally efficient.

        Features are additional information for each point in the cloud
        """

        B, _, D = point_cloud.shape

        num_latents = int(self.num_latents)

        num_random_query = self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents
        num_sharpedge_query = num_latents - num_random_query

        # Split random and sharpedge surface points
        random_pc, sharpedge_pc = torch.split(point_cloud, [self.pc_size, self.pc_sharpedge_size], dim=1)

        # assert statements
        assert random_pc.shape[1] <= self.pc_size, "Random surface points size must be less than or equal to pc_size"
        assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, "Sharpedge surface points size must be less than or equal to pc_sharpedge_size"

        input_random_pc_size = int(num_random_query * self.downsample_ratio)
        random_query_pc, random_input_pc, random_idx_pc, random_idx_query = \
            self.subsample(pc = random_pc, num_query = num_random_query, input_pc_size = input_random_pc_size)

        input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio)

        if input_sharpedge_pc_size == 0:
            sharpedge_input_pc = torch.zeros(B, 0, D, dtype = random_input_pc.dtype).to(point_cloud.device)
            sharpedge_query_pc = torch.zeros(B, 0, D, dtype= random_query_pc.dtype).to(point_cloud.device)

        else:
            sharpedge_query_pc, sharpedge_input_pc, sharpedge_idx_pc, sharpedge_idx_query = \
            self.subsample(pc = sharpedge_pc, num_query = num_sharpedge_query, input_pc_size = input_sharpedge_pc_size)

        # concat the random and sharpedges
        query_pc = torch.cat([random_query_pc, sharpedge_query_pc], dim = 1)
        input_pc = torch.cat([random_input_pc, sharpedge_input_pc], dim = 1)

        query = self.fourier_embedder(query_pc)
        data = self.fourier_embedder(input_pc)

        if self.point_feats > 0:
            random_surface_features, sharpedge_surface_features = torch.split(features, [self.pc_size, self.pc_sharpedge_size], dim = 1)

            input_random_surface_features, query_random_features = \
                self.handle_features(features = random_surface_features, idx_pc = random_idx_pc, batch_size = B,
                                     input_pc_size = input_random_pc_size, idx_query = random_idx_query)

            if input_sharpedge_pc_size == 0:
                input_sharpedge_surface_features = torch.zeros(B, 0, self.point_feats,
                                                               dtype = input_random_surface_features.dtype, device = point_cloud.device)

                query_sharpedge_features = torch.zeros(B, 0, self.point_feats,
                                                       dtype = query_random_features.dtype, device = point_cloud.device)
            else:

                input_sharpedge_surface_features, query_sharpedge_features = \
                    self.handle_features(idx_pc = sharpedge_idx_pc, features = sharpedge_surface_features,
                                         batch_size = B, idx_query = sharpedge_idx_query, input_pc_size = input_sharpedge_pc_size)

            query_features = torch.cat([query_random_features, query_sharpedge_features], dim = 1)
            input_features = torch.cat([input_random_surface_features, input_sharpedge_surface_features], dim = 1)

            if self.normal_pe:
                # apply the fourier embeddings on the first 3 dims (xyz)
                input_features_pe = self.fourier_embedder(input_features[..., :3])
                query_features_pe = self.fourier_embedder(query_features[..., :3])
                # replace the first 3 dims with the new PE ones
                input_features = torch.cat([input_features_pe, input_features[..., :3]], dim = -1)
                query_features = torch.cat([query_features_pe, query_features[..., :3]], dim = -1)

            # concat at the channels dim
            query = torch.cat([query, query_features], dim = -1)
            data = torch.cat([data, input_features], dim = -1)

        # don't return pc_info to avoid unnecessary memory usuage
        return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1])

    def forward(self, point_cloud: torch.Tensor, features: torch.Tensor):

        query, data = self.sample_points_and_latents(point_cloud = point_cloud, features = features)

        # apply projections
        query = self.input_proj(query)
        data = self.input_proj(data)

        # apply cross attention between query and data
        latents = self.cross_attn(query, data)

        if self.self_attn is not None:
            latents = self.self_attn(latents)

        if self.ln_post is not None:
            latents = self.ln_post(latents)

        return latents


    def subsample(self, pc, num_query, input_pc_size: int):

        """
        num_query: number of points to keep after FPS
        input_pc_size: number of points to select before FPS
        """

        B, _, D = pc.shape
        query_ratio = num_query / input_pc_size

        # random subsampling of points inside the point cloud
        idx_pc = torch.randperm(pc.shape[1], device = pc.device)[:input_pc_size]
        input_pc = pc[:, idx_pc, :]

        # flatten to allow applying fps across the whole batch
        flattent_input_pc = input_pc.view(B * input_pc_size, D)

        # construct a batch_down tensor to tell fps
        # which points belong to which batch
        N_down = int(flattent_input_pc.shape[0] / B)
        batch_down = torch.arange(B).to(pc.device)
        batch_down = torch.repeat_interleave(batch_down, N_down)

        idx_query = fps(flattent_input_pc, batch_down, sampling_ratio = query_ratio)
        query_pc = flattent_input_pc[idx_query].view(B, -1, D)

        return query_pc, input_pc, idx_pc, idx_query

    def handle_features(self, features, idx_pc, input_pc_size, batch_size: int, idx_query):

        B = batch_size

        input_surface_features = features[:, idx_pc, :]
        flattent_input_features = input_surface_features.view(B * input_pc_size, -1)
        query_features = flattent_input_features[idx_query].view(B, -1,
                                                                 flattent_input_features.shape[-1])

        return input_surface_features, query_features

def normalize_mesh(mesh, scale = 0.9999):
    """Normalize mesh to fit in [-scale, scale]. Translate mesh so its center is [0,0,0]"""

    bbox = mesh.bounds
    center = (bbox[1] + bbox[0]) / 2

    max_extent = (bbox[1] - bbox[0]).max()
    mesh.apply_translation(-center)
    mesh.apply_scale((2 * scale) / max_extent)

    return mesh

def sample_pointcloud(mesh, num = 200000):
    """ Uniformly sample points from the surface of the mesh """

    points, face_idx = mesh.sample(num, return_index = True)
    normals = mesh.face_normals[face_idx]
    return torch.from_numpy(points.astype(np.float32)), torch.from_numpy(normals.astype(np.float32))

def detect_sharp_edges(mesh, threshold=0.985):
    """Return edge indices (a, b) that lie on sharp boundaries of the mesh."""

    V, F = mesh.vertices, mesh.faces
    VN, FN = mesh.vertex_normals, mesh.face_normals

    sharp_mask = np.ones(V.shape[0])
    for i in range(3):
        indices = F[:, i]
        alignment = np.einsum('ij,ij->i', VN[indices], FN)
        dot_stack = np.stack((sharp_mask[indices], alignment), axis=-1)
        sharp_mask[indices] = np.min(dot_stack, axis=-1)

    edge_a = np.concatenate([F[:, 0], F[:, 1], F[:, 2]])
    edge_b = np.concatenate([F[:, 1], F[:, 2], F[:, 0]])
    sharp_edges = (sharp_mask[edge_a] < threshold) & (sharp_mask[edge_b] < threshold)

    return edge_a[sharp_edges], edge_b[sharp_edges]


def sharp_sample_pointcloud(mesh, num = 16384):
    """ Sample points preferentially from sharp edges in the mesh. """

    edge_a, edge_b = detect_sharp_edges(mesh)
    V, VN = mesh.vertices, mesh.vertex_normals

    va, vb = V[edge_a], V[edge_b]
    na, nb = VN[edge_a], VN[edge_b]

    edge_lengths = np.linalg.norm(vb - va, axis=-1)
    weights = edge_lengths / edge_lengths.sum()

    indices = np.searchsorted(np.cumsum(weights), np.random.rand(num))
    t = np.random.rand(num, 1)

    samples = t * va[indices] + (1 - t) * vb[indices]
    normals = t * na[indices] + (1 - t) * nb[indices]

    return samples.astype(np.float32), normals.astype(np.float32)

def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag = True, device = "cuda"):
    """Load a surface with optional sharp-edge annotations from a trimesh mesh."""

    import trimesh

    try:
        mesh_full = trimesh.util.concatenate(mesh.dump())
    except Exception:
        mesh_full = trimesh.util.concatenate(mesh)

    mesh_full = normalize_mesh(mesh_full)

    faces = mesh_full.faces
    vertices = mesh_full.vertices
    origin_face_count = faces.shape[0]

    mesh_surface = trimesh.Trimesh(vertices=vertices, faces=faces[:origin_face_count])
    mesh_fill = trimesh.Trimesh(vertices=vertices, faces=faces[origin_face_count:])

    area_surface = mesh_surface.area
    area_fill = mesh_fill.area
    total_area = area_surface + area_fill

    sample_num = 499712 // 2
    fill_ratio = area_fill / total_area if total_area > 0 else 0

    num_fill = int(sample_num * fill_ratio)
    num_surface = sample_num - num_fill

    surf_pts, surf_normals = sample_pointcloud(mesh_surface, num_surface)
    fill_pts, fill_normals = (torch.zeros(0, 3), torch.zeros(0, 3)) if num_fill == 0 else sample_pointcloud(mesh_fill, num_fill)

    sharp_pts, sharp_normals = sharp_sample_pointcloud(mesh_surface, sample_num)

    def assemble_tensor(points, normals, label=None):

        data = torch.cat([points, normals], dim=1).half().to(device)

        if label is not None:
            label_tensor = torch.full((data.shape[0], 1), float(label), dtype=torch.float16).to(device)
            data = torch.cat([data, label_tensor], dim=1)

        return data

    surface = assemble_tensor(torch.cat([surf_pts.to(device), fill_pts.to(device)], dim=0),
                              torch.cat([surf_normals.to(device), fill_normals.to(device)], dim=0),
                              label = 0 if sharpedge_flag else None)

    sharp_surface = assemble_tensor(torch.from_numpy(sharp_pts), torch.from_numpy(sharp_normals),
                                    label = 1 if sharpedge_flag else None)

    rng = np.random.default_rng()

    surface = surface[rng.choice(surface.shape[0], num_points, replace = False)]
    sharp_surface = sharp_surface[rng.choice(sharp_surface.shape[0], num_sharp_points, replace = False)]

    full = torch.cat([surface, sharp_surface], dim = 0).unsqueeze(0)

    return full

class SharpEdgeSurfaceLoader:
    """ Load mesh surface and sharp edge samples. """

    def __init__(self, num_uniform_points = 8192, num_sharp_points = 8192):

        self.num_uniform_points = num_uniform_points
        self.num_sharp_points = num_sharp_points
        self.total_points = num_uniform_points + num_sharp_points

    def __call__(self, mesh_input, device = "cuda"):
        mesh = self._load_mesh(mesh_input)
        return load_surface_sharpedge(mesh, self.num_uniform_points, self.num_sharp_points, device = device)

    @staticmethod
    def _load_mesh(mesh_input):
        import trimesh

        if isinstance(mesh_input, str):
            mesh = trimesh.load(mesh_input, force="mesh", merge_primitives = True)
        else:
            mesh = mesh_input

        if isinstance(mesh, trimesh.Scene):
            combined = None
            for obj in mesh.geometry.values():
                combined = obj if combined is None else combined + obj
            return combined

        return mesh

class DiagonalGaussianDistribution:
    def __init__(self, params: torch.Tensor, feature_dim: int = -1):

        # divide quant channels (8) into mean and log variance
        self.mean, self.logvar = torch.chunk(params, 2, dim = feature_dim)

        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
        self.std = torch.exp(0.5 * self.logvar)

    def sample(self):

        eps = torch.randn_like(self.std)
        z = self.mean + eps * self.std

        return z

################################################
# Volume Decoder
################################################

class VanillaVolumeDecoder():
    @torch.no_grad()
    def __call__(self, latents: torch.Tensor, geo_decoder: callable, octree_resolution: int, bounds = 1.01,
                 num_chunks: int = 10_000, enable_pbar: bool = True, **kwargs):

        if isinstance(bounds, float):
            bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]

        bbox_min, bbox_max = torch.tensor(bounds[:3]), torch.tensor(bounds[3:])

        x = torch.linspace(bbox_min[0], bbox_max[0], int(octree_resolution) + 1, dtype = torch.float32)
        y = torch.linspace(bbox_min[1], bbox_max[1], int(octree_resolution) + 1, dtype = torch.float32)
        z = torch.linspace(bbox_min[2], bbox_max[2], int(octree_resolution) + 1, dtype = torch.float32)

        [xs, ys, zs] = torch.meshgrid(x, y, z, indexing = "ij")
        xyz = torch.stack((xs, ys, zs), axis=-1).to(latents.device, dtype = latents.dtype).contiguous().reshape(-1, 3)
        grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]

        batch_logits = []
        for start in tqdm(range(0, xyz.shape[0], num_chunks), desc="Volume Decoding",
                          disable=not enable_pbar):

            chunk_queries = xyz[start: start + num_chunks, :]
            chunk_queries = chunk_queries.unsqueeze(0).repeat(latents.shape[0], 1, 1)
            logits = geo_decoder(queries = chunk_queries, latents = latents)
            batch_logits.append(logits)

        grid_logits = torch.cat(batch_logits, dim = 1)
        grid_logits = grid_logits.view((latents.shape[0], *grid_size)).float()

        return grid_logits

class FourierEmbedder(nn.Module):
    """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
    each feature dimension of `x[..., i]` into:
        [
            sin(x[..., i]),
            sin(f_1*x[..., i]),
            sin(f_2*x[..., i]),
            ...
            sin(f_N * x[..., i]),
            cos(x[..., i]),
            cos(f_1*x[..., i]),
            cos(f_2*x[..., i]),
            ...
            cos(f_N * x[..., i]),
            x[..., i]     # only present if include_input is True.
        ], here f_i is the frequency.

    Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
    If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
    Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].

    Args:
        num_freqs (int): the number of frequencies, default is 6;
        logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
            otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
        input_dim (int): the input dimension, default is 3;
        include_input (bool): include the input tensor or not, default is True.

    Attributes:
        frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
                otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);

        out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
            otherwise, it is input_dim * num_freqs * 2.

    """

    def __init__(self,
                 num_freqs: int = 6,
                 logspace: bool = True,
                 input_dim: int = 3,
                 include_input: bool = True,
                 include_pi: bool = True) -> None:

        """The initialization"""

        super().__init__()

        if logspace:
            frequencies = 2.0 ** torch.arange(
                num_freqs,
                dtype=torch.float32
            )
        else:
            frequencies = torch.linspace(
                1.0,
                2.0 ** (num_freqs - 1),
                num_freqs,
                dtype=torch.float32
            )

        if include_pi:
            frequencies *= torch.pi

        self.register_buffer("frequencies", frequencies, persistent=False)
        self.include_input = include_input
        self.num_freqs = num_freqs

        self.out_dim = self.get_dims(input_dim)

    def get_dims(self, input_dim):
        temp = 1 if self.include_input or self.num_freqs == 0 else 0
        out_dim = input_dim * (self.num_freqs * 2 + temp)

        return out_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """ Forward process.

        Args:
            x: tensor of shape [..., dim]

        Returns:
            embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
                where temp is 1 if include_input is True and 0 otherwise.
        """

        if self.num_freqs > 0:
            embed = (x[..., None].contiguous() * self.frequencies.to(device=x.device, dtype=x.dtype)).view(*x.shape[:-1], -1)
            if self.include_input:
                return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
            else:
                return torch.cat((embed.sin(), embed.cos()), dim=-1)
        else:
            return x

class CrossAttentionProcessor:
    def __call__(self, attn, q, k, v):
        out = comfy.ops.scaled_dot_product_attention(q, k, v)
        return out

class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """

    def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

        This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
        the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
        See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
        changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
        'survival rate' as the argument.

        """
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
        random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
        if keep_prob > 0.0 and self.scale_by_keep:
            random_tensor.div_(keep_prob)
        return x * random_tensor

    def extra_repr(self):
        return f'drop_prob={round(self.drop_prob, 3):0.3f}'


class MLP(nn.Module):
    def __init__(
        self, *,
        width: int,
        expand_ratio: int = 4,
        output_width: int = None,
        drop_path_rate: float = 0.0
    ):
        super().__init__()
        self.width = width
        self.c_fc = ops.Linear(width, width * expand_ratio)
        self.c_proj = ops.Linear(width * expand_ratio, output_width if output_width is not None else width)
        self.gelu = nn.GELU()
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()

    def forward(self, x):
        return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))

class QKVMultiheadCrossAttention(nn.Module):
    def __init__(
        self,
        heads: int,
        n_data = None,
        width=None,
        qk_norm=False,
        norm_layer=ops.LayerNorm
    ):
        super().__init__()
        self.heads = heads
        self.n_data = n_data
        self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()

    def forward(self, q, kv):

        _, n_ctx, _ = q.shape
        bs, n_data, width = kv.shape

        attn_ch = width // self.heads // 2
        q = q.view(bs, n_ctx, self.heads, -1)

        kv = kv.view(bs, n_data, self.heads, -1)
        k, v = torch.split(kv, attn_ch, dim=-1)

        q = self.q_norm(q)
        k = self.k_norm(k)

        q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
        out = F.scaled_dot_product_attention(q, k, v)

        out = out.transpose(1, 2).reshape(bs, n_ctx, -1)

        return out

class MultiheadCrossAttention(nn.Module):
    def __init__(
        self,
        *,
        width: int,
        heads: int,
        qkv_bias: bool = True,
        data_width: Optional[int] = None,
        norm_layer=ops.LayerNorm,
        qk_norm: bool = False,
        kv_cache: bool = False,
    ):
        super().__init__()
        self.width = width
        self.heads = heads
        self.data_width = width if data_width is None else data_width
        self.c_q = ops.Linear(width, width, bias=qkv_bias)
        self.c_kv = ops.Linear(self.data_width, width * 2, bias=qkv_bias)
        self.c_proj = ops.Linear(width, width)
        self.attention = QKVMultiheadCrossAttention(
            heads=heads,
            width=width,
            norm_layer=norm_layer,
            qk_norm=qk_norm
        )
        self.kv_cache = kv_cache
        self.data = None

    def forward(self, x, data):
        x = self.c_q(x)
        if self.kv_cache:
            if self.data is None:
                self.data = self.c_kv(data)
                logging.info('Save kv cache,this should be called only once for one mesh')
            data = self.data
        else:
            data = self.c_kv(data)
        x = self.attention(x, data)
        x = self.c_proj(x)
        return x

class ResidualCrossAttentionBlock(nn.Module):
    def __init__(
        self,
        *,
        width: int,
        heads: int,
        mlp_expand_ratio: int = 4,
        data_width: Optional[int] = None,
        qkv_bias: bool = True,
        norm_layer=ops.LayerNorm,
        qk_norm: bool = False
    ):
        super().__init__()

        if data_width is None:
            data_width = width

        self.attn = MultiheadCrossAttention(
            width=width,
            heads=heads,
            data_width=data_width,
            qkv_bias=qkv_bias,
            norm_layer=norm_layer,
            qk_norm=qk_norm
        )
        self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
        self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
        self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)
        self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio)

    def forward(self, x: torch.Tensor, data: torch.Tensor):
        x = x + self.attn(self.ln_1(x), self.ln_2(data))
        x = x + self.mlp(self.ln_3(x))
        return x


class QKVMultiheadAttention(nn.Module):
    def __init__(
        self,
        *,
        heads: int,
        width=None,
        qk_norm=False,
        norm_layer=ops.LayerNorm
    ):
        super().__init__()
        self.heads = heads
        self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()

    def forward(self, qkv):
        bs, n_ctx, width = qkv.shape
        attn_ch = width // self.heads // 3
        qkv = qkv.view(bs, n_ctx, self.heads, -1)
        q, k, v = torch.split(qkv, attn_ch, dim=-1)

        q = self.q_norm(q)
        k = self.k_norm(k)

        q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
        out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
        return out


class MultiheadAttention(nn.Module):
    def __init__(
        self,
        *,
        width: int,
        heads: int,
        qkv_bias: bool,
        norm_layer=ops.LayerNorm,
        qk_norm: bool = False,
        drop_path_rate: float = 0.0
    ):
        super().__init__()

        self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
        self.c_proj = ops.Linear(width, width)
        self.attention = QKVMultiheadAttention(
            heads=heads,
            width=width,
            norm_layer=norm_layer,
            qk_norm=qk_norm
        )
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()

    def forward(self, x):
        x = self.c_qkv(x)
        x = self.attention(x)
        x = self.drop_path(self.c_proj(x))
        return x


class ResidualAttentionBlock(nn.Module):
    def __init__(
        self,
        *,
        width: int,
        heads: int,
        qkv_bias: bool = True,
        norm_layer=ops.LayerNorm,
        qk_norm: bool = False,
        drop_path_rate: float = 0.0,
    ):
        super().__init__()
        self.attn = MultiheadAttention(
            width=width,
            heads=heads,
            qkv_bias=qkv_bias,
            norm_layer=norm_layer,
            qk_norm=qk_norm,
            drop_path_rate=drop_path_rate
        )
        self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
        self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)
        self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)

    def forward(self, x: torch.Tensor):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class Transformer(nn.Module):
    def __init__(
        self,
        *,
        width: int,
        layers: int,
        heads: int,
        qkv_bias: bool = True,
        norm_layer=ops.LayerNorm,
        qk_norm: bool = False,
        drop_path_rate: float = 0.0
    ):
        super().__init__()
        self.width = width
        self.layers = layers
        self.resblocks = nn.ModuleList(
            [
                ResidualAttentionBlock(
                    width=width,
                    heads=heads,
                    qkv_bias=qkv_bias,
                    norm_layer=norm_layer,
                    qk_norm=qk_norm,
                    drop_path_rate=drop_path_rate
                )
                for _ in range(layers)
            ]
        )

    def forward(self, x: torch.Tensor):
        for block in self.resblocks:
            x = block(x)
        return x


class CrossAttentionDecoder(nn.Module):

    def __init__(
        self,
        *,
        out_channels: int,
        fourier_embedder: FourierEmbedder,
        width: int,
        heads: int,
        mlp_expand_ratio: int = 4,
        downsample_ratio: int = 1,
        enable_ln_post: bool = True,
        qkv_bias: bool = True,
        qk_norm: bool = False,
        label_type: str = "binary"
    ):
        super().__init__()

        self.enable_ln_post = enable_ln_post
        self.fourier_embedder = fourier_embedder
        self.downsample_ratio = downsample_ratio
        self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
        if self.downsample_ratio != 1:
            self.latents_proj = ops.Linear(width * downsample_ratio, width)
        if not self.enable_ln_post:
            qk_norm = False
        self.cross_attn_decoder = ResidualCrossAttentionBlock(
            width=width,
            mlp_expand_ratio=mlp_expand_ratio,
            heads=heads,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm
        )

        if self.enable_ln_post:
            self.ln_post = ops.LayerNorm(width)
        self.output_proj = ops.Linear(width, out_channels)
        self.label_type = label_type
        self.count = 0

    def forward(self, queries=None, query_embeddings=None, latents=None):
        if query_embeddings is None:
            query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype))
        self.count += query_embeddings.shape[1]
        if self.downsample_ratio != 1:
            latents = self.latents_proj(latents)
        x = self.cross_attn_decoder(query_embeddings, latents)
        if self.enable_ln_post:
            x = self.ln_post(x)
        occ = self.output_proj(x)
        return occ


class ShapeVAE(nn.Module):
    def __init__(
            self,
            *,
            num_latents: int = 4096,
            embed_dim: int = 64,
            width: int = 1024,
            heads: int = 16,
            num_decoder_layers: int = 16,
            num_encoder_layers: int = 8,
            pc_size: int = 81920,
            pc_sharpedge_size: int = 0,
            point_feats: int = 4,
            downsample_ratio: int = 20,
            geo_decoder_downsample_ratio: int = 1,
            geo_decoder_mlp_expand_ratio: int = 4,
            geo_decoder_ln_post: bool = True,
            num_freqs: int = 8,
            qkv_bias: bool = False,
            qk_norm: bool = True,
            drop_path_rate: float = 0.0,
            include_pi: bool = False,
            scale_factor: float = 1.0039506158752403,
            label_type: str = "binary",
    ):
        super().__init__()
        self.geo_decoder_ln_post = geo_decoder_ln_post

        self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)

        self.encoder = PointCrossAttention(layers = num_encoder_layers,
                                    num_latents = num_latents,
                                    downsample_ratio = downsample_ratio,
                                    heads = heads,
                                    pc_size = pc_size,
                                    width = width,
                                    point_feats = point_feats,
                                    fourier_embedder = self.fourier_embedder,
                                    pc_sharpedge_size = pc_sharpedge_size)

        self.post_kl = ops.Linear(embed_dim, width)

        self.transformer = Transformer(
            width=width,
            layers=num_decoder_layers,
            heads=heads,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm,
            drop_path_rate=drop_path_rate
        )

        self.geo_decoder = CrossAttentionDecoder(
            fourier_embedder=self.fourier_embedder,
            out_channels=1,
            mlp_expand_ratio=geo_decoder_mlp_expand_ratio,
            downsample_ratio=geo_decoder_downsample_ratio,
            enable_ln_post=self.geo_decoder_ln_post,
            width=width // geo_decoder_downsample_ratio,
            heads=heads // geo_decoder_downsample_ratio,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm,
            label_type=label_type,
        )

        self.volume_decoder = VanillaVolumeDecoder()
        self.scale_factor = scale_factor

    def decode(self, latents, **kwargs):
        latents = self.post_kl(latents.movedim(-2, -1))
        latents = self.transformer(latents)

        bounds = kwargs.get("bounds", 1.01)
        num_chunks = kwargs.get("num_chunks", 8000)
        octree_resolution = kwargs.get("octree_resolution", 256)
        enable_pbar = kwargs.get("enable_pbar", True)

        grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
        return grid_logits.movedim(-2, -1)

    def encode(self, surface):

        pc, feats = surface[:, :, :3], surface[:, :, 3:]
        latents = self.encoder(pc, feats)

        moments = self.pre_kl(latents)
        posterior = DiagonalGaussianDistribution(moments, feature_dim = -1)

        latents = posterior.sample()

        return latents
