@@ -171,13 +171,11 @@ def prepare_retrain_config(best_config, best_log_dir, retrain):
171
171
best_config .merge_train_val = False
172
172
173
173
174
- def load_static_data (config , merge_train_val = False ):
174
+ def load_static_data (config ):
175
175
"""Preload static data once for multiple trials.
176
176
177
177
Args:
178
178
config (AttributeDict): Config of the experiment.
179
- merge_train_val (bool, optional): Whether to merge the training and validation data.
180
- Defaults to False.
181
179
182
180
Returns:
183
181
dict: A dict of static data containing datasets, classes, and word_dict.
@@ -187,7 +185,7 @@ def load_static_data(config, merge_train_val=False):
187
185
test_data = config .test_file ,
188
186
val_data = config .val_file ,
189
187
val_size = config .val_size ,
190
- merge_train_val = merge_train_val ,
188
+ merge_train_val = config . merge_train_val ,
191
189
tokenize_text = "lm_weight" not in config .network_config ,
192
190
remove_no_label_data = config .remove_no_label_data ,
193
191
)
@@ -231,7 +229,7 @@ def retrain_best_model(exp_name, best_config, best_log_dir, retrain):
231
229
with open (os .path .join (checkpoint_dir , "params.yml" ), "w" ) as fp :
232
230
yaml .dump (dict (best_config ), fp )
233
231
234
- data = load_static_data (best_config , merge_train_val = best_config . merge_train_val )
232
+ data = load_static_data (best_config )
235
233
236
234
if retrain :
237
235
logging .info (f"Re-training with best config: \n { best_config } " )
0 commit comments