Skip to content

Commit b8dbad9

Browse files
committed
generalised metric weights into build_components.py and base_app.py
1 parent 5a61ed2 commit b8dbad9

File tree

6 files changed

+65
-82
lines changed

6 files changed

+65
-82
lines changed

ml_peg/app/base_app.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
from dash.development.base_component import Component
99
from dash.html import Div
1010

11-
from ml_peg.app.utils.build_components import build_test_layout
11+
from ml_peg.app.utils.build_components import (
12+
build_metric_weight_components,
13+
build_test_layout,
14+
)
1215
from ml_peg.app.utils.load import rebuild_table
1316

1417

@@ -74,12 +77,18 @@ def build_layout(self) -> Div:
7477
Div component with list all components for app.
7578
"""
7679
# Define all components/placeholders
80+
# Auto-append metric-weight controls for this benchmark table
81+
metric_weights = build_metric_weight_components(self.table)
82+
extra_components = [metric_weights]
83+
if self.extra_components:
84+
extra_components.extend(self.extra_components)
85+
7786
return build_test_layout(
7887
name=self.name,
7988
description=self.description,
8089
docs_url=self.docs_url,
8190
table=self.table,
82-
extra_components=self.extra_components,
91+
extra_components=extra_components,
8392
)
8493

8594
@abstractmethod

ml_peg/app/nebs/li_diffusion/app_li_diffusion.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
from __future__ import annotations
44

5-
import json
6-
75
from dash import Dash
86
from dash.html import Div
97

@@ -13,7 +11,6 @@
1311
plot_from_table_cell,
1412
struct_from_scatter,
1513
)
16-
from ml_peg.app.utils.build_components import build_weight_components
1714
from ml_peg.app.utils.load import read_plot
1815
from ml_peg.calcs.models.models import MODELS
1916

@@ -76,29 +73,12 @@ def get_app() -> LiDiffusionApp:
7673
LiDiffusionApp
7774
Benchmark layout and callback registration.
7875
"""
79-
# Build metric weight components (sliders + inputs) for Li diffusion metrics
80-
with open(DATA_PATH / "li_diffusion_metrics_table.json") as f:
81-
table_json = json.load(f)
82-
metric_columns = [
83-
c["id"]
84-
for c in table_json["columns"]
85-
if c["id"] not in ("MLIP", "Score", "Rank", "id")
86-
]
87-
88-
metric_weights = build_weight_components(
89-
header="Metric weights",
90-
columns=metric_columns,
91-
input_ids=[f"{BENCHMARK_NAME}-{c.replace(' ', '-')}" for c in metric_columns],
92-
table_id=f"{BENCHMARK_NAME}-table",
93-
)
94-
9576
return LiDiffusionApp(
9677
name=BENCHMARK_NAME,
9778
description=("Performance in predicting energy barriers for Li diffision."),
9879
docs_url=DOCS_URL,
9980
table_path=DATA_PATH / "li_diffusion_metrics_table.json",
10081
extra_components=[
101-
metric_weights,
10282
Div(id=f"{BENCHMARK_NAME}-figure-placeholder"),
10383
Div(id=f"{BENCHMARK_NAME}-struct-placeholder"),
10484
],

ml_peg/app/surfaces/OC157/app_OC157.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
from __future__ import annotations
44

5-
import json
6-
75
from dash import Dash
86
from dash.html import Div
97
import numpy as np
@@ -14,7 +12,6 @@
1412
plot_from_table_column,
1513
struct_from_scatter,
1614
)
17-
from ml_peg.app.utils.build_components import build_weight_components
1815
from ml_peg.app.utils.load import read_plot
1916
from ml_peg.calcs.models.models import MODELS
2017

@@ -68,22 +65,6 @@ def get_app() -> OC157App:
6865
OC157App
6966
Benchmark layout and callback registration.
7067
"""
71-
# Build metric weight components (sliders + inputs) for OC157 metrics
72-
with open(DATA_PATH / "oc157_metrics_table.json") as f:
73-
table_json = json.load(f)
74-
metric_columns = [
75-
c["id"]
76-
for c in table_json["columns"]
77-
if c["id"] not in ("MLIP", "Score", "Rank", "id")
78-
]
79-
80-
metric_weights = build_weight_components(
81-
header="Metric weights",
82-
columns=metric_columns,
83-
input_ids=[f"{BENCHMARK_NAME}-{c.replace(' ', '-')}" for c in metric_columns],
84-
table_id=f"{BENCHMARK_NAME}-table",
85-
)
86-
8768
return OC157App(
8869
name=BENCHMARK_NAME,
8970
description=(
@@ -93,7 +74,6 @@ def get_app() -> OC157App:
9374
docs_url=DOCS_URL,
9475
table_path=DATA_PATH / "oc157_metrics_table.json",
9576
extra_components=[
96-
metric_weights,
9777
Div(id=f"{BENCHMARK_NAME}-figure-placeholder"),
9878
Div(id=f"{BENCHMARK_NAME}-struct-placeholder"),
9979
],

ml_peg/app/surfaces/S24/app_S24.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
from __future__ import annotations
44

5-
import json
6-
75
from dash import Dash
86
from dash.html import Div
97

@@ -13,7 +11,6 @@
1311
plot_from_table_column,
1412
struct_from_scatter,
1513
)
16-
from ml_peg.app.utils.build_components import build_weight_components
1714
from ml_peg.app.utils.load import read_plot
1815
from ml_peg.calcs.models.models import MODELS
1916

@@ -63,22 +60,6 @@ def get_app() -> S24App:
6360
S24App
6461
Benchmark layout and callback registration.
6562
"""
66-
# Build metric weight components (sliders + inputs) for S24 metrics
67-
with open(DATA_PATH / "s24_metrics_table.json") as f:
68-
table_json = json.load(f)
69-
metric_columns = [
70-
c["id"]
71-
for c in table_json["columns"]
72-
if c["id"] not in ("MLIP", "Score", "Rank", "id")
73-
]
74-
75-
metric_weights = build_weight_components(
76-
header="Metric weights",
77-
columns=metric_columns,
78-
input_ids=[f"{BENCHMARK_NAME}-{c.replace(' ', '-')}" for c in metric_columns],
79-
table_id=f"{BENCHMARK_NAME}-table",
80-
)
81-
8263
return S24App(
8364
name=BENCHMARK_NAME,
8465
description=(
@@ -88,7 +69,6 @@ def get_app() -> S24App:
8869
docs_url=DOCS_URL,
8970
table_path=DATA_PATH / "s24_metrics_table.json",
9071
extra_components=[
91-
metric_weights,
9272
Div(id=f"{BENCHMARK_NAME}-figure-placeholder"),
9373
Div(id=f"{BENCHMARK_NAME}-struct-placeholder"),
9474
],

ml_peg/app/surfaces/elemental_slab_oxygen_adsorption/app_elemental_slab_oxygen_adsorption.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
from __future__ import annotations
44

5-
import json
6-
75
from dash import Dash
86
from dash.html import Div
97

@@ -13,7 +11,6 @@
1311
plot_from_table_column,
1412
struct_from_scatter,
1513
)
16-
from ml_peg.app.utils.build_components import build_weight_components
1714
from ml_peg.app.utils.load import read_plot
1815
from ml_peg.calcs.models.models import MODELS
1916

@@ -63,22 +60,6 @@ def get_app() -> ElementalSlabOxygenAdsorptionApp:
6360
ElementalSlabOxygenAdsorptionApp
6461
Benchmark layout and callback registration.
6562
"""
66-
# Build metric weight components (sliders + inputs) for metrics
67-
with open(DATA_PATH / "elemental_slab_oxygen_adsorption_metrics_table.json") as f:
68-
table_json = json.load(f)
69-
metric_columns = [
70-
c["id"]
71-
for c in table_json["columns"]
72-
if c["id"] not in ("MLIP", "Score", "Rank", "id")
73-
]
74-
75-
metric_weights = build_weight_components(
76-
header="Metric weights",
77-
columns=metric_columns,
78-
input_ids=[f"{BENCHMARK_NAME}-{c.replace(' ', '-')}" for c in metric_columns],
79-
table_id=f"{BENCHMARK_NAME}-table",
80-
)
81-
8263
return ElementalSlabOxygenAdsorptionApp(
8364
name=BENCHMARK_NAME,
8465
description=(
@@ -88,7 +69,6 @@ def get_app() -> ElementalSlabOxygenAdsorptionApp:
8869
docs_url=DOCS_URL,
8970
table_path=DATA_PATH / "elemental_slab_oxygen_adsorption_metrics_table.json",
9071
extra_components=[
91-
metric_weights,
9272
Div(id=f"{BENCHMARK_NAME}-figure-placeholder"),
9373
Div(id=f"{BENCHMARK_NAME}-struct-placeholder"),
9474
],

ml_peg/app/utils/build_components.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,57 @@ def build_test_layout(
205205
layout_contents.extend(extra_components)
206206

207207
return Div(layout_contents)
208+
209+
210+
def build_metric_weight_components(
211+
table: DataTable, header: str = "Metric weights"
212+
) -> Div:
213+
"""
214+
Build metric-weight sliders and inputs for a benchmark table.
215+
216+
Parameters
217+
----------
218+
table
219+
Benchmark results DataTable.
220+
header
221+
Header label shown above the sliders. Default is "Metric weights".
222+
223+
Returns
224+
-------
225+
Div
226+
Div containing sliders, inputs, reset button and weight store.
227+
"""
228+
# Identify metric columns (exclude reserved columns)
229+
reserved = {"MLIP", "Score", "Rank", "id"}
230+
metric_columns = [
231+
col["id"] for col in table.columns if col.get("id") not in reserved
232+
]
233+
234+
if not metric_columns:
235+
return Div()
236+
237+
# Use table id to generate unique input ids
238+
def _safe(col: str) -> str:
239+
"""
240+
Make a safe suffix by replacing spaces with hyphens.
241+
242+
Parameters
243+
----------
244+
col
245+
Column name to sanitise.
246+
247+
Returns
248+
-------
249+
str
250+
Sanitised column suffix.
251+
"""
252+
return col.replace(" ", "-")
253+
254+
input_ids = [f"{table.id}-{_safe(col)}" for col in metric_columns]
255+
256+
return build_weight_components(
257+
header=header,
258+
columns=metric_columns,
259+
input_ids=input_ids,
260+
table_id=table.id,
261+
)

0 commit comments

Comments
 (0)