from Bio.PDB import PDBParser, PDBList
import numpy as np
from ripser import ripser
from persim import plot_diagrams
import matplotlib.pyplot as plt
import os
import gzip
import shutil

# ----------- Función para descargar y preparar archivo PDB -----------
def get_pdb_file(pdb_id, pdb_dir="pdb_files"):
    os.makedirs(pdb_dir, exist_ok=True)
    pdb_id = pdb_id.lower()
    pdb_path = os.path.join(pdb_dir, f"{pdb_id}.pdb")

    # Si ya existe el archivo .pdb local, no descarga de nuevo
    if os.path.isfile(pdb_path):
        print(f"✅ Archivo {pdb_id}.pdb encontrado localmente.")
        return pdb_path

    print(f"⬇️  Descargando {pdb_id} desde el servidor PDB...")
    pdbl = PDBList()
    gz_file = pdbl.retrieve_pdb_file(pdb_id, pdir=pdb_dir, file_format="pdb")  # descarga .ent o .ent.gz

    # Detectar nombre real descargado (por ejemplo pdb1c26.ent.gz)
    if gz_file.endswith(".gz"):
        # Descomprimir
        decompressed_path = gz_file[:-3]  # quitar .gz
        with gzip.open(gz_file, "rb") as f_in, open(decompressed_path, "wb") as f_out:
            shutil.copyfileobj(f_in, f_out)
        os.remove(gz_file)
    else:
        decompressed_path = gz_file

    # Renombrar a formato estándar 1c26.pdb
    shutil.move(decompressed_path, pdb_path)

    print(f"✅ Archivo {pdb_id}.pdb descargado y listo en {pdb_dir}")
    return pdb_path

# ----------- Función para extraer coordenadas -----------
def extract_coordinates(pdb_file, only_CA=True):
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure('protein', pdb_file)
    coords = []
    for atom in structure.get_atoms():
        if only_CA:
            if atom.get_name() == "CA":  # Solo átomos Cα
                coords.append(atom.coord)
        else:
            coords.append(atom.coord)
    return np.array(coords)

# ----------- Descargar y leer el archivo -----------
pdb_path = get_pdb_file("1c26")

points = extract_coordinates(pdb_path, only_CA=False)
print(f" all-atom: {points.shape[0]} puntos en R^{points.shape[1]}")

points2 = extract_coordinates(pdb_path, only_CA=True)
print(f" solo C_alpha: {points2.shape[0]} puntos en R^{points2.shape[1]}")

# ----------- Calcular homología persistente -----------
results = ripser(points, maxdim=2)
results2 = ripser(points2, maxdim=2)

diagrams = results['dgms']
diagrams2 = results2['dgms']

# ----------- Configuración de colores por dimensión -----------
colors = ['tab:blue', 'tab:orange', 'tab:green']  # H0, H1, H2

# ----------- Crear figura 2x2 -----------
fig, axs = plt.subplots(2, 2, figsize=(12, 10))

# ----------- Diagramas de persistencia -----------
plot_diagrams(diagrams, ax=axs[0, 0], show=False)
axs[0, 0].set_title("Diagrama de Persistencia all-atoms")

plot_diagrams(diagrams2, ax=axs[0, 1], show=False)
axs[0, 1].set_title("Diagrama de Persistencia C-alpha")

# ----------- Códigos de barra (barcodes) -----------
def plot_barcodes(ax, diagrams, colors):
    y_offset = 0
    for dim, dgm in enumerate(diagrams):
        color = colors[dim % len(colors)]
        for birth, death in dgm:
            xmax = death if np.isfinite(death) else ax.get_xlim()[1]
            ax.hlines(y=y_offset, xmin=birth, xmax=xmax, color=color, linewidth=2)
            y_offset += 1
    ax.set_xlabel("Filtración (ε)")
    ax.set_ylabel("Ciclos")
    ax.set_title("Códigos de Barra")
    ax.grid(True, linestyle=':', alpha=0.4)

# all-atoms
plot_barcodes(axs[1, 0], diagrams, colors)
axs[1, 0].set_title("Códigos de Barra all-atoms")

# c-alpha
plot_barcodes(axs[1, 1], diagrams2, colors)
axs[1, 1].set_title("Códigos de Barra C-alpha")

# ----------- Ajuste final -----------
plt.tight_layout()
plt.show()
