import json
from dataclasses import dataclass
from functools import cached_property
from io import BytesIO
from pathlib import Path
from typing import Literal

import torch
import torchvision.transforms as tf
from einops import rearrange, repeat
from jaxtyping import Float, UInt8
from PIL import Image
from torch import Tensor
from torch.utils.data import IterableDataset

from ..geometry.projection import get_fov
from .dataset import DatasetCfgCommon
from .shims.augmentation_shim import apply_augmentation_shim
from .shims.crop_shim import apply_crop_shim
from .types import Stage
from .view_sampler import ViewSampler
from ..misc.cam_utils import camera_normalization


@dataclass
class DatasetRE10kCfg(DatasetCfgCommon):
    name: str
    roots: list[Path]
    baseline_min: float
    baseline_max: float
    max_fov: float
    make_baseline_1: bool
    augment: bool
    relative_pose: bool
    skip_bad_shape: bool
    pose_norm_method: str = "max_1"  # "none", "start_end_1", "relative_world_1", "max_1"


@dataclass
class DatasetRE10kCfgWrapper:
    re10k: DatasetRE10kCfg


@dataclass
class DatasetDL3DVCfgWrapper:
    dl3dv: DatasetRE10kCfg


@dataclass
class DatasetScannetppCfgWrapper:
    scannetpp: DatasetRE10kCfg


class DatasetRE10k(IterableDataset):
    cfg: DatasetRE10kCfg
    stage: Stage
    view_sampler: ViewSampler

    to_tensor: tf.ToTensor
    chunks: list[Path]
    near: float = 0.1
    far: float = 100.0

    def __init__(
        self,
        cfg: DatasetRE10kCfg,
        stage: Stage,
        view_sampler: ViewSampler,
    ) -> None:
        super().__init__()
        self.cfg = cfg
        self.stage = stage
        self.view_sampler = view_sampler
        self.to_tensor = tf.ToTensor()
        self.name = cfg.name

        # Collect chunks.
        self.chunks = []
        for root in cfg.roots:
            root = root / self.data_stage
            root_chunks = sorted(
                [path for path in root.iterdir() if path.suffix == ".torch"]
            )
            self.chunks.extend(root_chunks)
        if self.cfg.overfit_to_scene is not None:
            chunk_path = self.index[self.cfg.overfit_to_scene]
            self.chunks = [chunk_path] * len(self.chunks)

    def shuffle(self, lst: list) -> list:
        indices = torch.randperm(len(lst))
        return [lst[x] for x in indices]

    def __iter__(self):
        # Chunks must be shuffled here (not inside __init__) for validation to show
        # random chunks.
        if self.stage in ("train", "val"):
            self.chunks = self.shuffle(self.chunks)

        # When testing, the data loaders alternate chunks.
        worker_info = torch.utils.data.get_worker_info()
        if self.stage == "test" and worker_info is not None:
            self.chunks = [
                chunk
                for chunk_index, chunk in enumerate(self.chunks)
                if chunk_index % worker_info.num_workers == worker_info.id
            ]

        for chunk_path in self.chunks:
            # Load the chunk.
            chunk = torch.load(chunk_path)

            if self.cfg.overfit_to_scene is not None:
                item = [x for x in chunk if x["key"] == self.cfg.overfit_to_scene]
                assert len(item) == 1
                chunk = item * len(chunk)

            if self.stage in ("train", "val"):
                chunk = self.shuffle(chunk)

            for example in chunk:
                extrinsics, intrinsics = self.convert_poses(example["cameras"])
                scene = example["key"]

                try:
                    context_indices, target_indices, overlap = self.view_sampler.sample(
                        scene,
                        extrinsics,
                        intrinsics,
                    )
                except ValueError:
                    # Skip because the example doesn't have enough frames.
                    continue

                # Skip the example if the field of view is too wide.
                if (get_fov(intrinsics).rad2deg() > self.cfg.max_fov).any():
                    continue

                # Load the images.
                try:
                    context_images = [
                        example["images"][index.item()] for index in context_indices
                    ]
                    context_images = self.convert_images(context_images)
                    target_images = [
                        example["images"][index.item()] for index in target_indices
                    ]
                    target_images = self.convert_images(target_images)
                except IndexError:
                    continue
                except OSError:
                    print(f"Skipped bad example {example['key']}.")  # DL3DV-Full have some bad images
                    continue

                # Skip the example if the images don't have the right shape.
                context_image_invalid = context_images.shape[1:] != (3, *self.cfg.original_image_shape)
                target_image_invalid = target_images.shape[1:] != (3, *self.cfg.original_image_shape)
                if self.cfg.skip_bad_shape and (context_image_invalid or target_image_invalid):
                    print(
                        f"Skipped bad example {example['key']}. Context shape was "
                        f"{context_images.shape} and target shape was "
                        f"{target_images.shape}."
                    )
                    continue

                # Resize the world to make the baseline 1.
                context_extrinsics = extrinsics[context_indices]
                target_extrinsics = extrinsics[target_indices]
                num_context_images = len(context_extrinsics)
                all_used_extrinsics = torch.cat([context_extrinsics, target_extrinsics], dim=0)

                if self.cfg.pose_norm_method == "start_end_1":
                    a, b = context_extrinsics[0, :3, 3], context_extrinsics[-1, :3, 3]
                    scale = (a - b).norm()
                elif self.cfg.pose_norm_method == "max_1":
                    scale = 0
                    for i in range(context_extrinsics.shape[0]):
                        for j in range(i + 1, context_extrinsics.shape[0]):
                            a, b = context_extrinsics[i, :3, 3], context_extrinsics[j, :3, 3]
                            scale = max(scale, (a - b).norm())
                elif self.cfg.pose_norm_method == "avg_norm":
                    # For c2w matrices, camera position is just the translation part
                    camera_positions = context_extrinsics[:, :3, 3]  # Shape: (V, 3)

                    # Calculate pairwise distances using broadcasting
                    # Expand dimensions to compute all pairwise differences
                    positions_i = camera_positions.unsqueeze(1)  # Shape: (V, 1, 3)
                    positions_j = camera_positions.unsqueeze(0)  # Shape: (1, V, 3)

                    # Compute pairwise distance matrix
                    distance_matrix = torch.norm(positions_i - positions_j, dim=2)  # Shape: (V, V)

                    # Extract upper triangular part (excluding diagonal) to get unique pairs
                    mask = torch.triu(torch.ones(distance_matrix.shape[0], distance_matrix.shape[1], dtype=torch.bool),
                                      diagonal=1)
                    pairwise_distances = distance_matrix[mask]

                    # Calculate average distance
                    average_distance = torch.mean(pairwise_distances)
                    scale = average_distance

                elif self.cfg.pose_norm_method == "max_trans":
                    scale = torch.max(torch.abs(context_extrinsics[:, :3, 3]))
                    scale = torch.norm(scale)  # re-scale the scene to a fixed scale

                elif self.cfg.pose_norm_method == "relative_world_1":
                    view1 = context_extrinsics[0:1, :3, 3]
                    view_remaining = context_extrinsics[1:, :3, 3]
                    scale = (view1 - view_remaining).norm(dim=-1).max()
                elif self.cfg.pose_norm_method == "none":
                    scale = 1.0
                else:
                    raise ValueError(f"Unknown pose norm method {self.cfg.pose_norm_method}")

                if scale < self.cfg.baseline_min or scale > self.cfg.baseline_max:
                    print(
                        f"Skipped {scene} because of baseline out of range: "
                        f"{scale:.6f}"
                    )
                    continue
                all_used_extrinsics[:, :3, 3] /= scale

                if self.cfg.relative_pose:
                    all_used_extrinsics = camera_normalization(all_used_extrinsics[0:1], all_used_extrinsics)

                example = {
                    "context": {
                        "extrinsics": all_used_extrinsics[:num_context_images],
                        "intrinsics": intrinsics[context_indices],
                        "image": context_images,
                        "near": self.get_bound("near", len(context_indices)) / scale,
                        "far": self.get_bound("far", len(context_indices)) / scale,
                        "index": context_indices,
                        "overlap": overlap,
                    },
                    "target": {
                        "extrinsics": all_used_extrinsics[num_context_images:],
                        "intrinsics": intrinsics[target_indices],
                        "image": target_images,
                        "near": self.get_bound("near", len(target_indices)) / scale,
                        "far": self.get_bound("far", len(target_indices)) / scale,
                        "index": target_indices,
                    },
                    "scene": scene,
                }
                if self.stage == "train" and self.cfg.augment:
                    example = apply_augmentation_shim(example)
                yield apply_crop_shim(example, tuple(self.cfg.input_image_shape))

    def convert_poses(
        self,
        poses: Float[Tensor, "batch 18"],
    ) -> tuple[
        Float[Tensor, "batch 4 4"],  # extrinsics
        Float[Tensor, "batch 3 3"],  # intrinsics
    ]:
        b, _ = poses.shape

        # Convert the intrinsics to a 3x3 normalized K matrix.
        intrinsics = torch.eye(3, dtype=torch.float32)
        intrinsics = repeat(intrinsics, "h w -> b h w", b=b).clone()
        fx, fy, cx, cy = poses[:, :4].T
        intrinsics[:, 0, 0] = fx
        intrinsics[:, 1, 1] = fy
        intrinsics[:, 0, 2] = cx
        intrinsics[:, 1, 2] = cy

        # Convert the extrinsics to a 4x4 OpenCV-style W2C matrix.
        w2c = repeat(torch.eye(4, dtype=torch.float32), "h w -> b h w", b=b).clone()
        w2c[:, :3] = rearrange(poses[:, 6:], "b (h w) -> b h w", h=3, w=4)
        return w2c.inverse(), intrinsics

    def convert_images(
        self,
        images: list[UInt8[Tensor, "..."]],
    ) -> Float[Tensor, "batch 3 height width"]:
        torch_images = []
        for image in images:
            image = Image.open(BytesIO(image.numpy().tobytes()))
            torch_images.append(self.to_tensor(image))
        return torch.stack(torch_images)

    def get_bound(
        self,
        bound: Literal["near", "far"],
        num_views: int,
    ) -> Float[Tensor, " view"]:
        value = torch.tensor(getattr(self, bound), dtype=torch.float32)
        return repeat(value, "-> v", v=num_views)

    @property
    def data_stage(self) -> Stage:
        if self.cfg.overfit_to_scene is not None:
            return "test"
        if self.stage == "val":
            return "test"
        return self.stage

    @cached_property
    def index(self) -> dict[str, Path]:
        merged_index = {}
        data_stages = [self.data_stage]
        if self.cfg.overfit_to_scene is not None:
            data_stages = ("test", "train")
        for data_stage in data_stages:
            for root in self.cfg.roots:
                # Load the root's index.
                with (root / data_stage / "index.json").open("r") as f:
                    index = json.load(f)
                index = {k: Path(root / data_stage / v) for k, v in index.items()}

                # The constituent datasets should have unique keys.
                assert not (set(merged_index.keys()) & set(index.keys()))

                # Merge the root's index into the main index.
                merged_index = {**merged_index, **index}
        return merged_index

    # def __len__(self) -> int:
    #     return len(self.index.keys())
    def test_len(self) -> int:
        if self.stage == "test":
            return len(self.view_sampler.index.keys())
        else:
            return 0
