"""
Análisis Topológico de Datos de Diabetes usando KMapper

Dataset: 442 pacientes
- 10 Variables (datos en R^10)
- Objetivo: entender la estructura de los datos y cómo se relaciona con diabetes
coordenadas corresponden a:
edad
sexo
masa corporal
presion sanguinea
s1 Colesterol total
s2 LDL ("colesterol malo")
s3 HDL ("colesterol bueno")
s4 Ratio colesterol/HDL
s5 Log de triglicéridos
s6 Glucosa en sangre


Filtro (función a R): Usamos PCA1 (componente principal)

- Captura la dirección de máxima variabilidad

la progresion de la enfermedad (que viene con los datos) no se usa para definir la funcion de filtro sino para interpretar los resultados
vamos a colorear los nodos del grafo obtenido por mapper segun progresion de la enfermedad para observar correlacion entre variables y progreso de enfermedad

"""

import numpy as np
import pandas as pd
import kmapper as km
from sklearn.cluster import DBSCAN
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns

# Cargar datos de diabetes
from sklearn.datasets import load_diabetes

print("="*70)
print("ANÁLISIS TOPOLÓGICO DE DATOS DE DIABETES CON MAPPER")
print("="*70)

# Cargar dataset
data = load_diabetes()
X = data.data  # Features
y = data.target  # Progresión de diabetes (variable continua)

print(f"\nDatos cargados:")
print(f"  - Muestras: {X.shape[0]}")
print(f"  - Variables: {X.shape[1]}")

# Nombres de las variables
feature_names = data.feature_names
print(f"\nVariables disponibles:")
for i, name in enumerate(feature_names):
    print(f"  {i+1}. {name}")

# Crear DataFrame para análisis
df = pd.DataFrame(X, columns=feature_names)
df['progression'] = y

# Estadísticas descriptivas
print("\n" + "-"*70)
print("ESTADÍSTICAS DESCRIPTIVAS")
print("-"*70)
print(df.describe().round(2))

# ===== PREPROCESAMIENTO =====
print("\n" + "="*70)
print("PASO 1: PREPROCESAMIENTO")
print("="*70)

# Estandarizar (importante para PCA y DBSCAN)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# ===== ANÁLISIS PCA =====
print("\n" + "="*70)
print("PASO 2: ANÁLISIS DE COMPONENTES PRINCIPALES (PCA)")
print("="*70)

pca = PCA()
X_pca = pca.fit_transform(X_scaled)

# Varianza explicada
var_exp = pca.explained_variance_ratio_
var_cum = np.cumsum(var_exp)

print(f"\nVarianza explicada por componente:")
for i in range(min(5, len(var_exp))):
    print(f"  PC{i+1}: {var_exp[i]*100:.1f}% (acumulado: {var_cum[i]*100:.1f}%)")

# Gráfico de varianza explicada
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.bar(range(1, len(var_exp)+1), var_exp * 100)
ax1.set_xlabel('Componente Principal')
ax1.set_ylabel('Varianza Explicada (%)')
ax1.set_title('Varianza por Componente')
ax1.set_xticks(range(1, len(var_exp)+1))

ax2.plot(range(1, len(var_cum)+1), var_cum * 100, 'bo-')
ax2.axhline(y=80, color='r', linestyle='--', label='80% varianza')
ax2.set_xlabel('Número de Componentes')
ax2.set_ylabel('Varianza Acumulada (%)')
ax2.set_title('Varianza Acumulada')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('pca_variance.png', dpi=150, bbox_inches='tight')
print("\n✓ Gráfico guardado: pca_variance.png")

# Interpretación del PC1
print(f"\n{'-'*70}")
print("INTERPRETACIÓN DEL COMPONENTE PRINCIPAL 1 (PC1)")
print("-"*70)
loadings = pca.components_[0]
loading_df = pd.DataFrame({
    'Variable': feature_names,
    'Loading': loadings,
    'Abs_Loading': np.abs(loadings)
}).sort_values('Abs_Loading', ascending=False)

print("\nContribuciones al PC1 (ordenadas por importancia):")
print(loading_df.to_string(index=False))

print(f"\n💡 Interpretación:")
top_var = loading_df.iloc[0]['Variable']
print(f"   PC1 está dominado por: {top_var}")

# ===== DEFINIR FILTRO =====
print("\n" + "="*70)
print("PASO 3: DEFINIR FUNCIÓN DE FILTRO")
print("="*70)

# Usaremos PC1 como filtro
lens = X_pca[:, 0].reshape(-1, 1)

print(f"Filtro elegido: Componente Principal 1 (PC1)")
print(f"  - Captura {var_exp[0]*100:.1f}% de la varianza")
print(f"  - Rango: [{lens.min():.2f}, {lens.max():.2f}]")


# ===== CORREMOS MAPPER =====

mapper = km.KeplerMapper(verbose=1)

print("\nParámetros de Mapper:")
print("  - Cover: 10 cubos, 50% overlap")
print("  - Clustering: DBSCAN (eps=1.5, min_samples=2)")


graph = mapper.map(
    lens,
    X_scaled,
    cover=km.Cover(n_cubes=10, perc_overlap=0.5),
    clusterer=DBSCAN(eps=1.5, min_samples=2)
)






# Coloreamos por progresión de diabetes - ese dato viene con la muestra (variable "y"
#  que no se usa para hacer el grafo, solo para colorear e interpretar 
# en cada nodo cuál es la progresión de la enfermedad)
# KMapper automáticamente hace el tamaño de nodo proporcional al número de puntos 

mapper.visualize(
    graph,
    path_html="mapper_diabetes_progression.html",
    title="Análisis Topológico de Diabetes - Color: Progresión de la enfermedad",
    color_function=y,
    color_function_name='Progresión de Diabetes'
)
print("\n Grafo guardado: mapper_diabetes_progression.html")


# ===== ANÁLISIS DE NODOS EXTREMOS =====
print("\n" + "="*70)
print("PASO 6: ANÁLISIS DE GRUPOS EXTREMOS")
print("="*70)

if len(graph['nodes']) == 0:
    print("\n⚠️ No hay nodos para analizar. El grafo está vacío.")
    print("\n💡 Sugerencias para solucionar:")
    print("   1. Aumentar 'eps' en DBSCAN (ej: eps=2.0)")
    print("   2. Reducir 'min_samples' en DBSCAN (ej: min_samples=2)")
    print("   3. Aumentar 'perc_overlap' en Cover (ej: 0.6)")
else:
    # Encontrar nodos con progresión alta y baja
    node_avg_progression = {}
    for node_id, members in graph['nodes'].items():
        node_avg_progression[node_id] = y[members].mean()

    sorted_nodes = sorted(node_avg_progression.items(), key=lambda x: x[1])

    n_show = min(3, len(sorted_nodes))
    
    print(f"\n📊 Top {n_show} nodos con MENOR progresión (más sanos):")
    for i, (node_id, avg_prog) in enumerate(sorted_nodes[:n_show]):
        members = list(graph['nodes'][node_id])
        print(f"\n{i+1}. Nodo {node_id}:")
        print(f"   - Pacientes: {len(members)}")
        print(f"   - Progresión promedio: {avg_prog:.1f}")
        print(f"   - PC1 promedio: {lens[members].mean():.2f}")
        # Características promedio
        avg_features = X_scaled[members].mean(axis=0)
        top_features = np.argsort(np.abs(avg_features))[-3:][::-1]
        print(f"   - Características destacadas:")
        for idx in top_features:
            print(f"     • {feature_names[idx]}: {avg_features[idx]:.2f} (std)")

    print("\n" + "-"*70)
    print(f"\n📊 Top {n_show} nodos con MAYOR progresión (más severos):")
    for i, (node_id, avg_prog) in enumerate(sorted_nodes[-n_show:]):
        members = list(graph['nodes'][node_id])
        print(f"\n{i+1}. Nodo {node_id}:")
        print(f"   - Pacientes: {len(members)}")
        print(f"   - Progresión promedio: {avg_prog:.1f}")
        print(f"   - PC1 promedio: {lens[members].mean():.2f}")
        avg_features = X_scaled[members].mean(axis=0)
        top_features = np.argsort(np.abs(avg_features))[-3:][::-1]
        print(f"   - Características destacadas:")
        for idx in top_features:
            print(f"     • {feature_names[idx]}: {avg_features[idx]:.2f} (std)")

