
import numpy as np
import matplotlib.pyplot as plt
from ripser import ripser
from persim import plot_diagrams, bottleneck
from scipy.spatial import distance_matrix
import time 


def puntos_en_esfera(n=400, radio=1.0):
    """
    Genera n puntos uniformemente distribuidos sobre S^2 de radio dado.
    """
    phi = np.random.uniform(0, 2*np.pi, n)
    cos_theta = np.random.uniform(-1, 1, n)
    theta = np.arccos(cos_theta)
    x = radio * np.sin(theta) * np.cos(phi)
    y = radio * np.sin(theta) * np.sin(phi)
    z = radio * np.cos(theta)
    return np.stack((x, y, z), axis=1)


def farthest_point_sampling(X):
    """
    Construye orden de los puntos (perm) agregando el más lejano a los anteriores
    y devuelve los 'exit-times' t_p 
    """
    n = X.shape[0]
    D = distance_matrix(X, X)
    perm = []
    t = np.zeros(n)
    # arrancamos desde el punto 0
    current = 0
    perm.append(current)
    t[current] = np.inf
    min_dist = D[current].copy()
    for _ in range(1, n):
        idx = np.argmax(min_dist)
        perm.append(idx)
        t[idx] = min_dist[idx]
        min_dist = np.minimum(min_dist, D[idx])
    # t_out reordenado en el orden original
    t_out = np.zeros(n)
    t_out[perm[0]] = np.max(t[np.isfinite(t)])
    for j in range(1, n):
        t_out[perm[j]] = t[perm[j]]
    return np.array(perm), t_out

def compare_rips_sparsified(X, gamma, maxdim=2):
    """
    Calcula homología persistente para:
      - nube completa X
      - submuestra N_gamma = {p : t_p > gamma}
    y muestra los dos diagramas y las distancias bottleneck.
    """
    perm, t_p = farthest_point_sampling(X)

    # ===============================
    # Medimos el tiempo de persistencia de la nube completa
    # ===============================
    t0_full = time.time()
    res_full = ripser(X, maxdim=maxdim)
    t1_full = time.time()
    tiempo_full = t1_full - t0_full

    # ===============================
    # Medimos tiempo total de submuestra + persistencia
    # ===============================
    t0_sparse_total = time.time()
    sel = np.where(t_p > gamma)[0]  # selección de submuestra
    Xs = X[sel]
    res_sparse = ripser(Xs, maxdim=maxdim)
    t1_sparse_total = time.time()
    tiempo_sparse_total = t1_sparse_total - t0_sparse_total

    dgms_full = res_full['dgms']
    dgms_sparse = res_sparse['dgms']

    # Bottleneck 
    b0 = bottleneck(dgms_full[0], dgms_sparse[0]) if len(dgms_full) > 0 else None
    b1 = bottleneck(dgms_full[1], dgms_sparse[1]) if len(dgms_full) > 1 else None
    b2 = bottleneck(dgms_full[2], dgms_sparse[2]) if len(dgms_full) > 2 else None

    # Gráfico 
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    plot_diagrams(dgms_full, ax=axes[0], show=False, title="Rips completo")
    plot_diagrams(dgms_sparse, ax=axes[1], show=False, title=f"Sparse Rips (γ={gamma:.3f})")

    # Información 
    total_pts = X.shape[0]
    sparse_pts = Xs.shape[0]
    text = (
        f"Nube original: {total_pts} puntos\n"
        f"Submuestra (N_γ): {sparse_pts} puntos\n"
        f"Bottleneck H0 = {b0:.4f}   H1 = {b1:.4f}  H2 = {b2:.4f}\n"
        f"Tiempo full = {tiempo_full:.3f}s   Tiempo submuestra = {tiempo_sparse_total:.3f}s"
    )
    fig.text(0.5, 0.02, text, ha='center', va='bottom', fontsize=11)
    plt.tight_layout(rect=[0, 0.08, 1, 1])
    plt.show()

    return {
        "gamma": gamma,
        "n_full": total_pts,
        "n_sparse": sparse_pts,
        "bottleneck_H0": b0,
        "bottleneck_H1": b1,
        "bottleneck_H2": b2,
        "tiempo_full": tiempo_full,
        "tiempo_sparse_total": tiempo_sparse_total
    }

# ===============================
# Corremos el programa. Tiene el ejemplo incorporado, toma puntos en la esfera, le agrega ruido, hace el submuestreo  y compara 
# ===============================
if __name__ == "__main__":
    Y = puntos_en_esfera(600, 1.0)
    ruido = 0.05 * np.random.randn(*Y.shape)  
    X = Y + ruido

    _, t_p = farthest_point_sampling(X)

    # Elegimos gamma como percentil 75 (por ejemplo)
    gamma = np.percentile(t_p[t_p > 0], 75)

    results = compare_rips_sparsified(X, gamma)
    print("\n--- Resultados ---")
    for k, v in results.items():
        print(f"{k}: {v}")
