

import numpy as np
import matplotlib.pyplot as plt
from ripser import ripser
from persim import plot_diagrams, bottleneck
from scipy.spatial import distance_matrix
import time
import gudhi

def make_noisy_circle(n=500, radius=1.0, noise=0.05, seed=0):
    rng = np.random.default_rng(seed)
    thetas = rng.random(n) * 2 * np.pi
    pts = np.column_stack([radius * np.cos(thetas), radius * np.sin(thetas)])
    pts += rng.normal(scale=noise, size=pts.shape)
    return pts

##FPS y t_p

def farthest_point_sampling(X):
    
    n = X.shape[0]
    D = distance_matrix(X, X)
    perm = []
    t = np.zeros(n)
    # arrancamos desde el punto 0
    current = 0
    perm.append(current)
    t[current] = np.inf
    min_dist = D[current].copy()
    for _ in range(1, n):
        idx = np.argmax(min_dist)
        perm.append(idx)
        t[idx] = min_dist[idx]
        min_dist = np.minimum(min_dist, D[idx])
    # t_out reordenado en el orden original
    t_out = np.zeros(n)
    t_out[perm[0]] = np.max(t[np.isfinite(t)])
    for j in range(1, n):
        t_out[perm[j]] = t[perm[j]]
    return np.array(perm), t_out

def compare_rips_sparsified(X, gamma, maxdim=1):
    
    perm, t_p = farthest_point_sampling(X)

    # ===============================
    # calculamos y medimos tiempo de persistencia de la nube completa
    # ===============================
    t0_full = time.time()
    res_full = ripser(X, maxdim=maxdim)
    t1_full = time.time()
    tiempo_full = t1_full - t0_full

    # ===============================
    # calculamos y medimos tiempo total de submuestra + persistencia de la submuestra
    # ===============================
    t0_sparse_total = time.time()
    sel = np.where(t_p > gamma)[0]  # selección de submuestra
    Xs = X[sel]
    res_sparse = ripser(Xs, maxdim=maxdim)
    t1_sparse_total = time.time()
    tiempo_sparse_total = t1_sparse_total - t0_sparse_total
    
    t0g= time.time()
    fps1=gudhi.subsampling.choose_n_farthest_points(X, nb_points=375, starting_point=None, fast=True)
    fps=np.array(fps1)
    res_gud=ripser(fps, maxdim=maxdim)
    t1g=time.time()
    tiempo_gd=t1g-t0g

    dgms_full = res_full['dgms']
    dgms_sparse = res_sparse['dgms']
    dgms_gd=res_gud['dgms']

    # Bottleneck 
    b0 = bottleneck(dgms_full[0], dgms_sparse[0]) if len(dgms_full) > 0 else None
    b1 = bottleneck(dgms_full[1], dgms_sparse[1]) if len(dgms_full) > 1 else None
    b0p = bottleneck(dgms_full[0], dgms_gd[0]) if len(dgms_full) > 0 else None
    b1p = bottleneck(dgms_full[1], dgms_gd[1]) if len(dgms_full) > 1 else None
    

    # Gráfico 
    fig, axes = plt.subplots(1, 3, figsize=(10, 5))
    plot_diagrams(dgms_full, ax=axes[0], show=False, title="Rips completo")
    plot_diagrams(dgms_sparse, ax=axes[1], show=False, title=f"Sparse Rips (γ={gamma:.3f})")
    plot_diagrams(dgms_gd, ax=axes[2], show=False, title=f"Sparse gudhi")
    # Información 
    total_pts = X.shape[0]
    sparse_pts = Xs.shape[0]
    text = (
        f"Nube original: {total_pts} puntos\n"
        f"Submuestra (N_γ): {sparse_pts} puntos\n"
        f"Bottleneck H0 = {b0:.4f}   H1 = {b1:.4f}  \n   H0p = {b0p:.4f}   H1p = {b1p:.4f}\n"
        f"Tiempo full = {tiempo_full:.3f}s   Tiempo submuestra = {tiempo_sparse_total:.3f}s tiempo gd ={tiempo_gd:.3f}s "
    )
    fig.text(0.5, 0.02, text, ha='center', va='bottom', fontsize=11)
    plt.tight_layout(rect=[0, 0.08, 1, 1])
    plt.show()

    return {
        "gamma": gamma,
        "n_full": total_pts,
        "n_sparse": sparse_pts,
        "bottleneck_H0": b0,
        "bottleneck_H1": b1,
       
        "tiempo_full": tiempo_full,
        "tiempo_sparse_total": tiempo_sparse_total
    }

# ===============================
# 4. Corremos el programa con el ejemplo
# ===============================
if __name__ == "__main__":
    X = make_noisy_circle(n=1500, radius=1.0, noise=0.15, seed=42)
    _, t_p = farthest_point_sampling(X)

   
    gamma = np.percentile(t_p[t_p > 0], 75)

    results = compare_rips_sparsified(X, gamma)
    print("\n--- Resultados ---")
    for k, v in results.items():
        print(f"{k}: {v}")
