import numpy as np
import matplotlib.pyplot as plt
from ripser import ripser
from persim import plot_diagrams, bottleneck

# ===============================
# 1. Funciones para generar nubes
# ===============================

def puntos_en_esfera(n=400, radio=1.0):
    """
    Genera n puntos uniformemente distribuidos sobre S^2 de radio dado.
    """
    phi = np.random.uniform(0, 2*np.pi, n)
    cos_theta = np.random.uniform(-1, 1, n)
    theta = np.arccos(cos_theta)
    x = radio * np.sin(theta) * np.cos(phi)
    y = radio * np.sin(theta) * np.sin(phi)
    z = radio * np.cos(theta)
    return np.stack((x, y, z), axis=1)

def puntos_en_toro(n=400, R=1.0, r=0.5):
    """
    Genera n puntos uniformemente distribuidos sobre un toro parametrizado:
    (R + r*cos(v)) cos(u), (R + r*cos(v)) sin(u), r*sin(v)
    """
    u = np.random.uniform(0, 2*np.pi, n)
    v = np.random.uniform(0, 2*np.pi, n)
    x = (R + r*np.cos(v)) * np.cos(u)
    y = (R + r*np.cos(v)) * np.sin(u)
    z = r * np.sin(v)
    return np.stack((x, y, z), axis=1)

# ===============================
# 2. Generar las tres muestras
# ===============================

X1 = puntos_en_esfera(400, 1.0)
X2 = puntos_en_esfera(200, 1.0)
X3 = puntos_en_toro(400, R=1.0, r=0.5)  # diámetro aprox. 2, igual que esfera

# ===============================
# 3. Calcular homología persistente
# ===============================

r1 = ripser(X1, maxdim=2, thresh=1)
r2 = ripser(X2, maxdim=2, thresh=1)
r3 = ripser(X3, maxdim=2, thresh=1)

dgms1, dgms2, dgms3 = r1['dgms'], r2['dgms'], r3['dgms']

# ===============================
# 4. Calcular distancias Bottleneck
# ===============================

dist_12 = [bottleneck(dgms1[i], dgms2[i]) for i in range(3)]
dist_13 = [bottleneck(dgms1[i], dgms3[i]) for i in range(3)]
dist_23 = [bottleneck(dgms2[i], dgms3[i]) for i in range(3)]

# ===============================
# 5. Graficar los diagramas
# ===============================

fig, axes = plt.subplots(1, 3, figsize=(14, 4))
titles = ['Muestra 1 (Esfera)', 'Muestra 2 (Esfera)', 'Muestra 3 (Toro)']

for ax, dgms, title in zip(axes, [dgms1, dgms2, dgms3], titles):
    plot_diagrams(dgms, show=False, ax=ax)
    ax.set_title(title)

plt.tight_layout(rect=[0, 0.15, 1, 1])

# ===============================
# 6. Mostrar distancias debajo
# ===============================

texto = (
    "Distancias Bottleneck (por dimensión)\n\n"
    f"Entre Esfera1 y Esfera2: {dist_12}\n"
    f"Entre Esfera1 y Toro:    {dist_13}\n"
    f"Entre Esfera2 y Toro:    {dist_23}"
)
plt.figtext(0.5, 0.08, texto, ha='center', va='center', fontsize=10, family='monospace')

plt.show()
