from Bio.PDB import PDBParser, PDBList
import numpy as np
from ripser import ripser

import matplotlib.pyplot as plt
import os
import gzip
import shutil


from persim import plot_diagrams, bottleneck
from scipy.spatial import distance_matrix
import time  # <-- agregado para medir tiempos

 
def get_pdb_file(pdb_id, pdb_dir="pdb_files"):
    os.makedirs(pdb_dir, exist_ok=True)
    pdb_id = pdb_id.lower()
    pdb_path = os.path.join(pdb_dir, f"{pdb_id}.pdb")

    # Si ya existe el archivo .pdb local, no descarga de nuevo
    if os.path.isfile(pdb_path):
        print(f"✅ Archivo {pdb_id}.pdb encontrado localmente.")
        return pdb_path

    print(f"⬇️  Descargando {pdb_id} desde el servidor PDB...")
    pdbl = PDBList()
    gz_file = pdbl.retrieve_pdb_file(pdb_id, pdir=pdb_dir, file_format="pdb")  # descarga .ent o .ent.gz

    # Detectar nombre real descargado (por ejemplo pdb1c26.ent.gz)
    if gz_file.endswith(".gz"):
        # Descomprimir
        decompressed_path = gz_file[:-3]  # quitar .gz
        with gzip.open(gz_file, "rb") as f_in, open(decompressed_path, "wb") as f_out:
            shutil.copyfileobj(f_in, f_out)
        os.remove(gz_file)
    else:
        decompressed_path = gz_file

    # Renombrar a formato estándar 1c26.pdb
    shutil.move(decompressed_path, pdb_path)

    print(f"✅ Archivo {pdb_id}.pdb descargado y listo en {pdb_dir}")
    return pdb_path

# ----------- Función para extraer coordenadas -----------
def extract_coordinates(pdb_file, only_CA=True):
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure('protein', pdb_file)
    coords = []
    for atom in structure.get_atoms():
        if only_CA:
            if atom.get_name() == "CA":  # Solo átomos Cα
                coords.append(atom.coord)
        else:
            coords.append(atom.coord)
    return np.array(coords)

# ----------- Descargar y leer el archivo -----------
pdb_path = get_pdb_file("1i2t")

points = extract_coordinates(pdb_path, only_CA=False)

def farthest_point_sampling(X):
    """
    Construye una permutación gulosa (greedy permutation)
    y devuelve los 'exit-times' t_p según la definición del net-tower.
    """
    n = X.shape[0]
    D = distance_matrix(X, X)
    perm = []
    t = np.zeros(n)
    # arrancamos desde el punto 0
    current = 0
    perm.append(current)
    t[current] = np.inf
    min_dist = D[current].copy()
    for _ in range(1, n):
        idx = np.argmax(min_dist)
        perm.append(idx)
        t[idx] = min_dist[idx]
        min_dist = np.minimum(min_dist, D[idx])
    # t_out reordenado en el orden original
    t_out = np.zeros(n)
    t_out[perm[0]] = np.max(t[np.isfinite(t)])
    for j in range(1, n):
        t_out[perm[j]] = t[perm[j]]
    return np.array(perm), t_out

def compare_rips_sparsified(X, gamma, maxdim=2):
    """
    Calcula homología persistente para:
      - nube completa X
      - submuestra N_gamma = {p : t_p > gamma}
    y muestra los dos diagramas y las distancias bottleneck.
    """
    perm, t_p = farthest_point_sampling(X)

    t0_full = time.time()
    res_full = ripser(X, maxdim=maxdim)
    t1_full = time.time()
    tiempo_full = t1_full - t0_full

    t0_sparse_total = time.time()
    sel = np.where(t_p > gamma)[0]  # selección de submuestra
    Xs = X[sel]
    res_sparse = ripser(Xs, maxdim=maxdim)
    t1_sparse_total = time.time()
    tiempo_sparse_total = t1_sparse_total - t0_sparse_total

    dgms_full = res_full['dgms']
    dgms_sparse = res_sparse['dgms']

   
    b0 = bottleneck(dgms_full[0], dgms_sparse[0]) if len(dgms_full) > 0 else None
    b1 = bottleneck(dgms_full[1], dgms_sparse[1]) if len(dgms_full) > 1 else None
    b2 = bottleneck(dgms_full[2], dgms_sparse[2]) if len(dgms_full) > 2 else None

   
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    plot_diagrams(dgms_full, ax=axes[0], show=False, title="Rips completo")
    plot_diagrams(dgms_sparse, ax=axes[1], show=False, title=f"Sparse Rips (γ={gamma:.3f})")

   
    total_pts = X.shape[0]
    sparse_pts = Xs.shape[0]
    text = (
        f"Nube original: {total_pts} puntos\n"
        f"Submuestra (N_γ): {sparse_pts} puntos\n"
        f"Bottleneck H0 = {b0:.4f}   H1 = {b1:.4f}  H2 = {b2:.4f}\n"
        f"Tiempo full = {tiempo_full:.3f}s   Tiempo submuestra = {tiempo_sparse_total:.3f}s"
    )
    fig.text(0.5, 0.02, text, ha='center', va='bottom', fontsize=11)
    plt.tight_layout(rect=[0, 0.08, 1, 1])
    plt.show()

    return {
        "gamma": gamma,
        "n_full": total_pts,
        "n_sparse": sparse_pts,
        "bottleneck_H0": b0,
        "bottleneck_H1": b1,
        "bottleneck_H2": b2,
        "tiempo_full": tiempo_full,
        "tiempo_sparse_total": tiempo_sparse_total
    }

if __name__ == "__main__":
    X = points
    print(f" tipo all-atom: {X.shape[0]} puntos en R^{X.shape[1]}")
    
    _, t_p = farthest_point_sampling(X)

    # Elegimos gamma como percentil 50
    gamma = np.percentile(t_p[t_p > 0], 50)

    results = compare_rips_sparsified(X, gamma)
    print("\n--- Resultados ---")
    for k, v in results.items():
        print(f"{k}: {v}")
