10
10
from bluephos .tasks .optimizegeometries import OptimizeGeometriesTask
11
11
from bluephos .tasks .smiles2sdf import Smiles2SDFTask
12
12
from bluephos .tasks .dft import DFTTask
13
+ from bluephos .tasks .filter_pipeline import (
14
+ FilterNNInTask ,
15
+ FilterNNOutTask ,
16
+ FilterXTBInTask ,
17
+ FilterXTBOutTask ,
18
+ FilterDFTInTask ,
19
+ FilterDFTOutTask ,
20
+ )
13
21
14
22
15
23
def ligand_pair_generator (halides_file , acids_file ):
@@ -36,24 +44,32 @@ def ligand_pair_generator(halides_file, acids_file):
36
44
yield pd .DataFrame ([ligand_pair ])
37
45
38
46
39
- def rerun_candidate_generator (input_dir , t_nn , t_ste ):
47
+ def check_column (df , column_name , condition_func , default = True ):
48
+ """
49
+ Helper function to check if a column exists and apply a condition.
50
+
51
+ """
52
+ if column_name not in df .columns :
53
+ return pd .Series ([default ] * len (df ), index = df .index )
54
+ else :
55
+ return condition_func (df [column_name ])
56
+
57
+
58
+ def rerun_candidate_generator (input_dir , t_nn , t_ste , t_dft ):
40
59
"""
41
60
Generates candidate DataFrames from parquet files in the input directory.
42
61
43
62
Core Algorithm:
44
- - If the absolute value of 'z' is less than t_nn,
45
- - and 'ste' is None or its absolute value is less than t_ste,
46
- - and 'dft_energy_diff' is None,
47
- This row is then added to a new DataFrame and yielded for re-run.
48
-
49
- Additional Context:
50
- 1. All valid ligand pairs should already have run through the NN process and have a 'z' score.
51
- 2. If a row's 'ste' is None, then it's 'dft_energy_diff' should also be None.
63
+ - A row is selected for re-run if:
64
+ 1. 'z' is None or its absolute value is less than t_nn,
65
+ 2. and 'ste' is None or its absolute value is less than t_ste,
66
+ 3. and 'dft_energy_diff' is None.
52
67
53
68
Args:
54
69
input_dir (str): Directory containing input parquet files.
55
70
t_nn (float): Threshold for 'z' score.
56
71
t_ste (float): Threshold for 'ste'.
72
+ t_dft (float): (Optional) Threshold for 'dft_energy_diff' (not currently used)
57
73
58
74
Yields:
59
75
DataFrame: A single-row DataFrame containing candidate data.
@@ -63,10 +79,9 @@ def rerun_candidate_generator(input_dir, t_nn, t_ste):
63
79
df ["ste" ] = df ["ste" ].replace ({None : np .nan })
64
80
65
81
filtered = df [
66
- (df ["z" ].notnull ())
67
- & (df ["z" ].abs () < t_nn )
68
- & ((df ["ste" ].isnull ()) | (df ["ste" ].abs () < t_ste ))
69
- & (df ["dft_energy_diff" ].isna ())
82
+ check_column (df , "z" , lambda col : col .isnull () | (col .abs () < t_nn ))
83
+ & check_column (df , "ste" , lambda col : col .isnull () | (col .abs () < t_ste ))
84
+ & check_column (df , "dft_energy_diff" , lambda col : col .isna ())
70
85
]
71
86
for _ , row in filtered .iterrows ():
72
87
yield row .to_frame ().transpose ()
@@ -78,15 +93,48 @@ def ligand_smiles_reader_generator(ligand_smiles):
78
93
yield row .to_frame ().transpose ()
79
94
80
95
81
- def get_generator (ligand_smiles , halides , acids , input_dir , t_nn , t_ste ):
96
+ def get_generator (ligand_smiles , halides , acids , input_dir , t_nn , t_ste , t_dft ):
82
97
"""
83
98
Get the appropriate generator based on the input directory presence.
84
99
"""
85
100
if ligand_smiles :
86
101
return lambda : ligand_smiles_reader_generator (ligand_smiles )
87
102
elif not input_dir :
88
103
return lambda : ligand_pair_generator (halides , acids )
89
- return lambda : rerun_candidate_generator (input_dir , t_nn , t_ste )
104
+ return lambda : rerun_candidate_generator (input_dir , t_nn , t_ste , t_dft )
105
+
106
+
107
+ def build_pipeline_graph (input_dir : str , ligand_smiles : str ):
108
+ """
109
+ Construct the pipeline graph based on input conditions.
110
+
111
+ Args:
112
+ input_dir (str): Directory containing input parquet files.
113
+ ligand_smiles (str): Path to the ligand SMILES CSV file.
114
+
115
+ Returns:
116
+ list: A list of task tuples representing the pipeline graph.
117
+ """
118
+ full_pipeline = [
119
+ (GenerateLigandTableTask , Smiles2SDFTask ), # Generate ligands, then convert SMILES to SDF
120
+ (Smiles2SDFTask , NNTask ), # Use SMILES to run NN prediction
121
+ (NNTask , FilterNNOutTask ), # NN filter "out" goes to sink
122
+ (NNTask , FilterNNInTask ), # NN filter "in" continues to the next task
123
+ (FilterNNInTask , OptimizeGeometriesTask ), # Optimize geometries for filtered ligands
124
+ (OptimizeGeometriesTask , FilterXTBOutTask ), # XTB filter "out" goes to sink
125
+ (OptimizeGeometriesTask , FilterXTBInTask ), # XTB filter "in" continues to the next task
126
+ (FilterXTBInTask , DFTTask ), # Run DFT calculation for filtered ligands
127
+ (DFTTask , FilterDFTOutTask ), # DFT filter "out" goes to sink
128
+ (DFTTask , FilterDFTInTask ), # DFT filter "in" could be processed further
129
+ ]
130
+
131
+ if ligand_smiles :
132
+ return full_pipeline [1 :] # from NNTask (Case 1: Input as ligand SMILES CSV file)
133
+
134
+ if input_dir :
135
+ return full_pipeline [2 :] # from Smiles2SDFTask (Case 2: Use parquet files for rerun)
136
+
137
+ return full_pipeline # from GenerateLigandTableTask (Case 3: Input as halides and acids CSV files)
90
138
91
139
92
140
def get_pipeline (
@@ -99,38 +147,19 @@ def get_pipeline(
99
147
input_dir = None , # Directory containing input parquet files(rerun). Defaults to None.
100
148
dft_package = "orca" , # DFT package to use. Defaults to "orca".
101
149
xtb = True , # Enable xTb optimize geometries task. Defaults to True.
102
- t_nn = 1.5 , # Threshold for 'z' score. Defaults to None
103
- t_ste = 1.9 , # Threshold for 'ste'. Defaults to None
150
+ t_nn = 1.5 , # Threshold for 'z' score.
151
+ t_ste = 1.9 , # Threshold for 'ste'.
152
+ t_dft = 2.5 , # Threshold for 'dft'.
104
153
):
105
154
"""
106
155
Set up and return the BluePhos discovery pipeline executor
107
156
Returns:
108
157
RayStreamGraphExecutor: An executor for the BluePhos discovery pipeline
109
158
"""
110
- steps = (
111
- [
112
- GenerateLigandTableTask ,
113
- Smiles2SDFTask ,
114
- NNTask ,
115
- OptimizeGeometriesTask ,
116
- DFTTask ,
117
- ]
118
- if not (input_dir or ligand_smiles ) # input as halides and acids CSV files
119
- else [
120
- NNTask ,
121
- OptimizeGeometriesTask ,
122
- DFTTask ,
123
- ]
124
- if not ligand_smiles # input as parquet files
125
- else [
126
- Smiles2SDFTask ,
127
- NNTask ,
128
- OptimizeGeometriesTask ,
129
- DFTTask ,
130
- ] # input as ligand smiles CSV file
131
- )
132
- generator = get_generator (ligand_smiles , halides , acids , input_dir , t_nn , t_ste )
133
- pipeline_executor = RayStreamGraphExecutor (graph = steps , generator = generator )
159
+
160
+ generator = get_generator (ligand_smiles , halides , acids , input_dir , t_nn , t_ste , t_dft )
161
+ pipeline_graph = build_pipeline_graph (input_dir , ligand_smiles )
162
+ pipeline_executor = RayStreamGraphExecutor (graph = pipeline_graph , generator = generator )
134
163
135
164
context_dict = {
136
165
"ligand_smiles" : ligand_smiles ,
@@ -143,6 +172,7 @@ def get_pipeline(
143
172
"xtb" : xtb ,
144
173
"t_nn" : t_nn ,
145
174
"t_ste" : t_ste ,
175
+ "t_dft" : t_dft ,
146
176
}
147
177
148
178
for key , value in context_dict .items ():
@@ -161,7 +191,10 @@ def get_pipeline(
161
191
ap .add_argument ("--input_dir" , required = False , help = "Directory containing input parquet files" )
162
192
ap .add_argument ("--t_nn" , type = float , required = False , default = 1.5 , help = "Threshold for 'z' score (default: 1.5)" )
163
193
ap .add_argument ("--t_ste" , type = float , required = False , default = 1.9 , help = "Threshold for 'ste' (default: 1.9)" )
164
- ap .add_argument ("--no_xtb" , action = "store_false" , dest = "xtb" , help = "Disable xTB optimization (default: enabled)" )
194
+ ap .add_argument ("--t_dft" , type = float , required = False , default = 2.5 , help = "Threshold for 'dft' (default: 2.5)" )
195
+ ap .add_argument (
196
+ "--no_xtb" , action = "store_false" , dest = "xtb" , default = True , help = "Disable xTB optimization (default: enabled)"
197
+ )
165
198
166
199
ap .add_argument (
167
200
"--dft_package" ,
@@ -186,6 +219,7 @@ def get_pipeline(
186
219
args .xtb ,
187
220
args .t_nn ,
188
221
args .t_ste ,
222
+ args .t_dft ,
189
223
),
190
224
args ,
191
225
)
0 commit comments