Skip to content

Commit 292e0c0

Browse files
xiangchenjhuXiang Chen
andauthored
Separate the filtering process out from task (#80)
* add seperate filter task * add filter class and implement * filter code update from mac * update test and remove unused variable * adjust comment of function * update based on PR comments * remove unnecessary filter Class --------- Co-authored-by: Xiang Chen <xchen286@ssec01.idies.jhu.edu>
1 parent 4ff0ae5 commit 292e0c0

File tree

5 files changed

+150
-59
lines changed

5 files changed

+150
-59
lines changed

bluephos/bluephos_pipeline.py

Lines changed: 76 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@
1010
from bluephos.tasks.optimizegeometries import OptimizeGeometriesTask
1111
from bluephos.tasks.smiles2sdf import Smiles2SDFTask
1212
from bluephos.tasks.dft import DFTTask
13+
from bluephos.tasks.filter_pipeline import (
14+
FilterNNInTask,
15+
FilterNNOutTask,
16+
FilterXTBInTask,
17+
FilterXTBOutTask,
18+
FilterDFTInTask,
19+
FilterDFTOutTask,
20+
)
1321

1422

1523
def ligand_pair_generator(halides_file, acids_file):
@@ -36,24 +44,32 @@ def ligand_pair_generator(halides_file, acids_file):
3644
yield pd.DataFrame([ligand_pair])
3745

3846

39-
def rerun_candidate_generator(input_dir, t_nn, t_ste):
47+
def check_column(df, column_name, condition_func, default=True):
48+
"""
49+
Helper function to check if a column exists and apply a condition.
50+
51+
"""
52+
if column_name not in df.columns:
53+
return pd.Series([default] * len(df), index=df.index)
54+
else:
55+
return condition_func(df[column_name])
56+
57+
58+
def rerun_candidate_generator(input_dir, t_nn, t_ste, t_dft):
4059
"""
4160
Generates candidate DataFrames from parquet files in the input directory.
4261
4362
Core Algorithm:
44-
- If the absolute value of 'z' is less than t_nn,
45-
- and 'ste' is None or its absolute value is less than t_ste,
46-
- and 'dft_energy_diff' is None,
47-
This row is then added to a new DataFrame and yielded for re-run.
48-
49-
Additional Context:
50-
1. All valid ligand pairs should already have run through the NN process and have a 'z' score.
51-
2. If a row's 'ste' is None, then it's 'dft_energy_diff' should also be None.
63+
- A row is selected for re-run if:
64+
1. 'z' is None or its absolute value is less than t_nn,
65+
2. and 'ste' is None or its absolute value is less than t_ste,
66+
3. and 'dft_energy_diff' is None.
5267
5368
Args:
5469
input_dir (str): Directory containing input parquet files.
5570
t_nn (float): Threshold for 'z' score.
5671
t_ste (float): Threshold for 'ste'.
72+
t_dft (float): (Optional) Threshold for 'dft_energy_diff' (not currently used)
5773
5874
Yields:
5975
DataFrame: A single-row DataFrame containing candidate data.
@@ -63,10 +79,9 @@ def rerun_candidate_generator(input_dir, t_nn, t_ste):
6379
df["ste"] = df["ste"].replace({None: np.nan})
6480

6581
filtered = df[
66-
(df["z"].notnull())
67-
& (df["z"].abs() < t_nn)
68-
& ((df["ste"].isnull()) | (df["ste"].abs() < t_ste))
69-
& (df["dft_energy_diff"].isna())
82+
check_column(df, "z", lambda col: col.isnull() | (col.abs() < t_nn))
83+
& check_column(df, "ste", lambda col: col.isnull() | (col.abs() < t_ste))
84+
& check_column(df, "dft_energy_diff", lambda col: col.isna())
7085
]
7186
for _, row in filtered.iterrows():
7287
yield row.to_frame().transpose()
@@ -78,15 +93,48 @@ def ligand_smiles_reader_generator(ligand_smiles):
7893
yield row.to_frame().transpose()
7994

8095

81-
def get_generator(ligand_smiles, halides, acids, input_dir, t_nn, t_ste):
96+
def get_generator(ligand_smiles, halides, acids, input_dir, t_nn, t_ste, t_dft):
8297
"""
8398
Get the appropriate generator based on the input directory presence.
8499
"""
85100
if ligand_smiles:
86101
return lambda: ligand_smiles_reader_generator(ligand_smiles)
87102
elif not input_dir:
88103
return lambda: ligand_pair_generator(halides, acids)
89-
return lambda: rerun_candidate_generator(input_dir, t_nn, t_ste)
104+
return lambda: rerun_candidate_generator(input_dir, t_nn, t_ste, t_dft)
105+
106+
107+
def build_pipeline_graph(input_dir: str, ligand_smiles: str):
108+
"""
109+
Construct the pipeline graph based on input conditions.
110+
111+
Args:
112+
input_dir (str): Directory containing input parquet files.
113+
ligand_smiles (str): Path to the ligand SMILES CSV file.
114+
115+
Returns:
116+
list: A list of task tuples representing the pipeline graph.
117+
"""
118+
full_pipeline = [
119+
(GenerateLigandTableTask, Smiles2SDFTask), # Generate ligands, then convert SMILES to SDF
120+
(Smiles2SDFTask, NNTask), # Use SMILES to run NN prediction
121+
(NNTask, FilterNNOutTask), # NN filter "out" goes to sink
122+
(NNTask, FilterNNInTask), # NN filter "in" continues to the next task
123+
(FilterNNInTask, OptimizeGeometriesTask), # Optimize geometries for filtered ligands
124+
(OptimizeGeometriesTask, FilterXTBOutTask), # XTB filter "out" goes to sink
125+
(OptimizeGeometriesTask, FilterXTBInTask), # XTB filter "in" continues to the next task
126+
(FilterXTBInTask, DFTTask), # Run DFT calculation for filtered ligands
127+
(DFTTask, FilterDFTOutTask), # DFT filter "out" goes to sink
128+
(DFTTask, FilterDFTInTask), # DFT filter "in" could be processed further
129+
]
130+
131+
if ligand_smiles:
132+
return full_pipeline[1:] # from NNTask (Case 1: Input as ligand SMILES CSV file)
133+
134+
if input_dir:
135+
return full_pipeline[2:] # from Smiles2SDFTask (Case 2: Use parquet files for rerun)
136+
137+
return full_pipeline # from GenerateLigandTableTask (Case 3: Input as halides and acids CSV files)
90138

91139

92140
def get_pipeline(
@@ -99,38 +147,19 @@ def get_pipeline(
99147
input_dir=None, # Directory containing input parquet files(rerun). Defaults to None.
100148
dft_package="orca", # DFT package to use. Defaults to "orca".
101149
xtb=True, # Enable xTb optimize geometries task. Defaults to True.
102-
t_nn=1.5, # Threshold for 'z' score. Defaults to None
103-
t_ste=1.9, # Threshold for 'ste'. Defaults to None
150+
t_nn=1.5, # Threshold for 'z' score.
151+
t_ste=1.9, # Threshold for 'ste'.
152+
t_dft=2.5, # Threshold for 'dft'.
104153
):
105154
"""
106155
Set up and return the BluePhos discovery pipeline executor
107156
Returns:
108157
RayStreamGraphExecutor: An executor for the BluePhos discovery pipeline
109158
"""
110-
steps = (
111-
[
112-
GenerateLigandTableTask,
113-
Smiles2SDFTask,
114-
NNTask,
115-
OptimizeGeometriesTask,
116-
DFTTask,
117-
]
118-
if not (input_dir or ligand_smiles) # input as halides and acids CSV files
119-
else [
120-
NNTask,
121-
OptimizeGeometriesTask,
122-
DFTTask,
123-
]
124-
if not ligand_smiles # input as parquet files
125-
else [
126-
Smiles2SDFTask,
127-
NNTask,
128-
OptimizeGeometriesTask,
129-
DFTTask,
130-
] # input as ligand smiles CSV file
131-
)
132-
generator = get_generator(ligand_smiles, halides, acids, input_dir, t_nn, t_ste)
133-
pipeline_executor = RayStreamGraphExecutor(graph=steps, generator=generator)
159+
160+
generator = get_generator(ligand_smiles, halides, acids, input_dir, t_nn, t_ste, t_dft)
161+
pipeline_graph = build_pipeline_graph(input_dir, ligand_smiles)
162+
pipeline_executor = RayStreamGraphExecutor(graph=pipeline_graph, generator=generator)
134163

135164
context_dict = {
136165
"ligand_smiles": ligand_smiles,
@@ -143,6 +172,7 @@ def get_pipeline(
143172
"xtb": xtb,
144173
"t_nn": t_nn,
145174
"t_ste": t_ste,
175+
"t_dft": t_dft,
146176
}
147177

148178
for key, value in context_dict.items():
@@ -161,7 +191,10 @@ def get_pipeline(
161191
ap.add_argument("--input_dir", required=False, help="Directory containing input parquet files")
162192
ap.add_argument("--t_nn", type=float, required=False, default=1.5, help="Threshold for 'z' score (default: 1.5)")
163193
ap.add_argument("--t_ste", type=float, required=False, default=1.9, help="Threshold for 'ste' (default: 1.9)")
164-
ap.add_argument("--no_xtb", action="store_false", dest="xtb", help="Disable xTB optimization (default: enabled)")
194+
ap.add_argument("--t_dft", type=float, required=False, default=2.5, help="Threshold for 'dft' (default: 2.5)")
195+
ap.add_argument(
196+
"--no_xtb", action="store_false", dest="xtb", default=True, help="Disable xTB optimization (default: enabled)"
197+
)
165198

166199
ap.add_argument(
167200
"--dft_package",
@@ -186,6 +219,7 @@ def get_pipeline(
186219
args.xtb,
187220
args.t_nn,
188221
args.t_ste,
222+
args.t_dft,
189223
),
190224
args,
191225
)

bluephos/tasks/dft.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ def get_dft_calculator(dft_package, n_cpus):
4343
# Process each row of the DataFrame to perform DFT calculations
4444
def process_dataframe(row, t_ste, dft_calculator):
4545
mol_id = row["ligand_identifier"]
46-
ste = row["ste"]
4746
energy_diff = row["dft_energy_diff"]
4847

49-
if ste is None or abs(ste) >= t_ste or energy_diff is not None:
50-
logger.info(f"Skipping DFT on molecule {mol_id} based on z or t_ste conditions.")
51-
return row
48+
# Skip DFT processing if enery_diff already existed (re-run condition)
49+
if energy_diff is not None:
50+
logger.info(f"Skipping DFT on molecule {mol_id} because dft_energy_diff existed (re-run).")
51+
return row # Return the row unchanged
5252

5353
if row["xyz"] not in ["failed", None]:
5454
base_name = row["ligand_identifier"]
@@ -97,5 +97,5 @@ def dft_run(df: pd.DataFrame, t_ste: float, dft_package: str) -> pd.DataFrame:
9797
"dft_package": "dft_package", # Either "ase" or "orca"
9898
},
9999
batch_size=1,
100-
num_cpus=32,
100+
num_cpus=8,
101101
)

bluephos/tasks/filter_pipeline.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import pandas as pd
2+
from dplutils.pipeline import PipelineTask
3+
4+
5+
# Helper function to create both filter_in and filter_out tasks
6+
def filter(df: pd.DataFrame, column: str, threshold: float, filter_in=True) -> pd.DataFrame:
7+
filtered = df[column] < threshold
8+
if filter_in:
9+
return df[filtered]
10+
return df[~filtered]
11+
12+
13+
# Dynamically create filter_in and filter_out tasks for NN
14+
FilterNNInTask = PipelineTask(
15+
"filter_nn_in",
16+
filter,
17+
kwargs={"column": "z"},
18+
context_kwargs={"threshold": "t_nn"},
19+
# filter_in
20+
)
21+
22+
FilterNNOutTask = PipelineTask(
23+
"filter_nn_out",
24+
filter,
25+
kwargs={"column": "z", "filter_in": False},
26+
context_kwargs={"threshold": "t_nn"},
27+
# filter_out
28+
)
29+
30+
# Dynamically create filter_in and filter_out tasks for XTB
31+
FilterXTBInTask = PipelineTask(
32+
"filter_xtb_in",
33+
filter,
34+
kwargs={"column": "ste"},
35+
context_kwargs={"threshold": "t_ste"},
36+
# filter_in
37+
)
38+
39+
FilterXTBOutTask = PipelineTask(
40+
"filter_xtb_out",
41+
filter,
42+
kwargs={"column": "ste", "filter_in": False},
43+
context_kwargs={"threshold": "t_ste"},
44+
# filter_out
45+
)
46+
47+
# Dynamically create filter_in and filter_out tasks for DFT
48+
FilterDFTInTask = PipelineTask(
49+
"filter_dft_in",
50+
filter,
51+
kwargs={"column": "dft_energy_diff"},
52+
context_kwargs={"threshold": "t_dft"},
53+
# filter_in
54+
)
55+
56+
FilterDFTOutTask = PipelineTask(
57+
"filter_dft_out",
58+
filter,
59+
kwargs={"column": "dft_energy_diff", "filter_in": False},
60+
context_kwargs={"threshold": "t_dft"},
61+
# filter_out
62+
)

bluephos/tasks/optimizegeometries.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,16 @@ def calculate_ste(mol):
3131
return None
3232

3333

34-
def optimize(row, t_nn, xtb):
34+
def optimize(row, xtb):
3535
mol_id = row["ligand_identifier"]
36-
z = row["z"]
3736
ste = row["ste"]
3837

3938
# Log the values of z and ste for debugging
4039
logger.info(f"Processing molecule {mol_id} ...")
4140

42-
# Skip processing based on conditions
43-
if z is None or abs(z) >= t_nn or ste is not None:
44-
logger.info(f"Skipping xTB optimization on molecule {mol_id} based on z or t_ste conditions.")
41+
# Skip processing if ste already existed (re-run condition)
42+
if ste is not None:
43+
logger.info(f"Skipping xTB optimization on molecule {mol_id} because t_ste existed (re-run).")
4544
return row # Return the row unchanged
4645

4746
mol = row["structure"]
@@ -100,13 +99,13 @@ def optimize(row, t_nn, xtb):
10099
return row # Return the updated row
101100

102101

103-
def optimize_geometries(df: pd.DataFrame, t_nn: float, xtb: bool) -> pd.DataFrame:
102+
def optimize_geometries(df: pd.DataFrame, xtb: bool) -> pd.DataFrame:
104103
for col in ["xyz", "ste"]:
105104
if col not in df.columns:
106105
df[col] = None
107106

108107
# Apply the optimize function to each row
109-
df = df.apply(optimize, axis=1, t_nn=t_nn, xtb=xtb)
108+
df = df.apply(optimize, axis=1, xtb=xtb)
110109

111110
return df
112111

@@ -115,7 +114,6 @@ def optimize_geometries(df: pd.DataFrame, t_nn: float, xtb: bool) -> pd.DataFram
115114
"optimize_geometries",
116115
optimize_geometries,
117116
context_kwargs={
118-
"t_nn": "t_nn",
119117
"xtb": "xtb",
120118
},
121119
batch_size=1,

tests/test_optimizegeometries_task.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,11 @@ def test_optimize(mock_optimize_geometry, mock_octahedral_embed, setup_dataframe
3131
mock_octahedral_embed.return_value = None # Does not need to return anything
3232
mock_optimize_geometry.return_value = None # Does not need to return anything
3333

34-
# Define a mock t_nn argument
35-
mock_t_nn = 1.5 # Replace with a suitable value for t_nn
36-
3734
# Define a mock xtb argument
3835
mock_xtb = True
3936

4037
# Run optimize
41-
output_dataframe = optimize_geometries(setup_dataframe, mock_t_nn, mock_xtb)
38+
output_dataframe = optimize_geometries(setup_dataframe, mock_xtb)
4239

4340
# Check if XYZ data was added or set to failed
4441
assert output_dataframe.loc[0, "xyz"] is not None

0 commit comments

Comments
 (0)