import numpy as np
from itertools import combinations
from sympy import Matrix, zeros, eye

def all_faces(simplices):
    verts = set()
    for s in simplices:
        for v in s:
            verts.add(v)
    faces = set((v,) for v in verts)
    for s in simplices:
        n = len(s)
        for k in range(2, n+1):
            for face in combinations(s, k):
                faces.add(tuple(sorted(face)))
    return sorted(faces, key=lambda x: (len(x), x))

def build_by_dim(faces):
    by_dim = dict()
    for s in faces:
        k = len(s)-1
        by_dim.setdefault(k, []).append(s)
    return by_dim

def boundary_matrix(by_dim, k, field=0):
    if k not in by_dim or (k-1) not in by_dim:
        return Matrix.zeros(len(by_dim.get(k-1, [])), len(by_dim.get(k, [])))
    rows = len(by_dim[k-1])
    cols = len(by_dim[k])
    mat = np.zeros((rows, cols), dtype=int)
    for j, s in enumerate(by_dim[k]):
        for i, face in enumerate(combinations(s, k)):
            face = tuple(sorted(face))
            sign = (-1)**i
            idx = by_dim[k-1].index(face)
            mat[idx, j] += sign
    if field == 2:
        mat = mat % 2
        return Matrix(mat.tolist()).applyfunc(lambda x: x % 2)
    else:
        return Matrix(mat.tolist())

def gauss_jordan_f2(A):
    """Implementación de eliminación de Gauss-Jordan para F₂"""
    m, n = A.rows, A.cols
    A = A.copy()
    pivots = []
    row = 0
    for col in range(n):
        if row >= m:
            break
        # Encontrar pivote en la columna actual
        pivot_row = None
        for i in range(row, m):
            if A[i, col] % 2 == 1:
                pivot_row = i
                break
        if pivot_row is None:
            continue
        
        # Intercambiar filas
        if pivot_row != row:
            A.row_swap(row, pivot_row)
        
        # Eliminar en otras filas
        for i in range(m):
            if i != row and A[i, col] % 2 == 1:
                A[i, :] = (A[i, :] + A[row, :]) % 2
        
        pivots.append(col)
        row += 1
    
    return A, pivots

def nullspace_f2(A):
    """Calcula el nucleo de una matriz sobre F₂"""
    A_f2 = A.applyfunc(lambda x: x % 2)
    m, n = A_f2.rows, A_f2.cols
    A_rref, pivots = gauss_jordan_f2(A_f2)
    free_cols = sorted(set(range(n)) - set(pivots))
    nulls = []
    for j in free_cols:
        vec = zeros(n, 1)
        vec[j] = 1
        for i, p in enumerate(pivots):
            if p < n:
                vec[p] = A_rref[i, j] % 2
        nulls.append(vec)
    return nulls

def columnspace_f2(A):
    """Calcula la imagen de una matriz sobre F₂"""
    A_f2 = A.applyfunc(lambda x: x % 2)
    A_rref, pivots = gauss_jordan_f2(A_f2)
    basis = []
    for p in pivots:
        if p < A_f2.cols:
            basis.append(A_f2[:, p])
    return basis

def compute_homology_with_basis(by_dim, k, field=0):
    # Obtener matrices de borde
    n_k = len(by_dim.get(k, []))
    
    if k == 0:
        d_k = Matrix.zeros(1, n_k)
    else:
        d_k = boundary_matrix(by_dim, k, field)
    
    if k + 1 in by_dim:
        d_kp1 = boundary_matrix(by_dim, k + 1, field)
    else:
        d_kp1 = Matrix.zeros(n_k, 0)
    
    # Calcular espacios de ciclos (Z_k) y bordes (B_k)
    if field == 2:
        if k == 0:
            # Para dimensión 0, todos los vértices son ciclos
            Z_basis = [eye(n_k)[:, i] for i in range(n_k)]
        else:
            Z_basis = nullspace_f2(d_k)
        
        B_basis = columnspace_f2(d_kp1) if d_kp1.cols > 0 else []
    else:
        if k == 0:
            Z_basis = [eye(n_k)[:, i] for i in range(n_k)]
        else:
            Z_basis = d_k.nullspace()
        
        B_basis = d_kp1.columnspace() if d_kp1.cols > 0 else []
    
    # Construir matriz aumentada [B | Z]
    M = Matrix.zeros(n_k, 0)
    
    # Añadir vectores de B_basis
    for vec in B_basis:
        if vec.rows != n_k:
            vec = vec.reshape(n_k, 1)
        M = M.row_join(vec)
    
    num_B = M.cols
    
    # Añadir vectores de Z_basis
    for vec in Z_basis:
        if vec.rows != n_k:
            vec = vec.reshape(n_k, 1)
        M = M.row_join(vec)
    
    # Reducir a forma escalonada
    if field == 2:
        M_rref, pivots = gauss_jordan_f2(M)
    else:
        M_rref, pivots = M.rref()
    
    # Identificar vectores de Z linealmente independientes de B
    basis_indices = []
    for j in range(num_B, M.cols):
        # Verificar si esta columna tiene un pivote
        is_pivot = j in pivots
        if is_pivot:
            # Verificar que el pivote no esté en la parte de B
            pivot_row = pivots.index(j)
            is_independent = True
            for i in range(pivot_row):
                if M_rref[i, j] % 2 == 1:
                    is_independent = False
                    break
            if is_independent:
                basis_indices.append(j - num_B)
    
    # Construir base de homología
    homology_basis = []
    for idx in basis_indices:
        homology_basis.append(Z_basis[idx])
    
    # Convertir a matriz
    if homology_basis:
        H = homology_basis[0]
        for i in range(1, len(homology_basis)):
            H = H.row_join(homology_basis[i])
    else:
        H = Matrix.zeros(n_k, 0)
    
    betti = H.cols
    
    return H, betti

def print_homology(by_dim, field=0):
    
    field_name = "F_2" if field == 2 else "Q"
    print(f"\nHomología con coeficientes en {field_name}:")
    
    max_dim = max(by_dim.keys())
    for k in range(max_dim + 1):
        hom_basis, betti = compute_homology_with_basis(by_dim, k, field)
        print(f"  Dimensión {k}: número de Betti = {betti}")
        
        if betti == 0:
            print("    (No hay generadores)")
            continue
            
        print("    Base para H_%d:" % k)
        basis_simplices = by_dim.get(k, [])
        
        for i in range(hom_basis.cols):
            vec = hom_basis[:, i]
            terms = []
            for j in range(len(vec)):
                coef = vec[j]
                if coef != 0:
                    s = basis_simplices[j]
                    if field == 2:
                        terms.append(f"{list(s)}")
                    else:
                        if coef == 1:
                            terms.append(f"{list(s)}")
                        elif coef == -1:
                            terms.append(f"-{list(s)}")
                        else:
                            terms.append(f"{coef}·{list(s)}")
            
            if terms:
                print(f"      Generador {i+1}: {' + '.join(terms)}")
            else:
                print(f"      Generador {i+1}: 0")

def main():
    print("=== Calculadora de Homología Simplicial ===")
    raw = input("Ingrese la lista de símplices maximales en formato Python (ej: [[0,1],[1,2],[2,0]]):\n")
    if not raw:
        print("Usando ejemplo por defecto: plano proyectivo.")
        simplices_max =  [[1,2,5],[1,3,5],[3,4,5],[2,5,6],[4,5,6],[2,3,4],[2,3,6],[1,2,4],[1,4,6],[1,3,6]]
    else:
        simplices_max = eval(raw)
        simplices_max = [tuple(sorted(s)) for s in simplices_max]
    
    faces = all_faces(simplices_max)
    by_dim = build_by_dim(faces)
    
    print("\nComplejo simplicial completado por caras:")
    for k in sorted(by_dim.keys()):
        print(f"  {k}-símplices: {by_dim[k]}")
    
    print_homology(by_dim, field=0)  # Q
    print_homology(by_dim, field=2)  # F_2

if __name__ == "__main__":
    main()