import numpy as np
import kmapper as km
from sklearn.cluster import DBSCAN
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# 1. Generar 5000 puntos aleatorios en la esfera S^2
np.random.seed(42)
n_points = 5000

# Método: generar puntos en 3D y normalizar
points = np.random.randn(n_points, 3)
# Normalizar para que estén en la esfera unitaria
points = points / np.linalg.norm(points, axis=1)[:, np.newaxis]

# 2. Definir función de filtro: usaremos la coordenada z (altura)
def filter_function(data):
    """Proyección en el eje z"""
    return data[:, 2].reshape(-1, 1)

# Aplicar la función de filtro
lens = filter_function(points)
print(f"\nValores del filtro - min: {lens.min():.3f}, max: {lens.max():.3f}")

# 3. Visualizar los datos originales coloreados por el filtro
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
scatter = ax.scatter(points[:, 0], points[:, 1], points[:, 2], 
                     c=lens.flatten(), cmap='viridis', s=10)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title('Esfera S² coloreada por coordenada z')
plt.colorbar(scatter, label='Valor del filtro (z)')
plt.tight_layout()
plt.savefig('sphere_original.png', dpi=150, bbox_inches='tight')
print("\nGráfico de la esfera guardado como 'sphere_original.png'")

# 4. Crear el mapper
mapper = km.KeplerMapper(verbose=1)

# 5. Construir el grafo usando KMapper
# - cover: define cómo cubrir el rango del filtro
#   n_cubes: número de intervalos
#   perc_overlap: porcentaje de solapamiento entre intervalos
# - clusterer: algoritmo de clustering (usamos DBSCAN, les recuerdo que pueden usar KMEANS,
# busquen otros métodos en manual mapper )
graph = mapper.map(
    lens,
    points,
    cover=km.Cover(n_cubes=15, perc_overlap=0.3),
    clusterer=DBSCAN(eps=0.3, min_samples=5)
)

# 6. Visualizar el grafo
# El color de cada nodo será el promedio de los valores del filtro
# de los puntos que contiene (usamos la función del filtro para colorear)
mapper.visualize(
    graph,
    path_html="mapper_graph.html",
    title="Mapper de S² usando coordenada z",
    color_function=lens.flatten(),  # Colorear por valores del filtro
    color_function_name='Coordenada z'
)

print("\nGráfico  guardado como 'mapper_graph.html'")
