diff --git a/docs/source/user_guide/benchmarks/index.rst b/docs/source/user_guide/benchmarks/index.rst index a637cabd..7a42e816 100644 --- a/docs/source/user_guide/benchmarks/index.rst +++ b/docs/source/user_guide/benchmarks/index.rst @@ -7,3 +7,4 @@ Benchmarks surfaces nebs + supramolecular diff --git a/docs/source/user_guide/benchmarks/supramolecular.rst b/docs/source/user_guide/benchmarks/supramolecular.rst new file mode 100644 index 00000000..0c3a44a4 --- /dev/null +++ b/docs/source/user_guide/benchmarks/supramolecular.rst @@ -0,0 +1,53 @@ +============== +Supramolecular +============== + +PLA15 +====== + +Summary +------- + +Performance in predicting protein–ligand active-site interaction energies for the +PLA15 set of 15 complexes. Systems range from 259 to 584 atoms and contain complete +active sites. Ligands contain 37–95 atoms with net charges of −1, 0, or +1 and all contain +aromatic heterocycles. Five ligands contain either divalent of tetrahedral sulfur atoms, and +four and three of them contain F and Cl atoms, respectively. + +Metrics +------- + +Total MAE + +For each complex, the interaction energy is calculated by taking the difference in energy between the protein-ligand complex and the sum of the individual protein and ligand energies. The MAE is computed by comparing predicted interaction energies to reference interaction energies across all 15 systems. + +Pearson's r² + +The squared Pearson correlation coefficient between predicted and reference interaction energies, measuring the proportion of variance in the reference values explained by the model predictions. + +Ion-Ion MAE + +For each complex where both protein and ligand fragments have non-zero charges, the interaction energy error is calculated. This metric reports the MAE for these ion-ion interaction systems. + +Ion-Neutral MAE + +For each complex where one fragment (protein or ligand) has a non-zero charge and the other is neutral, the interaction energy error is calculated. This metric reports the MAE for these ion-neutral interaction systems. + + +Computational cost +------------------ + +low: tests are likely to take minutes to run on CPU. + +Data availability +----------------- + +Input structures: + +* K. Kříž and J. Řezáč, ‘protein ligand - Benchmarking of Semiempirical Quantum-Mechanical Methods on Systems Relevant to Computer-Aided Drug Design’, J. Chem. Inf. Model., vol. 60, no. 3, pp. 1453–1460, Mar. 2020, doi: 10.1021/acs.jcim.9b01171. +* Structures download found in SI + +Reference data: + +* The Supporting Information also provides the interaction energies. + * The benchmark interaction energies are based on a combination of explicitly correlated MP2-F12 calculations and a DLPNO-CCSD(T) correction diff --git a/mlip_testing/analysis/supramolecular/PLA15/analyse_PLA15.py b/mlip_testing/analysis/supramolecular/PLA15/analyse_PLA15.py new file mode 100644 index 00000000..eb10eda0 --- /dev/null +++ b/mlip_testing/analysis/supramolecular/PLA15/analyse_PLA15.py @@ -0,0 +1,463 @@ +"""Analyse PLA15 benchmark.""" + +from __future__ import annotations + +import pytest + +from mlip_testing.analysis.utils.decorators import build_table, plot_parity +from mlip_testing.analysis.utils.utils import mae +from mlip_testing.app import APP_ROOT +from mlip_testing.calcs import CALCS_ROOT +from mlip_testing.calcs.models.models import MODELS + +CALC_PATH = CALCS_ROOT / "supramolecular" / "PLA15" / "outputs" +OUT_PATH = APP_ROOT / "data" / "supramolecular" / "PLA15" + + +def get_system_identifiers() -> list[str]: + """ + Get list of PLA15 system identifiers. + + Returns + ------- + list[str] + List of system identifiers from structure files. + """ + from ase.io import read + + system_identifiers = [] + for model_name in MODELS: + model_dir = CALC_PATH / model_name + if model_dir.exists(): + xyz_files = sorted(model_dir.glob("*.xyz")) + if xyz_files: + for xyz_file in xyz_files: + atoms = read(xyz_file) + system_identifiers.append( + atoms.info.get("identifier", f"system_{xyz_file.stem}") + ) + break + return system_identifiers + + +def get_atom_counts() -> list[int]: + """ + Get complex atom counts for PLA15. + + Returns + ------- + list[int] + List of complex atom counts from structure files. + """ + from ase.io import read + + for model_name in MODELS: + model_dir = CALC_PATH / model_name + if model_dir.exists(): + xyz_files = sorted(model_dir.glob("*.xyz")) + if xyz_files: + atom_counts = [] + for xyz_file in xyz_files: + atoms = read(xyz_file) + atom_counts.append(len(atoms)) + return atom_counts + return [] + + +def get_charges() -> list[int]: + """ + Get complex charges for PLA15. + + Returns + ------- + list[int] + List of complex charges from structure files. + """ + from ase.io import read + + for model_name in MODELS: + model_dir = CALC_PATH / model_name + if model_dir.exists(): + xyz_files = sorted(model_dir.glob("*.xyz")) + if xyz_files: + charges = [] + for xyz_file in xyz_files: + atoms = read(xyz_file) + charges.append(atoms.info.get("complex_charge", 0)) + return charges + return [] + + +def get_protein_atom_counts() -> list[int]: + """ + Get protein atom counts for PLA15. + + Returns + ------- + list[int] + List of protein atom counts from structure files. + """ + from ase.io import read + + for model_name in MODELS: + model_dir = CALC_PATH / model_name + if model_dir.exists(): + xyz_files = sorted(model_dir.glob("*.xyz")) + if xyz_files: + protein_counts = [] + for xyz_file in xyz_files: + atoms = read(xyz_file) + protein_counts.append(atoms.info.get("protein_atoms", 0)) + return protein_counts + return [] + + +def get_ligand_atom_counts() -> list[int]: + """ + Get ligand atom counts for PLA15. + + Returns + ------- + list[int] + List of ligand atom counts from structure files. + """ + from ase.io import read + + for model_name in MODELS: + model_dir = CALC_PATH / model_name + if model_dir.exists(): + xyz_files = sorted(model_dir.glob("*.xyz")) + if xyz_files: + ligand_counts = [] + for xyz_file in xyz_files: + atoms = read(xyz_file) + ligand_counts.append(atoms.info.get("ligand_atoms", 0)) + return ligand_counts + return [] + + +def get_protein_charges() -> list[int]: + """ + Get protein fragment charges for PLA15. + + Returns + ------- + list[int] + List of protein charges from structure files. + """ + from ase.io import read + + for model_name in MODELS: + model_dir = CALC_PATH / model_name + if model_dir.exists(): + xyz_files = sorted(model_dir.glob("*.xyz")) + if xyz_files: + protein_charges = [] + for xyz_file in xyz_files: + atoms = read(xyz_file) + protein_charges.append(atoms.info.get("protein_charge", 0)) + return protein_charges + return [] + + +def get_ligand_charges() -> list[int]: + """ + Get ligand charges for PLA15. + + Returns + ------- + list[int] + List of ligand charges from structure files. + """ + from ase.io import read + + for model_name in MODELS: + model_dir = CALC_PATH / model_name + if model_dir.exists(): + xyz_files = sorted(model_dir.glob("*.xyz")) + if xyz_files: + ligand_charges = [] + for xyz_file in xyz_files: + atoms = read(xyz_file) + ligand_charges.append(atoms.info.get("ligand_charge", 0)) + return ligand_charges + return [] + + +def get_interaction_types() -> list[str]: + """ + Get interaction types for PLA15. + + Returns + ------- + list[str] + List of interaction types from structure files. + """ + from ase.io import read + + for model_name in MODELS: + model_dir = CALC_PATH / model_name + if model_dir.exists(): + xyz_files = sorted(model_dir.glob("*.xyz")) + if xyz_files: + interaction_types = [] + for xyz_file in xyz_files: + atoms = read(xyz_file) + interaction_types.append( + atoms.info.get("interaction_type", "unknown") + ) + return interaction_types + return [] + + +@pytest.fixture +@plot_parity( + filename=OUT_PATH / "figure_interaction_energies.json", + title="PLA15 Protein-Ligand Interaction Energies", + x_label="Predicted interaction energy / kcal/mol", + y_label="Reference interaction energy / kcal/mol", + hoverdata={ + "System": get_system_identifiers(), + "Complex Atoms": get_atom_counts(), + "Protein Atoms": get_protein_atom_counts(), + "Ligand Atoms": get_ligand_atom_counts(), + "Total Charge": get_charges(), + "Protein Charge": get_protein_charges(), + "Ligand Charge": get_ligand_charges(), + "Interaction Type": get_interaction_types(), + }, +) +def interaction_energies() -> dict[str, list]: + """ + Get interaction energies for all PLA15 systems. + + Returns + ------- + dict[str, list] + Dictionary of reference and predicted interaction energies. + """ + from ase.io import read + + results = {"ref": []} | {mlip: [] for mlip in MODELS} + ref_stored = False + + for model_name in MODELS: + model_dir = CALC_PATH / model_name + + if not model_dir.exists(): + results[model_name] = [] + continue + + xyz_files = sorted(model_dir.glob("*.xyz")) + if not xyz_files: + results[model_name] = [] + continue + + model_energies = [] + ref_energies = [] + + for xyz_file in xyz_files: + atoms = read(xyz_file) + model_energies.append(atoms.info["E_int_model_kcal"]) + if not ref_stored: + ref_energies.append(atoms.info["E_int_ref_kcal"]) + + results[model_name] = model_energies + + # Store reference energies (only once) + if not ref_stored: + results["ref"] = ref_energies + ref_stored = True + + # Copy individual structure files to app data directory + structs_dir = OUT_PATH / model_name + structs_dir.mkdir(parents=True, exist_ok=True) + + # Copy individual structure files + import shutil + + for i, xyz_file in enumerate(xyz_files): + shutil.copy(xyz_file, structs_dir / f"{i}.xyz") + + return results + + +@pytest.fixture +def pla15_r2(interaction_energies) -> dict[str, float]: + """ + Get Pearson's r² for interaction energies. + + Parameters + ---------- + interaction_energies + Dictionary of reference and predicted interaction energies. + + Returns + ------- + dict[str, float] + Dictionary of Pearson's r² values for all models. + """ + from scipy.stats import pearsonr + + results = {} + for model_name in MODELS: + if interaction_energies[model_name]: + r, _ = pearsonr( + interaction_energies["ref"], interaction_energies[model_name] + ) + results[model_name] = r**2 + else: + results[model_name] = 0.0 + return results + + +@pytest.fixture +def pla15_mae(interaction_energies) -> dict[str, float]: + """ + Get mean absolute error for interaction energies (overall). + + Parameters + ---------- + interaction_energies + Dictionary of reference and predicted interaction energies. + + Returns + ------- + dict[str, float] + Dictionary of predicted interaction energy errors for all models. + """ + results = {} + for model_name in MODELS: + if interaction_energies[model_name]: + results[model_name] = mae( + interaction_energies["ref"], interaction_energies[model_name] + ) + else: + results[model_name] = float("nan") + return results + + +@pytest.fixture +def pla15_ion_ion_mae(interaction_energies) -> dict[str, float]: + """ + Get mean absolute error for ion-ion interactions. + + Parameters + ---------- + interaction_energies + Dictionary of reference and predicted interaction energies. + + Returns + ------- + dict[str, float] + Dictionary of predicted interaction energy errors for ion-ion systems. + """ + # Get interaction types for filtering + interaction_types = get_interaction_types() + ion_ion_indices = [ + i for i, itype in enumerate(interaction_types) if itype == "ion-ion" + ] + + results = {} + for model_name in MODELS: + if interaction_energies[model_name] and ion_ion_indices: + ref_ion_ion = [interaction_energies["ref"][i] for i in ion_ion_indices] + pred_ion_ion = [ + interaction_energies[model_name][i] for i in ion_ion_indices + ] + results[model_name] = mae(ref_ion_ion, pred_ion_ion) + else: + results[model_name] = float("nan") + return results + + +@pytest.fixture +def pla15_ion_neutral_mae(interaction_energies) -> dict[str, float]: + """ + Get mean absolute error for ion-neutral interactions. + + Parameters + ---------- + interaction_energies + Dictionary of reference and predicted interaction energies. + + Returns + ------- + dict[str, float] + Dictionary of predicted interaction energy errors for ion-neutral systems. + """ + # Get interaction types for filtering + interaction_types = get_interaction_types() + ion_neutral_indices = [ + i for i, itype in enumerate(interaction_types) if itype == "ion-neutral" + ] + + results = {} + for model_name in MODELS: + if interaction_energies[model_name] and ion_neutral_indices: + ref_ion_neutral = [ + interaction_energies["ref"][i] for i in ion_neutral_indices + ] + pred_ion_neutral = [ + interaction_energies[model_name][i] for i in ion_neutral_indices + ] + results[model_name] = mae(ref_ion_neutral, pred_ion_neutral) + else: + results[model_name] = float("nan") + return results + + +@pytest.fixture +@build_table( + filename=OUT_PATH / "pla15_metrics_table.json", + metric_tooltips={ + "Model": "Name of the model", + "MAE": "Mean Absolute Error for all systems (kcal/mol)", + "R²": "Pearson's r² (squared correlation coefficient)", + "Ion-Ion MAE": "MAE for ion-ion interactions (kcal/mol)", + "Ion-Neutral MAE": "MAE for ion-neutral interactions (kcal/mol)", + }, +) +def metrics( + pla15_mae: dict[str, float], + pla15_r2: dict[str, float], + pla15_ion_ion_mae: dict[str, float], + pla15_ion_neutral_mae: dict[str, float], +) -> dict[str, dict]: + """ + Get all PLA15 metrics. + + Parameters + ---------- + pla15_mae + Mean absolute errors for all systems. + pla15_r2 + R² values for all systems. + pla15_ion_ion_mae + Mean absolute errors for ion-ion interactions. + pla15_ion_neutral_mae + Mean absolute errors for ion-neutral interactions. + + Returns + ------- + dict[str, dict] + Metric names and values for all models. + """ + return { + "MAE": pla15_mae, + "R²": pla15_r2, + "Ion-Ion MAE": pla15_ion_ion_mae, + "Ion-Neutral MAE": pla15_ion_neutral_mae, + } + + +def test_pla15(metrics: dict[str, dict]) -> None: + """ + Run PLA15 test. + + Parameters + ---------- + metrics + All PLA15 metrics. + """ + return diff --git a/mlip_testing/app/supramolecular/PLA15/app_PLA15.py b/mlip_testing/app/supramolecular/PLA15/app_PLA15.py new file mode 100644 index 00000000..87bec1b4 --- /dev/null +++ b/mlip_testing/app/supramolecular/PLA15/app_PLA15.py @@ -0,0 +1,88 @@ +"""Run PLA15 app.""" + +from __future__ import annotations + +from pathlib import Path + +from dash import Dash +from dash.html import Div + +from mlip_testing.app import APP_ROOT +from mlip_testing.app.base_app import BaseApp +from mlip_testing.app.utils.build_callbacks import ( + plot_from_table_column, + struct_from_scatter, +) +from mlip_testing.app.utils.load import read_plot +from mlip_testing.calcs.models.models import MODELS + +BENCHMARK_NAME = Path(__file__).name.removeprefix("app_").removesuffix(".py") +DATA_PATH = APP_ROOT / "data" / "supramolecular" / "PLA15" + + +class PLA15App(BaseApp): + """PLA15 benchmark app layout and callbacks.""" + + def register_callbacks(self) -> None: + """Register callbacks to app.""" + scatter = read_plot( + DATA_PATH / "figure_interaction_energies.json", + id=f"{BENCHMARK_NAME}-figure", + ) + + structs_dir = DATA_PATH / list(MODELS.keys())[0] + # Assets dir will be parent directory - individual files for each system + structs = [ + f"assets/supramolecular/PLA15/{list(MODELS.keys())[0]}/{i}.xyz" + for i in range(len(list(structs_dir.glob("*.xyz")))) + ] + + plot_from_table_column( + table_id=self.table_id, + plot_id=f"{BENCHMARK_NAME}-figure-placeholder", + column_to_plot={"MAE": scatter}, + ) + + struct_from_scatter( + scatter_id=f"{BENCHMARK_NAME}-figure", + struct_id=f"{BENCHMARK_NAME}-struct-placeholder", + structs=structs, + mode="struct", + ) + + +def get_app() -> PLA15App: + """ + Get PLA15 benchmark app layout and callback registration. + + Returns + ------- + PLA15App + Benchmark layout and callback registration. + """ + return PLA15App( + name=BENCHMARK_NAME, + title="PLA15", + description=( + "Performance in predicting protein-ligand interaction energies for 15 " + "complete active site complexes." + ), + table_path=DATA_PATH / "pla15_metrics_table.json", + extra_components=[ + Div(id=f"{BENCHMARK_NAME}-figure-placeholder"), + Div(id=f"{BENCHMARK_NAME}-struct-placeholder"), + ], + ) + + +if __name__ == "__main__": + # Create Dash app + full_app = Dash(__name__, assets_folder=DATA_PATH.parent.parent) + + # Construct layout and register callbacks + pla15_app = get_app() + full_app.layout = pla15_app.layout + pla15_app.register_callbacks() + + # Run app + full_app.run(port=8055, debug=True) diff --git a/mlip_testing/calcs/supramolecular/PLA15/.dvc/.gitignore b/mlip_testing/calcs/supramolecular/PLA15/.dvc/.gitignore new file mode 100644 index 00000000..528f30c7 --- /dev/null +++ b/mlip_testing/calcs/supramolecular/PLA15/.dvc/.gitignore @@ -0,0 +1,3 @@ +/config.local +/tmp +/cache diff --git a/mlip_testing/calcs/supramolecular/PLA15/.dvc/config b/mlip_testing/calcs/supramolecular/PLA15/.dvc/config new file mode 100644 index 00000000..e69de29b diff --git a/mlip_testing/calcs/supramolecular/PLA15/.dvcignore b/mlip_testing/calcs/supramolecular/PLA15/.dvcignore new file mode 100644 index 00000000..51973055 --- /dev/null +++ b/mlip_testing/calcs/supramolecular/PLA15/.dvcignore @@ -0,0 +1,3 @@ +# Add patterns of files dvc should ignore, which could improve +# the performance. Learn more at +# https://dvc.org/doc/user-guide/dvcignore diff --git a/mlip_testing/calcs/supramolecular/PLA15/calc_PLA15.py b/mlip_testing/calcs/supramolecular/PLA15/calc_PLA15.py new file mode 100644 index 00000000..a51f3ee0 --- /dev/null +++ b/mlip_testing/calcs/supramolecular/PLA15/calc_PLA15.py @@ -0,0 +1,449 @@ +"""Run calculations for PLA15 benchmark.""" + +from __future__ import annotations + +from pathlib import Path + +from ase import Atoms, units +from ase.calculators.calculator import Calculator +from ase.io import write +import mlipx +from mlipx.abc import NodeWithCalculator +import numpy as np +from tqdm import tqdm +import zntrack + +from mlip_testing.calcs.models.models import MODELS +from mlip_testing.calcs.utils.utils import chdir, get_benchmark_data + +# Local directory to store output data +OUT_PATH = Path(__file__).parent / "outputs" + +# Constants +KCAL_PER_MOL_TO_EV = units.kcal / units.mol +EV_TO_KCAL_PER_MOL = 1.0 / KCAL_PER_MOL_TO_EV + + +class PLA15Benchmark(zntrack.Node): + """ + Benchmark model for PLA15 dataset. + + Evaluates protein-ligand interaction energies for 15 complete active site complexes. + Each complex consists of protein, ligand, and complex structures from PDB files. + Computes interaction energy = E(complex) - E(protein) - E(ligand) + """ + + model: NodeWithCalculator = zntrack.deps() + model_name: str = zntrack.params() + + @staticmethod + def extract_charge_and_selections( + pdb_path: Path, + ) -> tuple[float, float, float, str, str]: + """ + Extract charge and selection information from PDB REMARK lines. + + Parameters + ---------- + pdb_path : Path + Path to PDB file. + + Returns + ------- + Tuple[float, float, float, str, str] + Total charge, charge A, charge B, selection A, selection B. + """ + total_charge = qa = qb = 0.0 + selection_a = selection_b = "" + + with open(pdb_path) as f: + for line in f: + if not line.startswith("REMARK"): + if line.startswith("ATOM") or line.startswith("HETATM"): + break + continue + + parts = line.split() + if len(parts) < 3: + continue + + tag = parts[1].lower() + + if tag == "charge": + total_charge = float(parts[2]) + elif tag == "charge_a": + qa = float(parts[2]) + elif tag == "charge_b": + qb = float(parts[2]) + elif tag == "selection_a": + selection_a = " ".join(parts[2:]) + elif tag == "selection_b": + selection_b = " ".join(parts[2:]) + + return total_charge, qa, qb, selection_a, selection_b + + @staticmethod + def separate_protein_ligand_simple(pdb_path: Path): + """ + Separate protein and ligand based on residue names. + + Parameters + ---------- + pdb_path : Path + Path to PDB file. + + Returns + ------- + Tuple + All atoms, protein atoms, ligand atoms. + """ + import MDAnalysis as mda # noqa: N813 + + # Load with MDAnalysis + u = mda.Universe(str(pdb_path)) + + # Simple separation: ligand = UNK residues, protein = everything else + protein_atoms = [] + ligand_atoms = [] + + for atom in u.atoms: + if atom.resname.strip().upper() in ["UNK", "LIG", "MOL"]: + ligand_atoms.append(atom) + else: + protein_atoms.append(atom) + + return u.atoms, protein_atoms, ligand_atoms + + @staticmethod + def mda_atoms_to_ase(atom_list, charge: float, identifier: str) -> Atoms: + """ + Convert MDAnalysis atoms to ASE Atoms object. + + Parameters + ---------- + atom_list + List of MDAnalysis atoms. + charge : float + Charge of the fragment. + identifier : str + Identifier for the structure. + + Returns + ------- + Atoms + ASE Atoms object. + """ + if not atom_list: + atoms = Atoms() + atoms.info.update({"charge": charge, "identifier": identifier}) + return atoms + + symbols = [] + positions = [] + + for atom in atom_list: + # Get element symbol + try: + elem = (atom.element or "").strip().title() + except (AttributeError, TypeError): + elem = "" + + if not elem: + # Fallback: first letter of atom name + elem = "".join([c for c in atom.name if c.isalpha()])[:1].title() or "C" + + symbols.append(elem) + positions.append(atom.position) + + atoms = Atoms(symbols=symbols, positions=np.array(positions)) + atoms.info.update({"charge": int(round(charge)), "identifier": identifier}) + return atoms + + @staticmethod + def process_pdb_file(pdb_path: Path) -> dict[str, Atoms]: + """ + Parse PDB file and return complex with separated fragments. + + Parameters + ---------- + pdb_path : Path + Path to PDB file. + + Returns + ------- + Dict[str, Atoms] + Dictionary with 'complex', 'protein', and 'ligand' Atoms objects. + """ + total_charge, charge_a, charge_b, _, _ = ( + PLA15Benchmark.extract_charge_and_selections(pdb_path) + ) + + try: + all_atoms, protein_atoms, ligand_atoms = ( + PLA15Benchmark.separate_protein_ligand_simple(pdb_path) + ) + + if len(ligand_atoms) == 0: + print(f"Warning: No ligand atoms found in {pdb_path.name}") + return {} + + if len(protein_atoms) == 0: + print(f"Warning: No protein atoms found in {pdb_path.name}") + return {} + + base_id = pdb_path.stem + + complex_atoms = PLA15Benchmark.mda_atoms_to_ase( + list(all_atoms), total_charge, base_id + ) + protein_frag = PLA15Benchmark.mda_atoms_to_ase( + protein_atoms, charge_a, base_id + ) + ligand = PLA15Benchmark.mda_atoms_to_ase(ligand_atoms, charge_b, base_id) + + return {"complex": complex_atoms, "protein": protein_frag, "ligand": ligand} + + except (ImportError, AttributeError, ValueError, OSError) as e: + print(f"Warning: Error processing {pdb_path}: {e}") + return {} + + @staticmethod + def parse_pla15_references(path: Path) -> dict[str, float]: + """ + Parse PLA15 reference interaction energies from file. + + Parameters + ---------- + path : Path + Path to reference file. + + Returns + ------- + Dict[str, float] + Dictionary mapping system identifier to reference energy in eV. + """ + ref: dict[str, float] = {} + + for line in path.read_text().splitlines(): + line = line.strip() + if not line or line.lower().startswith("no.") or line.startswith("-"): + continue + + parts = line.split() + if len(parts) < 3: + continue + + try: + energy_kcal = float(parts[-1]) + except ValueError: + continue + + # Extract full identifier with residue type + full_identifier = parts[1].replace(".pdb", "") + + # TODO: review this part + error handling + # Extract base identifier by removing residue type suffix + # Format: "1ABC_15_lys" -> "1ABC_15" + identifier_parts = full_identifier.split("_") + if len(identifier_parts) >= 3: + # Last part is residue type (lys, arg, asp, etc.) + base_identifier = "_".join(identifier_parts[:-1]) + else: + # Fallback: use full identifier if format is unexpected + base_identifier = full_identifier + + energy_ev = energy_kcal * KCAL_PER_MOL_TO_EV # Convert to eV + ref[base_identifier] = energy_ev + + return ref + + @staticmethod + def interaction_energy(fragments: dict[str, Atoms], calc: Calculator) -> float: + """ + Calculate interaction energy from fragments. + + Parameters + ---------- + fragments : Dict[str, Atoms] + Dictionary containing 'complex', 'protein', and 'ligand' fragments. + calc : Calculator + ASE calculator for energy calculations. + + Returns + ------- + float + Interaction energy in eV. + """ + fragments["complex"].calc = calc + e_complex = fragments["complex"].get_potential_energy() + fragments["protein"].calc = calc + e_protein = fragments["protein"].get_potential_energy() + fragments["ligand"].calc = calc + e_ligand = fragments["ligand"].get_potential_energy() + return e_complex - e_protein - e_ligand + + @staticmethod + def benchmark_pla15( + calc: Calculator, model_name: str, base_dir: Path + ) -> list[Atoms]: + """ + Benchmark PLA15 dataset. + + Parameters + ---------- + calc : Calculator + ASE calculator for energy calculations. + model_name : str + Name of the model being benchmarked. + base_dir : Path + Base directory containing PLA15 data. + + Returns + ------- + list[Atoms] + List of complex structures. + """ + print(f"Benchmarking PLA15 with {model_name}...") + + pla15_dir = base_dir / "PLA15_pdbs" + pla15_ref_file = pla15_dir / "reference_energies.txt" + + pla15_refs = PLA15Benchmark.parse_pla15_references(pla15_ref_file) + pdb_files = list(pla15_dir.glob("*.pdb")) + + complex_atoms_list = [] + + for pdb_file in tqdm(pdb_files, desc="PLA15"): + identifier = pdb_file.stem + if identifier not in pla15_refs: + continue + + fragments = PLA15Benchmark.process_pdb_file(pdb_file) + if not fragments: + continue + + try: + # Calculate interaction energy + e_int_model = PLA15Benchmark.interaction_energy(fragments, calc) + e_int_ref = pla15_refs[identifier] + + # Calculate errors + error_ev = e_int_model - e_int_ref + error_kcal = error_ev * EV_TO_KCAL_PER_MOL + + # Store additional info in complex atoms + complex_atoms = fragments["complex"] + complex_atoms.info["model"] = model_name + complex_atoms.info["E_int_model_kcal"] = ( + e_int_model * EV_TO_KCAL_PER_MOL + ) + complex_atoms.info["E_int_ref_kcal"] = e_int_ref * EV_TO_KCAL_PER_MOL + complex_atoms.info["E_int_model_ev"] = e_int_model + complex_atoms.info["E_int_ref_ev"] = e_int_ref + complex_atoms.info["error_kcal"] = error_kcal + complex_atoms.info["error_ev"] = error_ev + complex_atoms.info["identifier"] = identifier + complex_atoms.info["dataset"] = "PLA15" + complex_atoms.info["complex_atoms"] = len(complex_atoms) + complex_atoms.info["protein_atoms"] = len(fragments["protein"]) + complex_atoms.info["ligand_atoms"] = len(fragments["ligand"]) + complex_atoms.info["complex_charge"] = complex_atoms.info["charge"] + complex_atoms.info["protein_charge"] = fragments["protein"].info[ + "charge" + ] + complex_atoms.info["ligand_charge"] = fragments["ligand"].info["charge"] + + # Classify interaction type based on fragment charges + protein_charge = fragments["protein"].info["charge"] + ligand_charge = fragments["ligand"].info["charge"] + + if protein_charge != 0 and ligand_charge != 0: + interaction_type = "ion-ion" + elif protein_charge != 0 or ligand_charge != 0: + interaction_type = "ion-neutral" + else: + interaction_type = "neutral-neutral" + + complex_atoms.info["interaction_type"] = interaction_type + + complex_atoms_list.append(complex_atoms) + + # print( + # f" {identifier}: E_int = {e_int_model:.6f} eV " + # f"(ref: {e_int_ref:.6f} eV, error: {error_kcal:.2f} kcal/mol)" + # ) + + except (KeyError, ValueError, RuntimeError) as e: + print(f"Error processing {identifier}: {e}") + continue + + return complex_atoms_list + + def run(self): + """Run PLA15 benchmark calculations.""" + calc = self.model.get_calculator() + + # Get benchmark data + base_dir = ( + get_benchmark_data("protein-ligand-data_PLA15_PLF547.zip") + / "protein-ligand-data_PLA15_PLF547" + ) + + # Run benchmark + complex_atoms = self.benchmark_pla15(calc, self.model_name, base_dir) + + # Write output structures + write_dir = OUT_PATH / self.model_name + write_dir.mkdir(parents=True, exist_ok=True) + + # Save individual complex atoms files for each system + for i, atoms in enumerate(complex_atoms): + atoms_copy = atoms.copy() + # atoms_copy.calc = None + + # Write each system to its own file + system_file = write_dir / f"{i}.xyz" + write(system_file, atoms_copy, format="extxyz") + + # Calculate and save MAE if we have results + # if complex_atoms: + # errors = [atoms.info["error_kcal"] for atoms in complex_atoms] + # mae = sum(abs(error) for error in errors) / len(errors) + # mae_data = {"MAE_kcal": float(mae)} + + # with open(write_dir / "mae_results.json", "w") as f: + # json.dump(mae_data, f, indent=2) + + # print(f"MAE for {self.model_name} on PLA15: {mae:.2f} kcal/mol") + + +def build_project(repro: bool = False) -> None: + """ + Build mlipx project. + + Parameters + ---------- + repro + Whether to call dvc repro -f after building. + """ + project = mlipx.Project() + benchmark_node_dict = {} + + for model_name, model in MODELS.items(): + with project.group(model_name): + benchmark = PLA15Benchmark( + model=model, + model_name=model_name, + ) + benchmark_node_dict[model_name] = benchmark + + if repro: + with chdir(Path(__file__).parent): + project.repro(build=True, force=True) + else: + project.build() + + +def test_pla15(): + """Run PLA15 benchmark via pytest.""" + build_project(repro=True) diff --git a/mlip_testing/calcs/supramolecular/PLA15/dvc.lock b/mlip_testing/calcs/supramolecular/PLA15/dvc.lock new file mode 100644 index 00000000..a520b840 --- /dev/null +++ b/mlip_testing/calcs/supramolecular/PLA15/dvc.lock @@ -0,0 +1,61 @@ +schema: '2.0' +stages: + mace_matpes_r2scan_PLA15Benchmark: + cmd: zntrack run calc_PLA15.PLA15Benchmark --name + mace_matpes_r2scan_PLA15Benchmark + params: + params.yaml: + mace_matpes_r2scan_PLA15Benchmark: + model: + _cls: mlipx.nodes.generic_ase.GenericASECalculator + class_name: mace_mp + device: auto + kwargs: + head: matpes_pbe + model: mace-matpes-r2scan-0 + module: mace.calculators + spec: + model_name: mace_matpes_r2scan + outs: + - path: nodes/mace_matpes_r2scan/PLA15Benchmark/node-meta.json + hash: md5 + md5: e7cbd2747fe698844caae0430c291dea + size: 750 + mace_mp_0a_PLA15Benchmark: + cmd: zntrack run calc_PLA15.PLA15Benchmark --name mace_mp_0a_PLA15Benchmark + params: + params.yaml: + mace_mp_0a_PLA15Benchmark: + model: + _cls: mlipx.nodes.generic_ase.GenericASECalculator + class_name: mace_mp + device: auto + kwargs: + model: medium + module: mace.calculators + spec: + model_name: mace_mp_0a + outs: + - path: nodes/mace_mp_0a/PLA15Benchmark/node-meta.json + hash: md5 + md5: 18271adf5648d016336e255f8990789b + size: 677 + mace_omat_0_PLA15Benchmark: + cmd: zntrack run calc_PLA15.PLA15Benchmark --name mace_omat_0_PLA15Benchmark + params: + params.yaml: + mace_omat_0_PLA15Benchmark: + model: + _cls: mlipx.nodes.generic_ase.GenericASECalculator + class_name: mace_mp + device: auto + kwargs: + model: medium-omat-0 + module: mace.calculators + spec: + model_name: mace_omat_0 + outs: + - path: nodes/mace_omat_0/PLA15Benchmark/node-meta.json + hash: md5 + md5: 6d756d4dc178f45efc57f03fc8ff9438 + size: 686 diff --git a/mlip_testing/calcs/supramolecular/PLA15/dvc.yaml b/mlip_testing/calcs/supramolecular/PLA15/dvc.yaml new file mode 100644 index 00000000..50971d45 --- /dev/null +++ b/mlip_testing/calcs/supramolecular/PLA15/dvc.yaml @@ -0,0 +1,8 @@ +stages: + mace_mp_0a_PLA15Benchmark: + cmd: zntrack run calc_PLA15.PLA15Benchmark --name mace_mp_0a_PLA15Benchmark + metrics: + - nodes/mace_mp_0a/PLA15Benchmark/node-meta.json: + cache: true + params: + - mace_mp_0a_PLA15Benchmark diff --git a/pyproject.toml b/pyproject.toml index 72328585..7d014c2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "kaleido>=1.0.0", "mlipx<0.2,>=0.1.5", "scikit-learn>=1.7.1", + "mdanalysis>=2.9.0", ] [project.optional-dependencies]