
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

# ============================
# Construicción de la matriz D
# ============================
def build_boundary_matrix(st):
    simplices = list(st.get_filtration())
    simplices.sort(key=lambda x: (x[1], len(x[0])))  
    # orden de filtración y después x dimensión (esto asegura que 
    #sigma esté antes que tau si b(sigma)< b(tau) o b(sigma)=b(tau) y sigma es cara de tau)

    N = len(simplices)
    simp_to_idx = {tuple(s[0]): i for i, s in enumerate(simplices)}
    D = np.zeros((N, N), dtype=int)
    
    #hacemos el calculo con coeficientes en Q, por eso ponemos los signos y no reducimos modulo p

    def boundary(simplex):
        faces = []
        for i in range(len(simplex)):
            face = simplex[:i] + simplex[i+1:] #esta es la cara i-esima (que se obtiene sacando el vértice i)
            sign = (-1) ** i
            faces.append((tuple(face), sign))
        return faces

    for j, (sigma, _) in enumerate(simplices):
        sigma = tuple(sigma)
        if len(sigma) > 1:
            for face, sign in boundary(list(sigma)):
                i = simp_to_idx[face]
                D[i, j] = sign

    return D, simplices

# ============================
# Reducción de columnas (cada p tiene que ser el low(q) a lo sumo para una columna q), las columnas con ceros tiene low=-1
# las filas y columnas en python quedan numeradas a partir de 0, así que ponemos low=-1 si son todos ceros y no hay pivote
# ============================
def reduce_matrix(D):
    def low(col):
        rows = np.where(col != 0)[0]
        return rows[-1] if len(rows) else -1

    Dred = D.copy()
    N = D.shape[1]

    for j in range(N):
        while True:
            lows = [low(Dred[:, r]) for r in range(j)]
            lj = low(Dred[:, j])
            if lj in lows and lj != -1:
                r = [r for r in range(j) if low(Dred[:, r]) == lj][0]
                factor = Dred[lj, j] // Dred[lj, r]
                Dred[:, j] -= factor * Dred[:, r]
            else:
                break
    return Dred

# ============================
# Extraer barcodes
# ============================
def extract_barcodes(Dred, simplices):
    def low(col):
        rows = np.where(col != 0)[0]
        return rows[-1] if len(rows) else -1

    barcodes = defaultdict(list)
    N = Dred.shape[1]

    low_of = {}
    for j in range(N):
        lj = low(Dred[:, j])
        if lj != -1:
            low_of[j] = lj

    # Barras finitas
    for j, lj in low_of.items():
        birth_simplex, birth_filtration = simplices[lj]
        death_simplex, death_filtration = simplices[j]
        dim = len(birth_simplex) - 1
        barcodes[dim].append((birth_filtration, death_filtration))

    # Barras infinitas
    pivots = set(low_of.values())
    for j in range(N):
        if low(Dred[:, j]) == -1 and j not in pivots:
            simplex, filt = simplices[j]
            dim = len(simplex) - 1
            barcodes[dim].append((filt, np.inf))

    return barcodes

# ============================
# Gráfico de barcodes 
# ============================
def plot_barcodes(barcodes, simplices):
    colors = {0: "red", 1: "blue", 2: "green", 3: "purple"}
    fig, ax = plt.subplots(figsize=(8, 5))

    max_filt = int(max(f for _, f in simplices))
    y_offset = 0
    block_gap = 2  # separación extra entre bloques

    for dim, intervals in sorted(barcodes.items()):
        # ordenar barras de más larga a más corta
        intervals_sorted = sorted(intervals, key=lambda x: (max_filt+2 if x[1]==np.inf else x[1]) - x[0], reverse=True)
        for (b, d) in intervals_sorted:
            end = d if d != np.inf else max_filt + 1
            ax.hlines(y_offset, b, end,
                      colors=colors.get(dim, "black"),
                      lw=2)
            y_offset += 1
        # Etiqueta de dimensión al costado izquierdo
        if intervals_sorted:
            mid_y = y_offset - len(intervals_sorted)/2
            ax.text(-0.8, mid_y, f"H{dim}", va="center", ha="right", 
                    fontsize=12, color=colors.get(dim,"black"), weight="bold")
        y_offset += block_gap

       # configurar ticks eje x
    xticks = list(range(0, max_filt+1)) + [max_filt+1]
    xlabels = [str(v) for v in range(0, max_filt+1)] + ["∞"]
    ax.set_xticks(xticks)
    ax.set_xticklabels(xlabels)

    # sacar eje y
    ax.set_yticks([])
    ax.set_yticklabels([])
    ax.set_ylabel("")

    ax.set_xlabel("Valor filtración")
    ax.set_title("Códigos de barra")

    # sin recuadro de leyenda
    ax.legend().remove() if ax.get_legend() else None


    plt.tight_layout()
    plt.show()

# =======================================
# codigo para crear filtracion sin usar gudhi (si se inserta un simplex y no está alguna de sus caras, se insertan automáticamente al mismo tiempo que el simplex)
# ===========================================

class Filtration:
    def __init__(self):
        # Guardamos simplices como {tuple(simplex): filtracion}
        self.simplices = {}

    def insert(self, simplex, filtration):
        """
        Inserta el simplex y todas sus caras si no existen.
        No duplica los ya insertados.
        """
        simplex = tuple(sorted(simplex))  # ordenamos para consistencia

        # Si ya existe, actualizar filtración solo si es menor
        if simplex in self.simplices:
            self.simplices[simplex] = min(self.simplices[simplex], filtration)
        else:
            self.simplices[simplex] = filtration

        # Insertar recursivamente todas las caras
        if len(simplex) > 1:
            for i in range(len(simplex)):
                face = simplex[:i] + simplex[i+1:]
                self.insert(face, filtration)  # filtración de la cara <= filtración del simplex

    def get_filtration(self):
        # Retorna lista de (simplex, filtración), convertimos a lista de listas
        return [(list(s), f) for s, f in self.simplices.items()]




# ============================
# EJEMPLO DE USO (abajo se puede cambiar la filtracion por otra)
# ============================
if __name__ == "__main__":
    
    #creamos la filtración usando Filtration(), insertando cada simplex de la filtración
    #al insertar un simplex, inserta automáticamente en ese momento las caras que no estén todavía insertadas

    st = Filtration()

    # Nivel 0 (K0)
    st.insert([0], filtration=0)
    st.insert([1], filtration=0)
    st.insert([2], filtration=0)

    # K1
    st.insert([0,1], filtration=1)
    st.insert([0,2], filtration=1)
    st.insert([1,2], filtration=1)
    st.insert([3], filtration=1)

    # K2
    st.insert([1,3], filtration=2)
    st.insert([2,3], filtration=2)
    
    # K3
    st.insert([1,2,3], filtration=3)

    # K4
    st.insert([1,2,4], filtration=4)
    st.insert([2,3,4], filtration=4)
    st.insert([1,3,4], filtration=4)

    D, simplices = build_boundary_matrix(st)
    Dred = reduce_matrix(D)
    barcodes = extract_barcodes(Dred, simplices)
    plot_barcodes(barcodes, simplices)

