Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion libmultilabel/nn/nn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def init_trainer(
limit_val_batches=1.0,
limit_test_batches=1.0,
save_checkpoints=True,
is_tune_mode=False,
):
"""Initialize a torch lightning trainer.

Expand All @@ -146,6 +147,7 @@ def init_trainer(
limit_val_batches (Union[int, float]): Percentage of validation dataset to use. Defaults to 1.0.
limit_test_batches (Union[int, float]): Percentage of test dataset to use. Defaults to 1.0.
save_checkpoints (bool): Whether to save the last and the best checkpoint or not. Defaults to True.
is_tune_mode (bool): Whether is parameter search is running or not. Defaults to False.

Returns:
lightning.trainer: A torch lightning trainer.
Expand All @@ -163,7 +165,19 @@ def init_trainer(
strict=False,
)
callbacks = [early_stopping_callback]
if save_checkpoints:

if is_tune_mode:
callbacks += [
ModelCheckpoint(
dirpath=checkpoint_dir,
filename="best_model",
save_top_k=1,
save_weights_only=True,
monitor=val_metric,
mode="min" if val_metric == "Loss" else "max",
)
]
elif save_checkpoints:
callbacks += [
ModelCheckpoint(
dirpath=checkpoint_dir,
Expand Down
3 changes: 2 additions & 1 deletion search_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def train_libmultilabel_tune(config, datasets, classes, word_dict):
datasets=datasets,
classes=classes,
word_dict=word_dict,
save_checkpoints=True,
save_checkpoints=False,
is_tune_mode=True,
)
val_score = trainer.train()
return {f"val_{config.val_metric}": val_score}
Expand Down
2 changes: 2 additions & 0 deletions torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
word_dict: dict = None,
embed_vecs=None,
save_checkpoints: bool = True,
is_tune_mode: bool = False,
):
self.run_name = config.run_name
self.checkpoint_dir = config.checkpoint_dir
Expand Down Expand Up @@ -119,6 +120,7 @@ def __init__(
limit_val_batches=config.limit_val_batches,
limit_test_batches=config.limit_test_batches,
save_checkpoints=save_checkpoints,
is_tune_mode=is_tune_mode,
)
callbacks = [callback for callback in self.trainer.callbacks if isinstance(callback, ModelCheckpoint)]
self.checkpoint_callback = callbacks[0] if callbacks else None
Expand Down