from dataclasses import dataclass
from typing import Literal

import torch
from einops import rearrange, repeat
from jaxtyping import Float
from torch import Tensor

from ...dataset import DatasetCfg
from ..types import Gaussians
from .cuda_splatting import DepthRenderingMode, render_cuda
from .decoder import Decoder, DecoderOutput


@dataclass
class DecoderSplattingCUDACfg:
    name: Literal["splatting_cuda"]
    background_color: list[float]
    make_scale_invariant: bool
    prune_opacity_threshold: float = 0.005
    training_prune_ratio: float = 0.
    training_prune_keep_ratio: float = 0.1


def prune_gaussians(
    gaussians: Gaussians,
    opacity_threshold: float,
    prune_ratio: float,
    random_keep_ratio: float,
    inference: bool = False,
) -> Gaussians:
    """Prune the Gaussians to only include those that are visible in the image."""
    means = gaussians.means  # (B, G, 3)
    opacities = gaussians.opacities  # (B, G)

    if means.shape[0] > 1:
        assert not inference, "Inference mode is not supported when bs > 1."

    if inference and opacity_threshold > 0:
        # Inference mode: prune based on opacity threshold
        gaussian_mask = opacities > opacity_threshold  # (B, G)

        # print number of pruned gaussians and the ratio
        num_gaussians = means.shape[1]
        num_pruned = num_gaussians - gaussian_mask.sum()
        # print(f"Pruned {num_pruned} gaussians out of {num_gaussians} ({num_pruned / num_gaussians:.2%})")

        def trim(element, mask):
            return element[mask].unsqueeze(0)

        gaussians = Gaussians(
            means=trim(gaussians.means, gaussian_mask),
            covariances=trim(gaussians.covariances, gaussian_mask),
            harmonics=trim(gaussians.harmonics, gaussian_mask),
            opacities=trim(gaussians.opacities, gaussian_mask),
            rotations=trim(gaussians.rotations, gaussian_mask),
            scales=trim(gaussians.scales, gaussian_mask),
        )

        return gaussians

    if prune_ratio > 0:
        # Training mode: prune based on opacity and random sampling (fixed ratio)
        num_gaussians = means.shape[1]
        # gaussian_usage = (opacities > opacity_threshold).float().mean(dim=1).squeeze(-1)  # (B,)

        keep_ratio = 1 - prune_ratio
        random_keep_ratio = keep_ratio * random_keep_ratio
        keep_ratio = keep_ratio - random_keep_ratio
        num_keep = int(num_gaussians * keep_ratio)
        num_keep_random = int(num_gaussians * random_keep_ratio)
        # rank by opacity
        idx_sort = opacities.argsort(dim=1, descending=True)
        keep_idx = idx_sort[:, :num_keep]
        if num_keep_random > 0:
            rest_idx = idx_sort[:, num_keep:]
            random_idx = rest_idx[:, torch.randperm(rest_idx.shape[1])[:num_keep_random]]
            keep_idx = torch.cat([keep_idx, random_idx], dim=1)

        return Gaussians(
            means=gaussians.means.gather(1, keep_idx.unsqueeze(-1).expand(-1, -1, gaussians.means.shape[-1])),
            covariances=gaussians.covariances.gather(1, keep_idx[..., None, None].expand(-1, -1, gaussians.covariances.shape[-2], gaussians.covariances.shape[-1])),
            harmonics=gaussians.harmonics.gather(1, keep_idx[..., None, None].expand(-1, -1, gaussians.harmonics.shape[-2], gaussians.harmonics.shape[-1])),
            opacities=gaussians.opacities.gather(1, keep_idx),
            rotations=gaussians.rotations.gather(1, keep_idx.unsqueeze(-1).expand(-1, -1, gaussians.rotations.shape[-1])),
            scales=gaussians.scales.gather(1, keep_idx.unsqueeze(-1).expand(-1, -1, gaussians.scales.shape[-1])),
        )

    return gaussians


class DecoderSplattingCUDA(Decoder[DecoderSplattingCUDACfg]):
    background_color: Float[Tensor, "3"]

    def __init__(
        self,
        cfg: DecoderSplattingCUDACfg,
    ) -> None:
        super().__init__(cfg)
        self.make_scale_invariant = cfg.make_scale_invariant
        self.register_buffer(
            "background_color",
            torch.tensor(cfg.background_color, dtype=torch.float32),
            persistent=False,
        )

    def forward(
        self,
        gaussians: Gaussians,
        extrinsics: Float[Tensor, "batch view 4 4"],
        intrinsics: Float[Tensor, "batch view 3 3"],
        near: Float[Tensor, "batch view"],
        far: Float[Tensor, "batch view"],
        image_shape: tuple[int, int],
        depth_mode: DepthRenderingMode | None = None,
        cam_rot_delta: Float[Tensor, "batch view 3"] | None = None,
        cam_trans_delta: Float[Tensor, "batch view 3"] | None = None,
    ) -> DecoderOutput:
        b, v, _, _ = extrinsics.shape
        gaussians = prune_gaussians(gaussians, self.cfg.prune_opacity_threshold, self.cfg.training_prune_ratio, self.cfg.training_prune_keep_ratio, inference=not self.training)
        color, depth = render_cuda(
            rearrange(extrinsics, "b v i j -> (b v) i j"),
            rearrange(intrinsics, "b v i j -> (b v) i j"),
            rearrange(near, "b v -> (b v)"),
            rearrange(far, "b v -> (b v)"),
            image_shape,
            repeat(self.background_color, "c -> (b v) c", b=b, v=v),
            repeat(gaussians.means, "b g xyz -> (b v) g xyz", v=v),
            repeat(gaussians.covariances, "b g i j -> (b v) g i j", v=v),
            repeat(gaussians.harmonics, "b g c d_sh -> (b v) g c d_sh", v=v),
            repeat(gaussians.opacities, "b g -> (b v) g", v=v),
            scale_invariant=self.make_scale_invariant,
            cam_rot_delta=rearrange(cam_rot_delta, "b v i -> (b v) i") if cam_rot_delta is not None else None,
            cam_trans_delta=rearrange(cam_trans_delta, "b v i -> (b v) i") if cam_trans_delta is not None else None,
        )
        color = rearrange(color, "(b v) c h w -> b v c h w", b=b, v=v)

        depth = rearrange(depth, "(b v) h w -> b v h w", b=b, v=v)
        return DecoderOutput(color, depth)
