import numpy as np
import matplotlib.pyplot as plt
from scipy.sparse import lil_matrix
from scipy.spatial.distance import pdist, squareform
from collections import defaultdict
from itertools import combinations
import time
import gudhi


plt.rcParams['text.usetex'] = True


def extract_complex_at_filtration(simplex_tree, filt_value, max_dim):
    """Extrae todos los símplices con filtración <= filt_value."""
    complex_data = defaultdict(list)
    for simplex, filt in simplex_tree.get_filtration():
        if filt <= filt_value:
            dim = len(simplex) - 1
            if dim <= max_dim:
                complex_data[dim].append(tuple(sorted(simplex)))
    return complex_data

def count_simplices(complex_data):
    """Cuenta el número total de símplices."""
    return sum(len(simplices) for simplices in complex_data.values())


def compute_graph_core(edge_list, n_vertices):
    """Calcula el core de un grafo mediante colapsos fuertes."""
    A = np.zeros((n_vertices, n_vertices), dtype=bool)
    for i, j in edge_list:
        A[i, j] = True
        A[j, i] = True
    np.fill_diagonal(A, True)
    
    alive = np.ones(n_vertices, dtype=bool)
    removed_any = True
    
    while removed_any:
        removed_any = False
        alive_idx = np.where(alive)[0]
        deg_alive = A[np.ix_(alive_idx, alive_idx)].sum(axis=1)
        
        for v in alive_idx[np.argsort(deg_alive)]:
            if not alive[v]:
                continue
            
            N_v = A[v] & alive
            if np.count_nonzero(N_v) <= 1:
                continue
            
            candidates = [w for w in np.where(N_v)[0] if w != v and alive[w]]
            
            for w in candidates:
                if np.all(A[v][N_v] <= A[w][N_v]) and np.all(N_v <= A[w]):
                    alive[v] = False
                    removed_any = True
                    break
    
    return list(np.where(alive)[0])

def clique_complex(graph_edges, vertices, max_dim):
    """Construye el clique complex de un grafo."""
    complex_data = defaultdict(set)
    
    for v in vertices:
        complex_data[0].add((v,))
    
    if max_dim == 0:
        return {dim: list(s) for dim, s in complex_data.items()}
    
    adj = defaultdict(set)
    for i, j in graph_edges:
        adj[i].add(j)
        adj[j].add(i)
    
    for i, j in graph_edges:
        complex_data[1].add(tuple(sorted([i, j])))
    
    if max_dim == 1:
        return {dim: list(s) for dim, s in complex_data.items()}
    
    all_cliques = []
    
    def bron_kerbosch(R, P, X):
        if not P and not X:
            if len(R) >= 2:
                all_cliques.append(sorted(R))
            return
        
        pivot = max(P | X, key=lambda v: len(P & adj[v])) if (P | X) else None
        
        if pivot:
            to_process = P - adj[pivot]
        else:
            to_process = P.copy()
        
        for v in list(to_process):
            bron_kerbosch(R | {v}, P & adj[v], X & adj[v])
            P.remove(v)
            X.add(v)
    
    bron_kerbosch(set(), set(vertices), set())
    
    for clique in all_cliques:
        n = len(clique)
        for k in range(2, min(n + 1, max_dim + 2)):
            for face in combinations(clique, k):
                simplex = tuple(sorted(face))
                dim = len(simplex) - 1
                if dim <= max_dim:
                    complex_data[dim].add(simplex)
    
    return {dim: list(s) for dim, s in complex_data.items()}

def core_flag_complex(complex_data, max_dim):
    """Calcula el core de un complejo simplicial."""
    vertices = [s[0] for s in complex_data.get(0, [])]
    edges = complex_data.get(1, [])
    
    if not vertices:
        return {dim: [] for dim in range(max_dim + 1)}
    
    vertex_to_idx = {v: i for i, v in enumerate(sorted(vertices))}
    idx_to_vertex = {i: v for v, i in vertex_to_idx.items()}
    n_vertices = len(vertices)
    
    edge_list = [(vertex_to_idx[min(e)], vertex_to_idx[max(e)]) for e in edges]
    
    core_indices = compute_graph_core(edge_list, n_vertices)
    core_vertices = [idx_to_vertex[i] for i in core_indices]
    core_vertex_set = set(core_vertices)
    core_edges = [e for e in edges if e[0] in core_vertex_set and e[1] in core_vertex_set]
    
    return clique_complex(core_edges, core_vertices, max_dim)


def build_boundary_matrix(complex_data, dim):
    """Construye la matriz de frontera."""
    if dim == 0 or dim not in complex_data or (dim - 1) not in complex_data:
        return None
    
    simplices_k_minus_1 = complex_data[dim - 1]
    simplices_k = complex_data[dim]
    
    n_rows = len(simplices_k_minus_1)
    n_cols = len(simplices_k)
    
    if n_rows == 0 or n_cols == 0:
        return None
    
    simplex_to_row = {s: i for i, s in enumerate(simplices_k_minus_1)}
    boundary = lil_matrix((n_rows, n_cols), dtype=np.int8)
    
    for col, simplex in enumerate(simplices_k):
        vertices = list(simplex)
        for i in range(len(vertices)):
            face = tuple(vertices[:i] + vertices[i+1:])
            if face in simplex_to_row:
                boundary[simplex_to_row[face], col] = 1
    
    return boundary.tocsr()

def rank_mod2(matrix):
    """Calcula el rango sobre F_2."""
    if matrix is None or matrix.shape[0] == 0 or matrix.shape[1] == 0:
        return 0
    
    M = matrix.toarray() % 2
    n_rows, n_cols = M.shape
    rank = 0
    row = 0
    
    for col in range(n_cols):
        if row >= n_rows:
            break
        
        pivot_row = None
        for i in range(row, n_rows):
            if M[i, col] == 1:
                pivot_row = i
                break
        
        if pivot_row is None:
            continue
        
        if pivot_row != row:
            M[[row, pivot_row]] = M[[pivot_row, row]]
        
        for i in range(n_rows):
            if i != row and M[i, col] == 1:
                M[i] = (M[i] + M[row]) % 2
        
        rank += 1
        row += 1
    
    return rank

def compute_betti_numbers(complex_data, max_dim):
    
    def rank_mod2_array(mat):
       
        if mat is None:
            return 0
        A = mat.toarray() if not isinstance(mat, np.ndarray) else mat.copy()
        A = (A % 2).astype(np.uint8)
        m, n = A.shape
        rank = 0
        row = 0
        for col in range(n):
            if row >= m:
                break
            pivot = None
            for r in range(row, m):
                if A[r, col] == 1:
                    pivot = r
                    break
            if pivot is None:
                continue
            if pivot != row:
                A[[pivot, row]] = A[[row, pivot]]
            # XOR elimination
            for r in range(m):
                if r != row and A[r, col] == 1:
                    A[r, :] ^= A[row, :]
            rank += 1
            row += 1
        return rank

    betti = []
    
    for k in range(max_dim):
        num_k_simplices = len(complex_data.get(k, []))
        
        if k == 0:
            dim_ker_k = num_k_simplices
        else:
            boundary_k = build_boundary_matrix(complex_data, k)
            if boundary_k is None:
                dim_ker_k = num_k_simplices
            else:
                rank_k = rank_mod2_array(boundary_k)
                dim_ker_k = num_k_simplices - rank_k
        
        boundary_k_plus_1 = build_boundary_matrix(complex_data, k + 1)
        if boundary_k_plus_1 is None:
            rank_k_plus_1 = 0
        else:
            rank_k_plus_1 = rank_mod2_array(boundary_k_plus_1)
        
        beta_k = dim_ker_k - rank_k_plus_1
        
        betti.append(int(beta_k))
    
    return betti


def compute_betti_and_cores(simplex_tree, max_dim, n_samples):
    """Calcula homología persistente y números de Betti de los cores."""
    print("Calculando homología persistente...")
    simplex_tree.compute_persistence()
    
    filtration_values_all = np.unique([filt for simplex, filt in simplex_tree.get_filtration()])
    epsilon_min = filtration_values_all[0]
    epsilon_max = filtration_values_all[-1]
    
    filtration_values = np.linspace(epsilon_min, epsilon_max, n_samples)
    
    print(f"Rango de filtración: [{epsilon_min:.6f}, {epsilon_max:.6f}]\n")
    
    betti_numbers_core = {dim: [] for dim in range(max_dim)}
    simplex_counts = []
    
    start_time = time.time()
    
    for idx, filt_val in enumerate(filtration_values):
        print(f"Muestra {idx + 1}/{len(filtration_values)}: filtración = {filt_val:.6f}")
        
        complex_i = extract_complex_at_filtration(simplex_tree, filt_val, max_dim)
        n_simplices_original = count_simplices(complex_i)
        
        core_i = core_flag_complex(complex_i, max_dim)
        n_simplices_core = count_simplices(core_i)
        
        simplex_counts.append((n_simplices_original, n_simplices_core))
        
        betti_list = compute_betti_numbers(core_i, max_dim)
        
        betti_str = ", ".join([f"betti_{d}={betti_list[d]}" for d in range(max_dim)])
        print(f"  Betti: {betti_str}")
        print(f"  Símplices: K_i={n_simplices_original}, core={n_simplices_core}\n")
        
        for dim in range(max_dim):
            betti_numbers_core[dim].append(betti_list[dim])
    
    computation_time = time.time() - start_time
    
    print("="*70)
    print(f" Tiempo: {computation_time:.3f}s")
    print("="*70 + "\n")
    
    for dim in range(max_dim):
        betti_numbers_core[dim] = np.array(betti_numbers_core[dim])
    
    return filtration_values, betti_numbers_core, simplex_counts, computation_time


def plot_results(simplex_tree, filtration_values, betti_numbers_core, simplex_counts, computation_time):
    """Grafica diagrama de persistencia y números de Betti."""
    colors = ['red', 'blue', 'green', 'orange', 'purple', 'brown', 'pink', 'gray', 
              'cyan', 'magenta', 'yellow', 'olive', 'navy', 'teal']
    
    max_dim = len(betti_numbers_core)
    
    
    n_cols_right = int(np.ceil(max_dim / 2)) if max_dim > 0 else 1
    fig_width = 6 + 2 * n_cols_right  
    fig_height = 6  
    fig = plt.figure(figsize=(fig_width, fig_height))
    width_ratios = [2] + [1] * n_cols_right
    gs = fig.add_gridspec(2, 1 + n_cols_right, width_ratios=width_ratios)
    
   
    ax_pers = fig.add_subplot(gs[:, 0])
    
    try:
        gudhi.plot_persistence_diagram(simplex_tree.persistence(), axes=ax_pers, legend=True)
        ax_pers.set_title('Diagrama de Persistencia', fontsize=12, fontweight='bold')
    except Exception as e:
        print(f"Advertencia: {e}")
        print("Usando método alternativo...")
        
        persistence = simplex_tree.persistence()
        by_dim = defaultdict(list)
        max_death_finite = 0
        
        for dim, (birth, death) in persistence:
            by_dim[dim].append((birth, death))
            if np.isfinite(death):
                max_death_finite = max(max_death_finite, death)
        
        max_val = max_death_finite * 1.1 if max_death_finite > 0 else 1.0
        y_inf = max_val * 1.15
        
        for dim in sorted(by_dim.keys()):
            color = colors[dim % len(colors)]
            births, deaths = [], []
            
            for b, d in by_dim[dim]:
                births.append(b)
                deaths.append(d if np.isfinite(d) else y_inf)
            
            ax_pers.scatter(births, deaths, alpha=0.7, s=35, label=f'$H_{{{dim}}}$', c=color)
        
        ax_pers.plot([0, max_val], [0, max_val], 'k--', alpha=0.3, linewidth=1)
        ax_pers.axhline(y=y_inf, color='gray', linestyle=':', linewidth=1.5)
        ax_pers.text(max_val * 0.02, y_inf * 1.01, r"$\infty$", fontsize=11, color='gray')
        
        ax_pers.set_xlim(0, max_val)
        ax_pers.set_ylim(0, y_inf * 1.08)
        ax_pers.set_xlabel('Nacimiento', fontsize=11)
        ax_pers.set_ylabel('Muerte', fontsize=11)
        ax_pers.set_title('Diagrama de Persistencia', fontsize=12, fontweight='bold')
        ax_pers.legend(loc='lower right', fontsize=10, framealpha=0.9)
        ax_pers.grid(True, alpha=0.3)
    
   
    for dim in range(max_dim):
        col_idx = 1 + (dim // 2)
        row_idx = dim % 2
        ax = fig.add_subplot(gs[row_idx, col_idx])
        
        color = colors[dim % len(colors)]
        ax.plot(filtration_values, betti_numbers_core[dim], 'o-', 
                linewidth=2.5, markersize=5, color=color)
        
        ax.set_xlabel('Valor de filtración', fontsize=11)
        # usar notación LaTeX para etiquetas y títulos
        ax.set_ylabel(rf'$\beta_{{{dim}}}$', fontsize=12, fontweight='bold')
        ax.set_title(rf'$H_{{{dim}}}$ - Números de Betti $\beta_{{{dim}}}$ (core)', fontsize=11, fontweight='bold')
        ax.grid(True, alpha=0.3)
        ax.set_xlim([filtration_values[0], filtration_values[-1]])
        ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
    
    fig.suptitle(f'Análisis de Homología Persistente y Core\nTiempo de cálculo: {computation_time:.3f}s', 
                 fontsize=13, fontweight='bold', y=0.98)
    
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()
    
    print("\n" + "="*70)
    print("REDUCCIÓN DE SÍMPLICES POR EL CORE")
    print("="*70)
    print(f"{'#':<5} {'Filtración':<15} {'K_i':<12} {'core(K_i)':<12} {'Reducción':<12}")
    print("-"*70)
    
    for idx, (filt_val, (n_orig, n_core)) in enumerate(zip(filtration_values, simplex_counts)):
        reduction = (1 - n_core / n_orig) * 100 if n_orig > 0 else 0
        print(f"{idx+1:<5} {filt_val:<15.6f} {n_orig:<12} {n_core:<12} {reduction:<12.1f}%")
    
    print("="*70)
    
    total_orig = sum(n_orig for n_orig, _ in simplex_counts)
    total_core = sum(n_core for _, n_core in simplex_counts)
    avg_reduction = (1 - total_core / total_orig) * 100 if total_orig > 0 else 0
    
    print(f"\nTotal símplices K_i:       {total_orig}")
    print(f"Total símplices core(K_i): {total_core}")
    print(f"Reducción promedio:        {avg_reduction:.1f}%")
    print(f"Tiempo de cálculo:         {computation_time:.3f}s")
    print("="*70)


if __name__ == "__main__":
    
   # EJEMPLO ESFERA
    print("\n\n" + "="*70)
    print("EJEMPLO 2: PUNTOS EN LA ESFERA S^2")
    print("="*70)
    
    def puntos_esfera_ruido_dist(n_puntos=100, sigma=0.0, random_state=None):
        rng = np.random.default_rng(random_state)
        theta = rng.uniform(0, 2 * np.pi, n_puntos)
        phi = np.arccos(rng.uniform(-1, 1, n_puntos))
        
        x = np.sin(phi) * np.cos(theta)
        y = np.sin(phi) * np.sin(theta)
        z = np.cos(phi)
        puntos = np.vstack((x, y, z)).T
        
        ruido = rng.normal(0, sigma, size=puntos.shape)
        puntos_con_ruido = puntos + ruido
        
        D = squareform(pdist(puntos_con_ruido, metric='euclidean'))
        return puntos_con_ruido, D
    
    puntos, D = puntos_esfera_ruido_dist(100, sigma=0.02, random_state=123)
    
    
    print("Construyendo filtración con gudhi...")
    rips = gudhi.RipsComplex(distance_matrix=D, max_edge_length=1)
    simplex_tree = rips.create_simplex_tree(max_dimension=3)
    print(f"Complejo: {simplex_tree.num_simplices()} símplices, {simplex_tree.num_vertices()} vértices")
    
    filt_vals, betti, counts, time_comp = compute_betti_and_cores(simplex_tree, max_dim=3, n_samples=15)
    plot_results(simplex_tree, filt_vals, betti, counts, time_comp)
