import torch
import torch.nn.functional as F
import numpy as np


def transl_ang_loss(t, tgt, eps=1e-6):
    """
    Args:
        t: estimated translation vector [B, 3]
        tgt: ground-truth translation vector [B, 3]
    Returns:
        T_err: translation direction angular error
    """
    t_norm = torch.norm(t, dim=1, keepdim=True)
    t_normed = t / (t_norm + eps)
    tgt_norm = torch.norm(tgt, dim=1, keepdim=True)
    tgt_normed = tgt / (tgt_norm + eps)
    cosine = torch.sum(t_normed * tgt_normed, dim=1)
    T_err = torch.acos(torch.clamp(cosine, -1.0 + eps, 1.0 - eps))  # handle numerical errors and NaNs
    return T_err.mean()


def transl_l2_loss(t, tgt):
    """
    Args:
        t: estimated translation vector [B, 3]
        tgt: ground-truth translation vector [B, 3]
    Returns:
        T_err: translation L2 error
    """
    T_err = torch.norm(t - tgt, dim=1)  # L2 norm
    return T_err.mean()


def transl_huber_loss(t, tgt, delta=1.0):
    """
    Args:
        t: estimated translation vector [B, 3]
        tgt: ground-truth translation vector [B, 3]
        delta: huber loss threshold
    Returns:
        T_err: translation Huber loss
    """
    T_err = F.huber_loss(t, tgt, delta=delta, reduction='none')
    return T_err.mean()


def rot_ang_loss(R, Rgt, eps=1e-6):
    """
    Args:
        R: estimated rotation matrix [B, 3, 3]
        Rgt: ground-truth rotation matrix [B, 3, 3]
    Returns:
        R_err: rotation angular error
    """
    residual = torch.matmul(R.transpose(1, 2), Rgt)
    trace = torch.diagonal(residual, dim1=-2, dim2=-1).sum(-1)
    cosine = (trace - 1) / 2
    R_err = torch.acos(torch.clamp(cosine, -1.0 + eps, 1.0 - eps))  # handle numerical errors and NaNs
    return R_err.mean()


def pose_loss(pred_pose, tgt_pose, eps=1e-6):
    """
    The input pose should be c2w pose, i.e., [B, 4, 4] where the first three rows are rotation and the last row is translation.

    Args:
        pred_pose: predicted pose [B, 4, 4]
        tgt_pose: ground-truth pose [B, 4, 4]
    Returns:
        loss: combined translation and rotation loss
    """
    pred_R = pred_pose[:, :3, :3]
    pred_t = pred_pose[:, :3, 3]

    tgt_R = tgt_pose[:, :3, :3]
    tgt_t = tgt_pose[:, :3, 3]

    # T_err = transl_ang_loss(pred_t, tgt_t, eps)
    T_err = transl_huber_loss(pred_t, tgt_t, delta=1.0)
    R_err = rot_ang_loss(pred_R, tgt_R, eps)

    return T_err + R_err


def compute_pairwise_relative_poses(pred_poses, target_poses):
    """
    Compute pairwise relative camera poses for loss calculation.

    Args:
        pred_poses: [B, V, 4, 4] predicted camera-to-world matrices
        target_poses: [B, V, 4, 4] target camera-to-world matrices

    Returns:
        pred_relative: [B, N, 4, 4] predicted relative poses where N = V*(V-1)
        target_relative: [B, N, 4, 4] target relative poses where N = V*(V-1)
    """
    B, V, _, _ = pred_poses.shape
    device = pred_poses.device

    # Create all pairs (i,j) where i != j
    # This gives us V*(V-1) pairs per batch
    i_indices = torch.arange(V, device=device).repeat_interleave(V - 1)
    j_indices = torch.cat([torch.cat([torch.arange(j, device=device),
                                      torch.arange(j + 1, V, device=device)])
                           for j in range(V)])

    # Get poses for all pairs
    # Shape: [B, N, 4, 4] where N = V*(V-1)
    pred_i = pred_poses[:, i_indices]  # [B, N, 4, 4]
    pred_j = pred_poses[:, j_indices]  # [B, N, 4, 4]
    target_i = target_poses[:, i_indices]  # [B, N, 4, 4]
    target_j = target_poses[:, j_indices]  # [B, N, 4, 4]

    # Compute relative poses: T_rel = T_i^{-1} @ T_j
    # This gives the transformation from camera i to camera j
    pred_i_inv = torch.inverse(pred_i)
    target_i_inv = torch.inverse(target_i)

    pred_relative = torch.bmm(pred_i_inv.view(-1, 4, 4), pred_j.view(-1, 4, 4))
    target_relative = torch.bmm(target_i_inv.view(-1, 4, 4), target_j.view(-1, 4, 4))

    # Reshape back to [B, N, 4, 4]
    N = V * (V - 1)
    pred_relative = pred_relative.view(B, N, 4, 4)
    target_relative = target_relative.view(B, N, 4, 4)

    return pred_relative, target_relative


def create_random_se3_matrix(device='cpu'):
    """Create a random valid SE(3) transformation matrix."""
    # Random rotation using axis-angle
    axis = torch.randn(3, device=device)
    axis = axis / torch.norm(axis)
    angle = torch.rand(1, device=device) * 2 * np.pi

    # Convert to rotation matrix
    K = torch.tensor([[0, -axis[2], axis[1]],
                      [axis[2], 0, -axis[0]],
                      [-axis[1], axis[0], 0]], device=device)
    R = torch.eye(3, device=device) + torch.sin(angle) * K + (1 - torch.cos(angle)) * torch.mm(K, K)

    # Random translation
    t = torch.randn(3, device=device) * 5  # scale translation

    # Build SE(3) matrix
    T = torch.eye(4, device=device)
    T[:3, :3] = R
    T[:3, 3] = t

    return T


def naive_pairwise_computation(poses):
    """Naive implementation for verification - computes pairwise relative poses one by one."""
    B, V, _, _ = poses.shape
    N = V * (V - 1)
    relative_poses = torch.zeros(B, N, 4, 4, device=poses.device, dtype=poses.dtype)

    idx = 0
    for b in range(B):
        for i in range(V):
            for j in range(V):
                if i != j:
                    # T_rel = T_i^{-1} @ T_j
                    T_i_inv = torch.inverse(poses[b, i])
                    T_j = poses[b, j]
                    relative_poses[b, idx] = torch.mm(T_i_inv, T_j)
                    idx += 1
        idx = 0  # Reset for next batch

    return relative_poses


def verify_indexing():
    """Verify that our indexing produces the correct pairs."""
    print("=== Verifying Indexing ===")
    V = 4
    device = 'cpu'

    # Create indices the same way as in the main function
    i_indices = torch.arange(V, device=device).repeat_interleave(V - 1)
    j_indices = torch.cat([torch.cat([torch.arange(j, device=device),
                                      torch.arange(j + 1, V, device=device)])
                           for j in range(V)])

    print(f"For V={V} views:")
    print(f"i_indices: {i_indices.tolist()}")
    print(f"j_indices: {j_indices.tolist()}")

    # Verify we have all pairs except (i,i)
    expected_pairs = [(i, j) for i in range(V) for j in range(V) if i != j]
    actual_pairs = list(zip(i_indices.tolist(), j_indices.tolist()))

    print(f"Expected pairs: {expected_pairs}")
    print(f"Actual pairs:   {actual_pairs}")
    print(f"Pairs match: {expected_pairs == actual_pairs}")
    print(f"Number of pairs: {len(actual_pairs)} (expected: {V * (V - 1)})")
    print()


def verify_relative_pose_correctness():
    """Verify that T_rel = T_i^{-1} @ T_j is computed correctly."""
    print("=== Verifying Relative Pose Computation ===")
    device = 'cpu'

    # Create two specific poses
    T1 = torch.eye(4, device=device)
    T1[:3, 3] = torch.tensor([1, 0, 0], device=device)  # Translation along x

    T2 = torch.eye(4, device=device)
    T2[:3, 3] = torch.tensor([0, 1, 0], device=device)  # Translation along y

    # Manual computation
    T1_inv = torch.inverse(T1)
    T_rel_manual = torch.mm(T1_inv, T2)

    print("T1 (translate by [1,0,0]):")
    print(T1)
    print("\nT2 (translate by [0,1,0]):")
    print(T2)
    print("\nT1_inv:")
    print(T1_inv)
    print("\nT_rel = T1_inv @ T2 (manual):")
    print(T_rel_manual)

    # Using our function
    poses = torch.stack([T1, T2]).unsqueeze(0)  # [1, 2, 4, 4]
    rel_poses, _ = compute_pairwise_relative_poses(poses, poses)

    print(f"\nUsing our function, relative poses shape: {rel_poses.shape}")
    print("T_rel from our function (first pair, T1->T2):")
    print(rel_poses[0, 0])  # First pair should be T1->T2

    # Check if they match
    diff = torch.abs(T_rel_manual - rel_poses[0, 0]).max()
    print(f"\nMax difference: {diff.item():.2e}")
    print(f"Computation correct: {diff < 1e-6}")
    print()


def verify_against_naive_implementation():
    """Compare our vectorized implementation against naive loop-based implementation."""
    print("=== Verifying Against Naive Implementation ===")
    device = 'cpu'

    # Test with random poses
    B, V = 2, 3
    poses = torch.stack([[create_random_se3_matrix(device) for _ in range(V)] for _ in range(B)])

    print(f"Testing with B={B}, V={V}")
    print(f"Input shape: {poses.shape}")

    # Our implementation
    rel_ours, _ = compute_pairwise_relative_poses(poses, poses)

    # Naive implementation
    rel_naive = naive_pairwise_computation(poses)

    print(f"Our output shape: {rel_ours.shape}")
    print(f"Naive output shape: {rel_naive.shape}")

    # Compare
    diff = torch.abs(rel_ours - rel_naive).max()
    print(f"Max difference: {diff.item():.2e}")
    print(f"Implementations match: {diff < 1e-6}")
    print()


def verify_properties():
    """Verify mathematical properties of relative poses."""
    print("=== Verifying Mathematical Properties ===")
    device = 'cpu'

    B, V = 1, 3
    poses = torch.stack([[create_random_se3_matrix(device) for _ in range(V)] for _ in range(B)])

    rel_poses, _ = compute_pairwise_relative_poses(poses, poses)

    print(f"Testing with B={B}, V={V}")

    # Property 1: T_rel should be a valid SE(3) matrix (det=1, orthogonal rotation)
    print("\nProperty 1: Valid SE(3) matrices")
    for i in range(rel_poses.shape[1]):
        T = rel_poses[0, i]
        R_part = T[:3, :3]

        # Check determinant
        det = torch.det(R_part)
        print(f"  Pair {i}: det(R) = {det.item():.6f} (should be ≈ 1)")

        # Check orthogonality
        should_be_identity = torch.mm(R_part, R_part.t())
        identity_error = torch.abs(should_be_identity - torch.eye(3)).max()
        print(f"  Pair {i}: R@R^T error = {identity_error.item():.2e} (should be ≈ 0)")

        # Check bottom row
        bottom_row = T[3, :]
        expected_bottom = torch.tensor([0, 0, 0, 1], dtype=T.dtype)
        bottom_error = torch.abs(bottom_row - expected_bottom).max()
        print(f"  Pair {i}: bottom row error = {bottom_error.item():.2e} (should be ≈ 0)")

    # Property 2: Composition property
    # If we have T_AB = T_A^{-1} @ T_B and T_BC = T_B^{-1} @ T_C
    # Then T_AC should equal T_AB @ T_BC
    print(f"\nProperty 2: Composition property (limited test)")
    T_A = poses[0, 0]
    T_B = poses[0, 1]
    T_C = poses[0, 2]

    # Direct computation
    T_AC_direct = torch.mm(torch.inverse(T_A), T_C)

    # Via composition
    T_AB = torch.mm(torch.inverse(T_A), T_B)
    T_BC = torch.mm(torch.inverse(T_B), T_C)
    T_AC_composed = torch.mm(T_AB, T_BC)

    composition_error = torch.abs(T_AC_direct - T_AC_composed).max()
    print(f"  Composition error: {composition_error.item():.2e} (should be ≈ 0)")
    print()


def verify_gradient_flow():
    """Verify that gradients flow correctly through the computation."""
    print("=== Verifying Gradient Flow ===")
    device = 'cpu'

    B, V = 1, 3
    poses = torch.stack([[create_random_se3_matrix(device) for _ in range(V)] for _ in range(B)])
    poses.requires_grad_(True)

    rel_poses, _ = compute_pairwise_relative_poses(poses, poses)

    # Simple loss: sum of all elements
    loss = rel_poses.sum()
    loss.backward()

    print(f"Input poses shape: {poses.shape}")
    print(f"Input poses require grad: {poses.requires_grad}")
    print(f"Poses grad is not None: {poses.grad is not None}")
    print(f"Poses grad shape: {poses.grad.shape if poses.grad is not None else 'None'}")
    print(f"Poses grad sum: {poses.grad.sum().item() if poses.grad is not None else 'None'}")
    print("Gradient flow: OK" if poses.grad is not None else "Gradient flow: FAILED")
    print()


def run_all_verifications():
    """Run all verification tests."""
    print("VERIFYING PAIRWISE RELATIVE POSES COMPUTATION")
    print("=" * 50)

    verify_indexing()
    verify_relative_pose_correctness()
    verify_against_naive_implementation()
    verify_properties()
    verify_gradient_flow()

    print("All verifications completed!")


if __name__ == "__main__":
    run_all_verifications()