from dataclasses import dataclass
from typing import Literal

import torch
from jaxtyping import Float
from torch import Tensor

from ..types import Gaussians
from .decoder import Decoder, DecoderOutput
from math import sqrt
from gsplat import rasterization

DepthRenderingMode = Literal["depth", "disparity", "relative_disparity", "log"]


@dataclass
class DecoderSplattingSplatCfg:
    name: Literal["splatting_gsplat"]
    background_color: list[float]
    make_scale_invariant: bool
    prune_opacity_threshold: float = 0.005
    training_prune_ratio: float = 0.
    training_prune_keep_ratio: float = 0.1


def prune_gaussians(
    gaussians: Gaussians,
    opacity_threshold: float,
    prune_ratio: float,
    random_keep_ratio: float,
    inference: bool = False,
) -> Gaussians:
    """Prune the Gaussians to only include those that are visible in the image."""
    means = gaussians.means  # (B, G, 3)
    opacities = gaussians.opacities  # (B, G)

    if means.shape[0] > 1:
        assert not inference, "Inference mode is not supported when bs > 1."

    if inference and opacity_threshold > 0:
        # Inference mode: prune based on opacity threshold
        gaussian_mask = opacities > opacity_threshold  # (B, G)

        # print number of pruned gaussians and the ratio
        num_gaussians = means.shape[1]
        num_pruned = num_gaussians - gaussian_mask.sum()
        print(f"Pruned {num_pruned} gaussians out of {num_gaussians} ({num_pruned / num_gaussians:.2%})")

        def trim(element, mask):
            return element[mask].unsqueeze(0)

        gaussians = Gaussians(
            means=trim(gaussians.means, gaussian_mask),
            covariances=trim(gaussians.covariances, gaussian_mask),
            harmonics=trim(gaussians.harmonics, gaussian_mask),
            opacities=trim(gaussians.opacities, gaussian_mask),
            rotations=trim(gaussians.rotations, gaussian_mask),
            scales=trim(gaussians.scales, gaussian_mask),
        )

        return gaussians

    if prune_ratio > 0:
        # Training mode: prune based on opacity and random sampling (fixed ratio)
        num_gaussians = means.shape[1]
        # gaussian_usage = (opacities > opacity_threshold).float().mean(dim=1).squeeze(-1)  # (B,)

        keep_ratio = 1 - prune_ratio
        random_keep_ratio = keep_ratio * random_keep_ratio
        keep_ratio = keep_ratio - random_keep_ratio
        num_keep = int(num_gaussians * keep_ratio)
        num_keep_random = int(num_gaussians * random_keep_ratio)
        # rank by opacity
        idx_sort = opacities.argsort(dim=1, descending=True)
        keep_idx = idx_sort[:, :num_keep]
        if num_keep_random > 0:
            rest_idx = idx_sort[:, num_keep:]
            random_idx = rest_idx[:, torch.randperm(rest_idx.shape[1])[:num_keep_random]]
            keep_idx = torch.cat([keep_idx, random_idx], dim=1)

        return Gaussians(
            means=gaussians.means.gather(1, keep_idx.unsqueeze(-1).expand(-1, -1, gaussians.means.shape[-1])),
            covariances=gaussians.covariances.gather(1, keep_idx[..., None, None].expand(-1, -1, gaussians.covariances.shape[-2], gaussians.covariances.shape[-1])),
            harmonics=gaussians.harmonics.gather(1, keep_idx[..., None, None].expand(-1, -1, gaussians.harmonics.shape[-2], gaussians.harmonics.shape[-1])),
            opacities=gaussians.opacities.gather(1, keep_idx),
            rotations=gaussians.rotations.gather(1, keep_idx.unsqueeze(-1).expand(-1, -1, gaussians.rotations.shape[-1])),
            scales=gaussians.scales.gather(1, keep_idx.unsqueeze(-1).expand(-1, -1, gaussians.scales.shape[-1])),
        )

    return gaussians


class DecoderSplattingGSPlat(Decoder[DecoderSplattingSplatCfg]):
    background_color: Float[Tensor, "3"]

    def __init__(
            self,
            cfg: DecoderSplattingSplatCfg,
    ) -> None:
        super().__init__(cfg)
        self.make_scale_invariant = cfg.make_scale_invariant
        self.register_buffer(
            "background_color",
            torch.tensor(cfg.background_color, dtype=torch.float32),
            persistent=False,
        )

    def rendering_fn(
            self,
            gaussians: Gaussians,
            extrinsics: Float[Tensor, "batch view 4 4"],
            intrinsics: Float[Tensor, "batch view 3 3"],
            near: Float[Tensor, "batch view"],
            far: Float[Tensor, "batch view"],
            image_shape: tuple[int, int],
            depth_mode: DepthRenderingMode | None = None,
            cam_rot_delta: Float[Tensor, "batch view 3"] | None = None,
            cam_trans_delta: Float[Tensor, "batch view 3"] | None = None,
    ) -> DecoderOutput:
        gaussians = prune_gaussians(gaussians, self.cfg.prune_opacity_threshold, self.cfg.training_prune_ratio, self.cfg.training_prune_keep_ratio, inference=not self.training)

        B, V, _, _ = intrinsics.shape
        H, W = image_shape
        means, opacitys, rotations, scales, features = gaussians.means, gaussians.opacities, gaussians.rotations, gaussians.scales, gaussians.harmonics.permute(
            0, 1, 3, 2).contiguous()
        covars = gaussians.covariances

        w2c = extrinsics.float().inverse()  # (B, V, 4, 4)
        sh_degree = (int(sqrt(features.shape[-2])) - 1)

        intrinsics_denorm = intrinsics.clone()
        intrinsics_denorm[:, :, 0] = intrinsics_denorm[:, :, 0] * W
        intrinsics_denorm[:, :, 1] = intrinsics_denorm[:, :, 1] * H

        backgrounds = self.background_color.unsqueeze(0).unsqueeze(0).repeat(B, V, 1)

        rendering, alpha, _ = rasterization(means, rotations, scales, opacitys, features,
                                            w2c,
                                            intrinsics_denorm,
                                            W, H,
                                            sh_degree=sh_degree,
                                            render_mode="RGB+D", packed=False,
                                            backgrounds=backgrounds,
                                            radius_clip=0.1,
                                            covars=covars,
                                            rasterize_mode='classic',
                                            )  # (V, H, W, 3)
        rendering_img, rendering_depth = torch.split(rendering, [3, 1], dim=-1)
        rendering_img = rendering_img.clamp(0.0, 1.0)
        return DecoderOutput(rendering_img.permute(0, 1, 4, 2, 3), rendering_depth.squeeze(-1))

    def forward(
            self,
            gaussians: Gaussians,
            extrinsics: Float[Tensor, "batch view 4 4"],
            intrinsics: Float[Tensor, "batch view 3 3"],
            near: Float[Tensor, "batch view"],
            far: Float[Tensor, "batch view"],
            image_shape: tuple[int, int],
            depth_mode: DepthRenderingMode | None = None,
            cam_rot_delta: Float[Tensor, "batch view 3"] | None = None,
            cam_trans_delta: Float[Tensor, "batch view 3"] | None = None,
    ) -> DecoderOutput:

        return self.rendering_fn(gaussians, extrinsics, intrinsics, near, far, image_shape, depth_mode, cam_rot_delta,
                                 cam_trans_delta)
