import numpy as np
import kmapper as km
from sklearn.cluster import DBSCAN
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# 1. Generar 2000 puntos aleatorios en la esfera S^2
np.random.seed(42)
n_points = 2000

points = np.random.randn(n_points, 3)
points = points / np.linalg.norm(points, axis=1)[:, np.newaxis]

print(f"Puntos generados: {points.shape}")

# ===== Filtro 1: Coordenada z =====
print("\n" + "="*60)
print("EXPERIMENTO 1: Filtro = coordenada z")
print("="*60)

def filter_z(data):
    return data[:, 2].reshape(-1, 1)

lens_z = filter_z(points)

mapper = km.KeplerMapper(verbose=0)
graph_z = mapper.map(
    lens_z,
    points,
    cover=km.Cover(n_cubes=15, perc_overlap=0.3),
    clusterer=DBSCAN(eps=0.3, min_samples=5)
)

print(f"Nodos: {len(graph_z['nodes'])}, Enlaces: {len(graph_z['links'])}")

mapper.visualize(
    graph_z,
    path_html="mapper_coordz.html",
    title="Filtro: coordenada z ",
    color_function=lens_z.flatten(),
    color_function_name='Coordenada z'
)
print("→ Grafo guardado en: mapper_coordz.html")


# ===== Filtro 2: (cos θ, sin θ) - 
print("\n" + "="*60)
print("EXPERIMENTO 2: Filtro = (cos θ, sin θ) ")
print("="*60)

def filter_circle_embedding(data):
    
    theta = np.arctan2(data[:, 1], data[:, 0])
    cos_theta = np.cos(theta)
    sin_theta = np.sin(theta)
    return np.column_stack([cos_theta, sin_theta])

lens_circle = filter_circle_embedding(points)

graph_circle = mapper.map(
    lens_circle,
    points,
    cover=km.Cover(n_cubes=15, perc_overlap=0.5), 
    clusterer=DBSCAN(eps=0.3, min_samples=5)
)

print(f"Nodos: {len(graph_circle['nodes'])}, Enlaces: {len(graph_circle['links'])}")

# Para colorear usamos el ángulo original
theta_for_color = np.arctan2(points[:, 1], points[:, 0])
mapper.visualize(
    graph_circle,
    path_html="mapper_angulo.html",
    title="Filtro: (cos θ, sin θ) ",
    color_function=theta_for_color,
    color_function_name='Ángulo θ'
)
print("→ Grafo guardado en: mapper_angulo.html")




# Toro =====
print("\n" + "="*60)
print("EXPERIMENTO 3: TORO ")
print("="*60)

# Generar toro: (R + r*cos(v))*cos(u), (R + r*cos(v))*sin(u), r*sin(v)
R, r = 2, 1  # Radio mayor y menor
n_toro = 2000

u = np.random.uniform(0, 2*np.pi, n_toro)
v = np.random.uniform(0, 2*np.pi, n_toro)

toro = np.column_stack([
    (R + r*np.cos(v)) * np.cos(u),
    (R + r*np.cos(v)) * np.sin(u),
    r * np.sin(v)
])

print(f"Puntos en el toro: {toro.shape}")

# Filtro: (cos u, sin u) - el ángulo alrededor del eje z
lens_toro = np.column_stack([np.cos(u), np.sin(u)])

graph_toro = mapper.map(
    lens_toro,
    toro,
    cover=km.Cover(n_cubes=15, perc_overlap=0.5),
    clusterer=DBSCAN(eps=0.5, min_samples=5)
)

print(f"Nodos: {len(graph_toro['nodes'])}, Enlaces: {len(graph_toro['links'])}")

mapper.visualize(
    graph_toro,
    path_html="mapper_toro.html",
    title="Filtro: ángulo u del toro",
    color_function=u,
    color_function_name='Ángulo u'
)
print("→ Grafo guardado en: mapper_toro.html")

# ===== VISUALIZACIÓN =====
fig = plt.figure(figsize=(15, 10))

# Subplot 1: Esfera con filtro z
ax1 = fig.add_subplot(221, projection='3d')
scatter1 = ax1.scatter(points[:, 0], points[:, 1], points[:, 2], 
                       c=lens_z.flatten(), cmap='viridis', s=3, alpha=0.6)
ax1.set_title('Esfera: filtro z → intervalo')
ax1.set_xlabel('X'); ax1.set_ylabel('Y'); ax1.set_zlabel('Z')
plt.colorbar(scatter1, ax=ax1, shrink=0.5)

# Subplot 2: Esfera con filtro circular
ax2 = fig.add_subplot(222, projection='3d')
theta_color = np.arctan2(points[:, 1], points[:, 0])
scatter2 = ax2.scatter(points[:, 0], points[:, 1], points[:, 2], 
                       c=theta_color, cmap='hsv', s=3, alpha=0.6)
ax2.set_title('Esfera: filtro (cos θ, sin θ) → ciclo')
ax2.set_xlabel('X'); ax2.set_ylabel('Y'); ax2.set_zlabel('Z')
plt.colorbar(scatter2, ax=ax2, shrink=0.5)

# Subplot 4: Toro
ax4 = fig.add_subplot(224, projection='3d')
scatter4 = ax4.scatter(toro[:, 0], toro[:, 1], toro[:, 2], 
                       c=u, cmap='hsv', s=3, alpha=0.6)
ax4.set_title('Toro: filtro u → ciclo perfecto')
ax4.set_xlabel('X'); ax4.set_ylabel('Y'); ax4.set_zlabel('Z')
plt.colorbar(scatter4, ax=ax4, shrink=0.5)

plt.tight_layout()
plt.savefig('comparacion_completa.png', dpi=150, bbox_inches='tight')
print("\n→ Visualización guardada en: comparacion_completa.png")
