Skip to content

Commit 1455be1

Browse files
committed
PLA15 calc, analysis and app initial commit + edit pyproject.toml to add MDAnalysis package
1 parent 9f3baf9 commit 1455be1

File tree

9 files changed

+870
-0
lines changed

9 files changed

+870
-0
lines changed
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
"""Analyse PLA15 benchmark."""
2+
3+
from __future__ import annotations
4+
5+
import pytest
6+
7+
from mlip_testing.analysis.utils.decorators import build_table, plot_parity
8+
from mlip_testing.analysis.utils.utils import mae
9+
from mlip_testing.app import APP_ROOT
10+
from mlip_testing.calcs import CALCS_ROOT
11+
from mlip_testing.calcs.models.models import MODELS
12+
13+
CALC_PATH = CALCS_ROOT / "supramolecular" / "PLA15" / "outputs"
14+
OUT_PATH = APP_ROOT / "data" / "supramolecular" / "PLA15"
15+
16+
17+
def get_system_identifiers() -> list[str]:
18+
"""
19+
Get list of PLA15 system identifiers.
20+
21+
Returns
22+
-------
23+
list[str]
24+
List of system identifiers from structure files.
25+
"""
26+
from ase.io import read
27+
28+
system_identifiers = []
29+
for model_name in MODELS:
30+
model_dir = CALC_PATH / model_name
31+
if model_dir.exists():
32+
xyz_files = sorted(model_dir.glob("*.xyz"))
33+
if xyz_files:
34+
for xyz_file in xyz_files:
35+
atoms = read(xyz_file)
36+
system_identifiers.append(
37+
atoms.info.get("identifier", f"system_{xyz_file.stem}")
38+
)
39+
break
40+
return system_identifiers
41+
42+
43+
def get_atom_counts() -> list[int]:
44+
"""
45+
Get complex atom counts for PLA15.
46+
47+
Returns
48+
-------
49+
list[int]
50+
List of complex atom counts from structure files.
51+
"""
52+
from ase.io import read
53+
54+
for model_name in MODELS:
55+
model_dir = CALC_PATH / model_name
56+
if model_dir.exists():
57+
xyz_files = sorted(model_dir.glob("*.xyz"))
58+
if xyz_files:
59+
atom_counts = []
60+
for xyz_file in xyz_files:
61+
atoms = read(xyz_file)
62+
atom_counts.append(len(atoms))
63+
return atom_counts
64+
return []
65+
66+
67+
def get_charges() -> list[int]:
68+
"""
69+
Get complex charges for PLA15.
70+
71+
Returns
72+
-------
73+
list[int]
74+
List of complex charges from structure files.
75+
"""
76+
from ase.io import read
77+
78+
for model_name in MODELS:
79+
model_dir = CALC_PATH / model_name
80+
if model_dir.exists():
81+
xyz_files = sorted(model_dir.glob("*.xyz"))
82+
if xyz_files:
83+
charges = []
84+
for xyz_file in xyz_files:
85+
atoms = read(xyz_file)
86+
charges.append(atoms.info.get("complex_charge", 0))
87+
return charges
88+
return []
89+
90+
91+
def get_protein_atom_counts() -> list[int]:
92+
"""
93+
Get protein atom counts for PLA15.
94+
95+
Returns
96+
-------
97+
list[int]
98+
List of protein atom counts from structure files.
99+
"""
100+
from ase.io import read
101+
102+
for model_name in MODELS:
103+
model_dir = CALC_PATH / model_name
104+
if model_dir.exists():
105+
xyz_files = sorted(model_dir.glob("*.xyz"))
106+
if xyz_files:
107+
protein_counts = []
108+
for xyz_file in xyz_files:
109+
atoms = read(xyz_file)
110+
protein_counts.append(atoms.info.get("protein_atoms", 0))
111+
return protein_counts
112+
return []
113+
114+
115+
def get_ligand_atom_counts() -> list[int]:
116+
"""
117+
Get ligand atom counts for PLA15.
118+
119+
Returns
120+
-------
121+
list[int]
122+
List of ligand atom counts from structure files.
123+
"""
124+
from ase.io import read
125+
126+
for model_name in MODELS:
127+
model_dir = CALC_PATH / model_name
128+
if model_dir.exists():
129+
xyz_files = sorted(model_dir.glob("*.xyz"))
130+
if xyz_files:
131+
ligand_counts = []
132+
for xyz_file in xyz_files:
133+
atoms = read(xyz_file)
134+
ligand_counts.append(atoms.info.get("ligand_atoms", 0))
135+
return ligand_counts
136+
return []
137+
138+
139+
@pytest.fixture
140+
@plot_parity(
141+
filename=OUT_PATH / "figure_interaction_energies.json",
142+
title="PLA15 Protein-Ligand Interaction Energies",
143+
x_label="Predicted interaction energy / kcal/mol",
144+
y_label="Reference interaction energy / kcal/mol",
145+
hoverdata={
146+
"System": get_system_identifiers(),
147+
"Complex Atoms": get_atom_counts(),
148+
"Protein Atoms": get_protein_atom_counts(),
149+
"Ligand Atoms": get_ligand_atom_counts(),
150+
"Charge": get_charges(),
151+
},
152+
)
153+
def interaction_energies() -> dict[str, list]:
154+
"""
155+
Get interaction energies for all PLA15 systems.
156+
157+
Returns
158+
-------
159+
dict[str, list]
160+
Dictionary of reference and predicted interaction energies.
161+
"""
162+
from ase.io import read
163+
164+
results = {"ref": []} | {mlip: [] for mlip in MODELS}
165+
ref_stored = False
166+
167+
for model_name in MODELS:
168+
model_dir = CALC_PATH / model_name
169+
170+
if not model_dir.exists():
171+
results[model_name] = []
172+
continue
173+
174+
xyz_files = sorted(model_dir.glob("*.xyz"))
175+
if not xyz_files:
176+
results[model_name] = []
177+
continue
178+
179+
model_energies = []
180+
ref_energies = []
181+
182+
for xyz_file in xyz_files:
183+
atoms = read(xyz_file)
184+
model_energies.append(atoms.info["E_int_model_kcal"])
185+
if not ref_stored:
186+
ref_energies.append(atoms.info["E_int_ref_kcal"])
187+
188+
results[model_name] = model_energies
189+
190+
# Store reference energies (only once)
191+
if not ref_stored:
192+
results["ref"] = ref_energies
193+
ref_stored = True
194+
195+
# Copy individual structure files to app data directory
196+
structs_dir = OUT_PATH / model_name
197+
structs_dir.mkdir(parents=True, exist_ok=True)
198+
199+
# Copy individual structure files
200+
import shutil
201+
202+
for i, xyz_file in enumerate(xyz_files):
203+
shutil.copy(xyz_file, structs_dir / f"{i}.xyz")
204+
205+
return results
206+
207+
208+
@pytest.fixture
209+
def pla15_mae(interaction_energies) -> dict[str, float]:
210+
"""
211+
Get mean absolute error for interaction energies.
212+
213+
Parameters
214+
----------
215+
interaction_energies
216+
Dictionary of reference and predicted interaction energies.
217+
218+
Returns
219+
-------
220+
dict[str, float]
221+
Dictionary of predicted interaction energy errors for all models.
222+
"""
223+
results = {}
224+
for model_name in MODELS:
225+
if interaction_energies[model_name]:
226+
results[model_name] = mae(
227+
interaction_energies["ref"], interaction_energies[model_name]
228+
)
229+
else:
230+
results[model_name] = float("nan")
231+
return results
232+
233+
234+
@pytest.fixture
235+
@build_table(
236+
filename=OUT_PATH / "pla15_metrics_table.json",
237+
metric_tooltips={
238+
"Model": "Name of the model",
239+
"MAE": "Mean Absolute Error for all systems (kcal/mol)",
240+
},
241+
)
242+
def metrics(pla15_mae: dict[str, float]) -> dict[str, dict]:
243+
"""
244+
Get all PLA15 metrics.
245+
246+
Parameters
247+
----------
248+
pla15_mae
249+
Mean absolute errors for all systems.
250+
251+
Returns
252+
-------
253+
dict[str, dict]
254+
Metric names and values for all models.
255+
"""
256+
return {
257+
"MAE": pla15_mae,
258+
}
259+
260+
261+
def test_pla15(metrics: dict[str, dict]) -> None:
262+
"""
263+
Run PLA15 test.
264+
265+
Parameters
266+
----------
267+
metrics
268+
All PLA15 metrics.
269+
"""
270+
return
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""Run PLA15 app."""
2+
3+
from __future__ import annotations
4+
5+
from pathlib import Path
6+
7+
from dash import Dash
8+
from dash.html import Div
9+
10+
from mlip_testing.app import APP_ROOT
11+
from mlip_testing.app.base_app import BaseApp
12+
from mlip_testing.app.utils.build_callbacks import (
13+
plot_from_table_column,
14+
struct_from_scatter,
15+
)
16+
from mlip_testing.app.utils.load import read_plot
17+
from mlip_testing.calcs.models.models import MODELS
18+
19+
BENCHMARK_NAME = Path(__file__).name.removeprefix("app_").removesuffix(".py")
20+
DATA_PATH = APP_ROOT / "data" / "supramolecular" / "PLA15"
21+
22+
23+
class PLA15App(BaseApp):
24+
"""PLA15 benchmark app layout and callbacks."""
25+
26+
def register_callbacks(self) -> None:
27+
"""Register callbacks to app."""
28+
scatter = read_plot(
29+
DATA_PATH / "figure_interaction_energies.json",
30+
id=f"{BENCHMARK_NAME}-figure",
31+
)
32+
33+
structs_dir = DATA_PATH / list(MODELS.keys())[0]
34+
# Assets dir will be parent directory - individual files for each system
35+
structs = [
36+
f"assets/supramolecular/PLA15/{list(MODELS.keys())[0]}/{i}.xyz"
37+
for i in range(len(list(structs_dir.glob("*.xyz"))))
38+
]
39+
40+
plot_from_table_column(
41+
table_id=self.table_id,
42+
plot_id=f"{BENCHMARK_NAME}-figure-placeholder",
43+
column_to_plot={"MAE": scatter},
44+
)
45+
46+
struct_from_scatter(
47+
scatter_id=f"{BENCHMARK_NAME}-figure",
48+
struct_id=f"{BENCHMARK_NAME}-struct-placeholder",
49+
structs=structs,
50+
mode="struct",
51+
)
52+
53+
54+
def get_app() -> PLA15App:
55+
"""
56+
Get PLA15 benchmark app layout and callback registration.
57+
58+
Returns
59+
-------
60+
PLA15App
61+
Benchmark layout and callback registration.
62+
"""
63+
return PLA15App(
64+
name=BENCHMARK_NAME,
65+
title="PLA15",
66+
description=(
67+
"Performance in predicting protein-ligand interaction energies for 15 "
68+
"complete active site complexes."
69+
),
70+
table_path=DATA_PATH / "pla15_metrics_table.json",
71+
extra_components=[
72+
Div(id=f"{BENCHMARK_NAME}-figure-placeholder"),
73+
Div(id=f"{BENCHMARK_NAME}-struct-placeholder"),
74+
],
75+
)
76+
77+
78+
if __name__ == "__main__":
79+
# Create Dash app
80+
full_app = Dash(__name__, assets_folder=DATA_PATH.parent.parent)
81+
82+
# Construct layout and register callbacks
83+
pla15_app = get_app()
84+
full_app.layout = pla15_app.layout
85+
pla15_app.register_callbacks()
86+
87+
# Run app
88+
full_app.run(port=8055, debug=True)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
/config.local
2+
/tmp
3+
/cache

mlip_testing/calcs/supramolecular/PLA15/.dvc/config

Whitespace-only changes.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Add patterns of files dvc should ignore, which could improve
2+
# the performance. Learn more at
3+
# https://dvc.org/doc/user-guide/dvcignore

0 commit comments

Comments
 (0)