diff --git a/libmultilabel/nn/nn_utils.py b/libmultilabel/nn/nn_utils.py index a4ac82c2..aa961d78 100644 --- a/libmultilabel/nn/nn_utils.py +++ b/libmultilabel/nn/nn_utils.py @@ -131,6 +131,7 @@ def init_trainer( limit_val_batches=1.0, limit_test_batches=1.0, save_checkpoints=True, + is_tune_mode=False, ): """Initialize a torch lightning trainer. @@ -146,6 +147,7 @@ def init_trainer( limit_val_batches (Union[int, float]): Percentage of validation dataset to use. Defaults to 1.0. limit_test_batches (Union[int, float]): Percentage of test dataset to use. Defaults to 1.0. save_checkpoints (bool): Whether to save the last and the best checkpoint or not. Defaults to True. + is_tune_mode (bool): Whether is parameter search is running or not. Defaults to False. Returns: lightning.trainer: A torch lightning trainer. @@ -163,7 +165,19 @@ def init_trainer( strict=False, ) callbacks = [early_stopping_callback] - if save_checkpoints: + + if is_tune_mode: + callbacks += [ + ModelCheckpoint( + dirpath=checkpoint_dir, + filename="best_model", + save_top_k=1, + save_weights_only=True, + monitor=val_metric, + mode="min" if val_metric == "Loss" else "max", + ) + ] + elif save_checkpoints: callbacks += [ ModelCheckpoint( dirpath=checkpoint_dir, diff --git a/search_params.py b/search_params.py index 545a6530..4c2dad4c 100644 --- a/search_params.py +++ b/search_params.py @@ -42,7 +42,8 @@ def train_libmultilabel_tune(config, datasets, classes, word_dict): datasets=datasets, classes=classes, word_dict=word_dict, - save_checkpoints=True, + save_checkpoints=False, + is_tune_mode=True, ) val_score = trainer.train() return {f"val_{config.val_metric}": val_score} diff --git a/torch_trainer.py b/torch_trainer.py index 8dc259b5..737db0b9 100644 --- a/torch_trainer.py +++ b/torch_trainer.py @@ -33,6 +33,7 @@ def __init__( word_dict: dict = None, embed_vecs=None, save_checkpoints: bool = True, + is_tune_mode: bool = False, ): self.run_name = config.run_name self.checkpoint_dir = config.checkpoint_dir @@ -119,6 +120,7 @@ def __init__( limit_val_batches=config.limit_val_batches, limit_test_batches=config.limit_test_batches, save_checkpoints=save_checkpoints, + is_tune_mode=is_tune_mode, ) callbacks = [callback for callback in self.trainer.callbacks if isinstance(callback, ModelCheckpoint)] self.checkpoint_callback = callbacks[0] if callbacks else None