import torch
from einops import rearrange

from src.geometry.projection import sample_image_grid, project


def reproj_loss(pts, intrinsics, extrinsics, h, w, tolerance=0.9, mask=None, downsample=1):
    """
    Compute reprojection loss between 3D points and 2D grid. First, project 3D points in world coordinate to camera
    coordinate using extrinsics.
    Then, project 3D points in camera coordinate to 2D grid using intrinsics.
    Finally, compute the reprojection loss between 2D grid and projected 2D points.
    The loss is zero if the difference between the 2D grid and projected 2D points is less than the tolerance.

    :param pts: 3D points (N, V, M, 3) in worded coordinate
    :param intrinsics: camera intrinsics (N, V, 3, 3)
    :param extrinsics: camera extrinsics (N, V, 4, 4)
    :param h: height of 2D grid
    :param w: width of 2D grid
    :param tolerance: tolerance in pixels
    :return: reprojection loss
    """

    intrinsics = rearrange(intrinsics, "b v i j -> b v () i j")
    extrinsics = rearrange(extrinsics, "b v i j -> b v () i j")  # the inversion operation is done within the project() function

    pts_2d, in_front_of_camera = project(pts, extrinsics, intrinsics)  # (N, 2)

    # sample 2D grid
    grid_2d, _ = sample_image_grid((h // downsample, w // downsample), pts.device)  # (H, W, 2)
    grid_2d = rearrange(grid_2d, "h w xy -> () () (h w) xy")  # (1, 1, HW, 2)

    # normolize tolerance pixel size
    pixel_size = 1 / torch.tensor((w, h), dtype=torch.float32, device=pts.device)
    tolerance = tolerance * pixel_size  # (2,)

    # compute the reprojection loss between 2D grid and projected 2D points
    distance = grid_2d - pts_2d
    distance = torch.where(torch.abs(distance) < tolerance, 0.0, distance)  # apply tolerance

    # L2 loss
    # loss = torch.norm(distance, dim=-1)  # (N, V, HW)
    loss = distance ** 2  # (N, V, HW, 2)

    if mask is not None:
        mask = rearrange(mask, "b v h w -> b v (h w) ()")
        loss = loss * mask
        loss = loss.sum() / mask.sum()
        return loss

    # loss = loss[:, 1]  # only consider the second view
    return loss.mean()  # return mean loss


def epipolar_loss():
    pass
