From 56ddfda921c1d3d4c722a2c2562da2f113e30272 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Pinard?= Date: Fri, 4 Aug 2023 19:18:27 +0200 Subject: [PATCH] Add type hints, 3.8 style --- .../faster_rcnn/evaluation/coco_evaluator.py | 4 +- pytorch_accelerated/callbacks.py | 156 +++++++------ pytorch_accelerated/finetuning.py | 47 ++-- .../schedulers/cosine_scheduler.py | 10 +- pytorch_accelerated/tracking.py | 28 +-- pytorch_accelerated/trainer.py | 208 ++++++++++-------- 6 files changed, 259 insertions(+), 194 deletions(-) diff --git a/examples/vision/faster_rcnn/evaluation/coco_evaluator.py b/examples/vision/faster_rcnn/evaluation/coco_evaluator.py index 1748302..e38fea5 100644 --- a/examples/vision/faster_rcnn/evaluation/coco_evaluator.py +++ b/examples/vision/faster_rcnn/evaluation/coco_evaluator.py @@ -61,7 +61,7 @@ def __init__(self, iou_threshold: float = None, verbose=False): self.verbose = verbose self.silencer = Silencer() - def compute(self, targets_json: dict, predictions_json: List[dict]): + def compute(self, targets_json: dict, predictions_json: List[Dict]): """ Calculate mAP from COCO-formatted dictionaries containing predictions and targets. @@ -150,7 +150,7 @@ def create_targets_coco_json_from_df( @classmethod def create_predictions_coco_json_from_df( cls, predictions_df: pd.DataFrame - ) -> List[dict]: + ) -> List[Dict]: """ Create a COCO-formatted predictions list of dictionaries that can be used for evaluation. diff --git a/pytorch_accelerated/callbacks.py b/pytorch_accelerated/callbacks.py index 290fcbb..0f65e4a 100644 --- a/pytorch_accelerated/callbacks.py +++ b/pytorch_accelerated/callbacks.py @@ -9,6 +9,10 @@ import numpy as np import torch from pytorch_accelerated.utils import ModelEma +from typing import TYPE_CHECKING, Optional, List, Any, Dict, Callable, Iterable, Type + +if TYPE_CHECKING: + from pytorch_accelerated import Trainer from torch import nn from tqdm import tqdm @@ -58,25 +62,25 @@ def on_init_end(self, trainer, **kwargs): """ pass - def on_training_run_start(self, trainer, **kwargs): + def on_training_run_start(self, trainer: "Trainer", **kwargs): """ Event called at the start of training run. """ pass - def on_train_epoch_start(self, trainer, **kwargs): + def on_train_epoch_start(self, trainer: "Trainer", **kwargs): """ Event called at the beginning of a training epoch. """ pass - def on_train_step_start(self, trainer, **kwargs): + def on_train_step_start(self, trainer: "Trainer", **kwargs): """ Event called at the beginning of a training step. """ pass - def on_train_step_end(self, trainer, batch, batch_output, **kwargs): + def on_train_step_end(self, trainer: "Trainer", batch, batch_output, **kwargs): """ Event called at the end of a training step. @@ -85,25 +89,25 @@ def on_train_step_end(self, trainer, batch, batch_output, **kwargs): """ pass - def on_train_epoch_end(self, trainer, **kwargs): + def on_train_epoch_end(self, trainer: "Trainer", **kwargs): """ Event called at the end of a training epoch. """ pass - def on_eval_epoch_start(self, trainer, **kwargs): + def on_eval_epoch_start(self, trainer: "Trainer", **kwargs): """ Event called at the beginning of an evaluation epoch. """ pass - def on_eval_step_start(self, trainer, **kwargs): + def on_eval_step_start(self, trainer: "Trainer", **kwargs): """ Event called at the beginning of a evaluation step. """ pass - def on_eval_step_end(self, trainer, batch, batch_output, **kwargs): + def on_eval_step_end(self, trainer: "Trainer", batch, batch_output, **kwargs): """ Event called at the end of an evaluation step. @@ -112,43 +116,43 @@ def on_eval_step_end(self, trainer, batch, batch_output, **kwargs): """ pass - def on_eval_epoch_end(self, trainer, **kwargs): + def on_eval_epoch_end(self, trainer: "Trainer", **kwargs): """ Event called at the end of evaluation. """ pass - def on_training_run_epoch_end(self, trainer, **kwargs): + def on_training_run_epoch_end(self, trainer: "Trainer", **kwargs): """ Event called during a training run after both training and evaluation epochs have been completed. """ pass - def on_training_run_end(self, trainer, **kwargs): + def on_training_run_end(self, trainer: "Trainer", **kwargs): """ Event called at the end of training run. """ pass - def on_evaluation_run_start(self, trainer, **kwargs): + def on_evaluation_run_start(self, trainer: "Trainer", **kwargs): """ Event called at the start of an evaluation run. """ pass - def on_evaluation_run_end(self, trainer, **kwargs): + def on_evaluation_run_end(self, trainer: "Trainer", **kwargs): """ Event called at the end of an evaluation run. """ pass - def on_stop_training_error(self, trainer, **kwargs): + def on_stop_training_error(self, trainer: "Trainer", **kwargs): """ Event called when a stop training error is raised """ pass - def __getattr__(self, item): + def __getattr__(self, item) -> Callable[..., None]: try: return super().__getattr__(item) except AttributeError: @@ -161,12 +165,12 @@ class CallbackHandler: This class calls the callbacks in the order that they are given. """ - def __init__(self, callbacks): - self.callbacks = [] + def __init__(self, callbacks: Iterable[TrainerCallback]): + self.callbacks: List[TrainerCallback] = [] self.add_callbacks(callbacks) self._enabled = True - def add_callbacks(self, callbacks): + def add_callbacks(self, callbacks: Iterable[TrainerCallback]): """ Add a list of callbacks to the callback handler @@ -175,7 +179,7 @@ def add_callbacks(self, callbacks): for cb in callbacks: self.add_callback(cb) - def add_callback(self, callback): + def add_callback(self, callback: TrainerCallback): """ Add a callbacks to the callback handler @@ -187,22 +191,23 @@ def add_callback(self, callback): existing_callbacks = "\n".join(cb for cb in self.callback_list) raise ValueError( - f"You attempted to add multiple instances of the callback {cb_class} to a single Trainer" - f" The list of callbacks already present is\n: {existing_callbacks}" + f"You attempted to add multiple instances of the callback {cb_class} to" + " a single Trainer The list of callbacks already present is\n:" + f" {existing_callbacks}" ) self.callbacks.append(cb) - def __iter__(self): + def __iter__(self) -> Iterable[TrainerCallback]: return self.callbacks - def clear_callbacks(self): + def clear_callbacks(self) -> None: self.callbacks = [] @property - def callback_list(self): + def callback_list(self) -> List[str]: return [cb.__class__.__name__ for cb in self.callbacks] - def call_event(self, event, *args, **kwargs): + def call_event(self, event: str, *args, **kwargs): """ For each callback which has been registered, sequentially call the method corresponding to the given event. @@ -214,7 +219,8 @@ def call_event(self, event, *args, **kwargs): if self._enabled: for callback in self.callbacks: try: - getattr(callback, event)( + event_function = callback.__getattr__(event) + event_function( *args, **kwargs, ) @@ -233,7 +239,7 @@ class LogMetricsCallback(TrainerCallback): This can be subclassed to create loggers for different platforms by overriding the :meth:`~LogMetricsCallback.log_metrics` method. """ - def on_train_epoch_end(self, trainer, **kwargs): + def on_train_epoch_end(self, trainer: "Trainer", **kwargs): metric_names = [ metric for metric in trainer.run_history.get_metric_names() @@ -242,7 +248,7 @@ def on_train_epoch_end(self, trainer, **kwargs): self._log_latest_metrics(trainer, metric_names) - def on_eval_epoch_end(self, trainer, **kwargs): + def on_eval_epoch_end(self, trainer: "Trainer", **kwargs): metric_names = [ metric for metric in trainer.run_history.get_metric_names() @@ -250,17 +256,17 @@ def on_eval_epoch_end(self, trainer, **kwargs): ] self._log_latest_metrics(trainer, metric_names) - def _log_latest_metrics(self, trainer, metric_names): + def _log_latest_metrics(self, trainer: "Trainer", metric_names: Iterable[str]): latest_metrics = self._get_latest_metrics(trainer, metric_names) self.log_metrics(trainer, latest_metrics) - def _get_latest_metrics(self, trainer, metric_names): + def _get_latest_metrics(self, trainer: "Trainer", metric_names: Iterable[str]): return { metric_name: trainer.run_history.get_latest_metric(metric_name) for metric_name in metric_names } - def log_metrics(self, trainer, metrics: dict): + def log_metrics(self, trainer: "Trainer", metrics: Dict[str, Any]): for metric_name, metric_value in metrics.items(): trainer.print(f"\n{metric_name}: {metric_value}") @@ -273,29 +279,29 @@ class ProgressBarCallback(TrainerCallback): def __init__(self): self.pbar = None - def on_train_epoch_start(self, trainer, **kwargs): + def on_train_epoch_start(self, trainer: "Trainer", **kwargs): self.pbar = tqdm( total=len(trainer._train_dataloader), disable=not trainer._accelerator.is_local_main_process, ) - def on_train_step_end(self, trainer, **kwargs): + def on_train_step_end(self, trainer: "Trainer", **kwargs): self.pbar.update(1) - def on_train_epoch_end(self, trainer, **kwargs): + def on_train_epoch_end(self, trainer: "Trainer", **kwargs): self.pbar.close() time.sleep(0.01) - def on_eval_epoch_start(self, trainer, **kwargs): + def on_eval_epoch_start(self, trainer: "Trainer", **kwargs): self.pbar = tqdm( total=len(trainer._eval_dataloader), disable=not trainer._accelerator.is_local_main_process, ) - def on_eval_step_end(self, trainer, **kwargs): + def on_eval_step_end(self, trainer: "Trainer", **kwargs): self.pbar.update(1) - def on_eval_epoch_end(self, trainer, **kwargs): + def on_eval_epoch_end(self, trainer: "Trainer", **kwargs): self.pbar.close() time.sleep(0.01) @@ -306,20 +312,20 @@ class PrintProgressCallback(TrainerCallback): as well as at the start of each epoch. """ - def on_training_run_start(self, trainer, **kwargs): + def on_training_run_start(self, trainer: "Trainer", **kwargs): trainer.print("\nStarting training run") - def on_train_epoch_start(self, trainer, **kwargs): + def on_train_epoch_start(self, trainer: "Trainer", **kwargs): trainer.print(f"\nStarting epoch {trainer.run_history.current_epoch}") time.sleep(0.01) - def on_training_run_end(self, trainer, **kwargs): + def on_training_run_end(self, trainer: "Trainer", **kwargs): trainer.print("Finishing training run") - def on_evaluation_run_start(self, trainer, **kwargs): + def on_evaluation_run_start(self, trainer: "Trainer", **kwargs): trainer.print("\nStarting evaluation run") - def on_evaluation_run_end(self, trainer, **kwargs): + def on_evaluation_run_end(self, trainer: "Trainer", **kwargs): trainer.print("Finishing evaluation run") @@ -331,8 +337,8 @@ class SaveBestModelCallback(TrainerCallback): def __init__( self, - save_path="best_model.pt", - watch_metric="eval_loss_epoch", + save_path: str = "best_model.pt", + watch_metric: str = "eval_loss_epoch", greater_is_better: bool = False, reset_on_train: bool = True, save_optimizer: bool = True, @@ -358,7 +364,7 @@ def on_training_run_start(self, args, **kwargs): if self.reset_on_train: self.best_metric = None - def on_training_run_epoch_end(self, trainer, **kwargs): + def on_training_run_epoch_end(self, trainer: "Trainer", **kwargs): current_metric = trainer.run_history.get_latest_metric(self.watch_metric) if self.best_metric is None: self.best_metric = current_metric @@ -380,9 +386,10 @@ def on_training_run_epoch_end(self, trainer, **kwargs): save_optimizer=self.save_optimizer, ) - def on_training_run_end(self, trainer, **kwargs): + def on_training_run_end(self, trainer: "Trainer", **kwargs): trainer.print( - f"Loading checkpoint with {self.watch_metric}: {self.best_metric} from epoch {self.best_metric_epoch}" + f"Loading checkpoint with {self.watch_metric}: {self.best_metric} from" + f" epoch {self.best_metric_epoch}" ) trainer.load_checkpoint(self.save_path) @@ -396,7 +403,7 @@ def __init__( self, early_stopping_patience: int = 1, early_stopping_threshold: float = 0.01, - watch_metric="eval_loss_epoch", + watch_metric: str = "eval_loss_epoch", greater_is_better: bool = False, reset_on_train: bool = True, ): @@ -422,7 +429,7 @@ def on_training_run_start(self, args, **kwargs): self.best_metric = None self.early_stopping_patience_counter = 0 - def on_training_run_epoch_end(self, trainer, **kwargs): + def on_training_run_epoch_end(self, trainer: "Trainer", **kwargs): current_metric = trainer.run_history.get_latest_metric(self.watch_metric) if self.best_metric is None: self.best_metric = current_metric @@ -447,12 +454,14 @@ def on_training_run_epoch_end(self, trainer, **kwargs): if self.early_stopping_patience_counter >= self.early_stopping_patience: raise StopTrainingError( - f"Stopping training due to no improvement after {self.early_stopping_patience} epochs" + "Stopping training due to no improvement after" + f" {self.early_stopping_patience} epochs" ) - def __print_counter_status(self, trainer): + def __print_counter_status(self, trainer: "Trainer"): trainer.print( - f"Early stopping counter: {self.early_stopping_patience_counter}/{self.early_stopping_patience}" + "Early stopping counter:" + f" {self.early_stopping_patience_counter}/{self.early_stopping_patience}" ) @@ -465,20 +474,20 @@ class TerminateOnNaNCallback(TrainerCallback): def __init__(self): self.triggered = False - def check_for_nan_after_batch(self, batch_output, step=None): + def check_for_nan_after_batch(self, batch_output, step: Optional[str] = None): """Test if loss is NaN and interrupts training.""" loss = batch_output["loss"] if torch.isinf(loss).any() or torch.isnan(loss).any(): self.triggered = True raise StopTrainingError(f"Stopping training due to NaN loss in {step} step") - def on_train_step_end(self, trainer, batch_output, **kwargs): + def on_train_step_end(self, trainer: "Trainer", batch_output, **kwargs): self.check_for_nan_after_batch(batch_output, step="training") - def on_eval_step_end(self, trainer, batch_output, **kwargs): + def on_eval_step_end(self, trainer: "Trainer", batch_output, **kwargs): self.check_for_nan_after_batch(batch_output, step="validation") - def on_training_run_end(self, trainer, **kwargs): + def on_training_run_end(self, trainer: "Trainer", **kwargs): if self.triggered: sys.exit("Exiting due to NaN loss") @@ -494,25 +503,25 @@ class MoveModulesToDeviceCallback(TrainerCallback): """ - def _get_modules(self, trainer): + def _get_modules(self, trainer: "Trainer"): return inspect.getmembers(trainer, lambda x: isinstance(x, nn.Module)) - def _move_modules_to_device(self, trainer): + def _move_modules_to_device(self, trainer: "Trainer"): modules = self._get_modules(trainer) for module_name, module in modules: if module_name != "model": module.to(trainer.device) - def on_training_run_start(self, trainer, **kwargs): + def on_training_run_start(self, trainer: "Trainer", **kwargs): self._move_modules_to_device(trainer) - def on_evaluation_run_start(self, trainer, **kwargs): + def on_evaluation_run_start(self, trainer: "Trainer", **kwargs): self._move_modules_to_device(trainer) class DataLoaderSlice: - def __init__(self, dl, slice_size): + def __init__(self, dl, slice_size: int): self.dl = dl self.slice_size = slice_size @@ -528,10 +537,10 @@ class LimitBatchesCallback(TrainerCallback): A callback that that limits the number of batches used during training and evaluation """ - def __init__(self, num_batches): + def __init__(self, num_batches: int): self.num_batches = num_batches - def on_training_run_start(self, trainer, **kwargs): + def on_training_run_start(self, trainer: "Trainer", **kwargs): trainer._train_dataloader = DataLoaderSlice( trainer._train_dataloader, self.num_batches ) @@ -539,7 +548,7 @@ def on_training_run_start(self, trainer, **kwargs): trainer._eval_dataloader, self.num_batches ) - def on_evaluation_run_start(self, trainer, **kwargs): + def on_evaluation_run_start(self, trainer: "Trainer", **kwargs): trainer._eval_dataloader = DataLoaderSlice( trainer._eval_dataloader, self.num_batches ) @@ -567,8 +576,8 @@ def __init__( save_path: str = "ema_model.pt", watch_metric: str = "ema_model_eval_loss_epoch", greater_is_better: bool = False, - model_ema=ModelEma, - callbacks=(), + model_ema: Type[nn.Module] = ModelEma, + callbacks: Iterable[TrainerCallback] = (), ): """ :param decay: the amount of decay to use, which determines how much of the previous state will be maintained. @@ -593,19 +602,21 @@ def __init__( self.model_ema_cls = model_ema self.callback_handler = CallbackHandler(callbacks) - def on_training_run_start(self, trainer, **kwargs): + def on_training_run_start(self, trainer: "Trainer", **kwargs): self.ema_model = self.model_ema_cls( trainer._accelerator.unwrap_model(trainer.model), decay=self.decay ) if self.evaluate_during_training: self.ema_model.to(trainer.device) - def on_train_epoch_end(self, trainer, **kwargs): + def on_train_epoch_end(self, trainer: "Trainer", **kwargs): + assert self.ema_model is not None self.ema_model.update(trainer._accelerator.unwrap_model(trainer.model)) - def on_eval_epoch_end(self, trainer, **kwargs): + def on_eval_epoch_end(self, trainer: "Trainer", **kwargs): if self.evaluate_during_training: model = trainer.model + assert self.ema_model is not None trainer.model = self.ema_model.module run_history_prefix = trainer.run_history.metric_name_prefix trainer_callback_handler = trainer.callback_handler @@ -620,8 +631,9 @@ def on_eval_epoch_end(self, trainer, **kwargs): trainer.callback_handler = trainer_callback_handler trainer.run_history.set_metric_name_prefix(run_history_prefix) - def on_training_run_epoch_end(self, trainer, **kwargs): + def on_training_run_epoch_end(self, trainer: "Trainer", **kwargs): model = trainer.model + assert self.ema_model is not None trainer.model = self.ema_model.module if self.evaluate_during_training: @@ -631,7 +643,7 @@ def on_training_run_epoch_end(self, trainer, **kwargs): trainer.model = model - def on_training_run_end(self, trainer, **kwargs): + def on_training_run_end(self, trainer: "Trainer", **kwargs): # Overriding, as we do not want to load the EMA model pass @@ -641,6 +653,6 @@ class ConvertSyncBatchNormCallback(TrainerCallback): A callback which converts all BatchNorm*D layers in the model to :class:`torch.nn.SyncBatchNorm` layers. """ - def on_training_run_start(self, trainer, **kwargs): + def on_training_run_start(self, trainer: "Trainer", **kwargs): if trainer.run_config.is_distributed: trainer.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(trainer.model) diff --git a/pytorch_accelerated/finetuning.py b/pytorch_accelerated/finetuning.py index b2b652c..8ecb339 100644 --- a/pytorch_accelerated/finetuning.py +++ b/pytorch_accelerated/finetuning.py @@ -1,5 +1,5 @@ from collections import namedtuple, defaultdict -from typing import List +from typing import List, Tuple, TYPE_CHECKING, Dict, Iterable import torch @@ -51,7 +51,7 @@ def forward(self, x): of Linear, BatchNorm and ReLU modules. """ - def __init__(self, model, freeze_batch_norms=False): + def __init__(self, model: torch.nn.Module, freeze_batch_norms=False): """ Create a new ModelFreezer instance, which can be used to freeze and unfreeze all, or parts, or a model. When a model is passed to a ModelFreezer instance, all parameters will be unfrozen regardless of their previous state. Subsequent freezing/unfreezing should be @@ -122,7 +122,7 @@ def get_layers(self) -> List[Layer]: return layers - def get_trainable_parameters(self): + def get_trainable_parameters(self) -> List[torch.nn.Parameter]: """ Return a list of all unfrozen model parameters, which will be updated during training. @@ -130,7 +130,9 @@ def get_trainable_parameters(self): """ return [param for param in self.model.parameters() if param.requires_grad] - def freeze(self, from_index=0, to_index=-2, set_modules_as_eval=False): + def freeze( + self, from_index: int = 0, to_index: int = -2, set_modules_as_eval: bool = False + ): """ Freeze layer groups corresponding to the specified indexes, which are inclusive. By default, this freezes all layer groups except the final one. @@ -143,7 +145,12 @@ def freeze(self, from_index=0, to_index=-2, set_modules_as_eval=False): from_index, to_index, freeze=True, toggle_train_eval=set_modules_as_eval ) - def unfreeze(self, from_index=-1, to_index=0, set_modules_as_training=True): + def unfreeze( + self, + from_index: int = -1, + to_index: int = 0, + set_modules_as_training: bool = True, + ): """ Unfreeze layer groups corresponding to the specified indexes, which are inclusive. By default, this unfreezes all layer groups. For each layer group, any parameters which have been unfrozen are returned, so that they can be added to an optimizer if needed. @@ -163,10 +170,10 @@ def unfreeze(self, from_index=-1, to_index=0, set_modules_as_training=True): def __freeze_unfreeze( self, - from_layer_group_index, - to_layer_group_index, - freeze=True, - toggle_train_eval=True, + from_layer_group_index: int, + to_layer_group_index: int, + freeze: bool = True, + toggle_train_eval: bool = True, ): modified_parameters = defaultdict(list) set_grad_value = not freeze @@ -201,7 +208,7 @@ def __freeze_unfreeze( for layer_group_idx, params in modified_parameters.items() } - def _convert_idxs(self, from_idx, to_idx): + def _convert_idxs(self, from_idx: int, to_idx: int) -> Tuple[int, int]: from_idx = _convert_idx(from_idx, self.num_groups) to_idx = _convert_idx(to_idx, self.num_groups) @@ -211,7 +218,9 @@ def _convert_idxs(self, from_idx, to_idx): return from_idx, to_idx -def _change_layer_state(layer: Layer, set_grad_value: bool, toggle_train_eval: bool): +def _change_layer_state( + layer: Layer, set_grad_value: bool, toggle_train_eval: bool +) -> List[torch.nn.Parameter]: params = list(layer.module.parameters()) if params: _set_requires_grad(params, value=set_grad_value) @@ -220,22 +229,24 @@ def _change_layer_state(layer: Layer, set_grad_value: bool, toggle_train_eval: b return params -def _module_is_batch_norm(module): +def _module_is_batch_norm(module: torch.nn.Module) -> bool: return isinstance(module, BN_MODULES) -def _convert_idx(idx, num_groups): +def _convert_idx(idx: int, num_groups: int) -> int: if idx < 0: idx = idx + num_groups return idx -def _set_requires_grad(parameters, value=True): +def _set_requires_grad(parameters: Iterable[torch.nn.Parameter], value: bool = True): for param in parameters: param.requires_grad = value -def _get_layer_groups_for_module(module): +def _get_layer_groups_for_module( + module: torch.nn.Module, +) -> Tuple[Dict[int, torch.nn.Module], List[Tuple[int, torch.nn.Module]]]: layers = [] layer_groups = dict() for layer_group, group in enumerate(module.children()): @@ -245,7 +256,11 @@ def _get_layer_groups_for_module(module): return layer_groups, layers -def _recursive_get_layers(module, result, layer_group=0): +def _recursive_get_layers( + module: torch.nn.Module, + result: List[Tuple[int, torch.nn.Module]], + layer_group: int = 0, +): children = list(module.children()) if not children: # is leaf diff --git a/pytorch_accelerated/schedulers/cosine_scheduler.py b/pytorch_accelerated/schedulers/cosine_scheduler.py index f90c05d..f30c44d 100644 --- a/pytorch_accelerated/schedulers/cosine_scheduler.py +++ b/pytorch_accelerated/schedulers/cosine_scheduler.py @@ -135,9 +135,11 @@ def get_updated_values(self, num_updates: int): else: # cooldown lrs = [ - self.lr_min_ratio * base_lr - if self.lr_min_ratio is not None - else self.lr_min + ( + self.lr_min_ratio * base_lr + if self.lr_min_ratio is not None + else self.lr_min + ) for base_lr in self.base_lr_values ] @@ -155,7 +157,7 @@ def create_scheduler_fn( warmup_starting_lr=1e-6, warmup_starting_lr_ratio=None, num_cooldown_epochs=0, - ) -> Callable: + ) -> Callable[[torch.optim.Optimizer], "CosineLrScheduler"]: """ An alternative constructor which returns a function that accepts an optimizer and creates an instance of ``CosineLrScheduler``. This is primarily intended to be used with the :class:`~pytorch_accelerated.trainer.Trainer` diff --git a/pytorch_accelerated/tracking.py b/pytorch_accelerated/tracking.py index 4b80f02..bce6d7a 100644 --- a/pytorch_accelerated/tracking.py +++ b/pytorch_accelerated/tracking.py @@ -1,7 +1,7 @@ # Copyright © 2021 Chris Hughes from abc import ABC, abstractmethod from collections import defaultdict -from typing import Iterable +from typing import Any, Iterable class RunHistory(ABC): @@ -19,7 +19,7 @@ def get_metric_names(self) -> Iterable: pass @abstractmethod - def get_metric_values(self, metric_name) -> Iterable: + def get_metric_values(self, metric_name: str) -> Iterable: """ Return all of the values that have been recorded for the given metric. @@ -29,7 +29,7 @@ def get_metric_values(self, metric_name) -> Iterable: pass @abstractmethod - def get_latest_metric(self, metric_name): + def get_latest_metric(self, metric_name: str) -> Any: """ Return the most recent value that has been recorded for the given metric. @@ -39,7 +39,7 @@ def get_latest_metric(self, metric_name): pass @abstractmethod - def set_metric_name_prefix(self, prefix=""): + def set_metric_name_prefix(self, prefix: str = ""): """ Set a prefix which will be prepended to any metric name which is tracked. @@ -49,14 +49,14 @@ def set_metric_name_prefix(self, prefix=""): @property @abstractmethod - def metric_name_prefix(self): + def metric_name_prefix(self) -> str: """ - :return: the prefix which wil be prepended to any metric name + :return: the prefix which will be prepended to any metric name """ pass @abstractmethod - def update_metric(self, metric_name, metric_value): + def update_metric(self, metric_name: str, metric_value: Any): """ Record the value for the given metric. @@ -103,10 +103,10 @@ def __init__(self): def get_metric_names(self): return set(self._metrics.keys()) - def get_metric_values(self, metric_name): + def get_metric_values(self, metric_name: str) -> list: return self._metrics[metric_name] - def get_latest_metric(self, metric_name): + def get_latest_metric(self, metric_name: str) -> Any: if len(self._metrics[metric_name]) > 0: return self._metrics[metric_name][-1] else: @@ -114,18 +114,18 @@ def get_latest_metric(self, metric_name): f"No values have been recorded for the metric {metric_name}" ) - def update_metric(self, metric_name, metric_value): + def update_metric(self, metric_name: str, metric_value): self._metrics[f"{self._prefix}{metric_name}"].append(metric_value) - def set_metric_name_prefix(self, prefix=""): + def set_metric_name_prefix(self, prefix: str = ""): self._prefix = prefix @property - def metric_name_prefix(self): + def metric_name_prefix(self) -> str: return self._prefix @property - def current_epoch(self): + def current_epoch(self) -> str: return self._current_epoch def _increment_epoch(self): @@ -149,7 +149,7 @@ def reset(self): self.total_loss = 0 self.running_count = 0 - def update(self, loss_batch_value, batch_size=1): + def update(self, loss_batch_value: float, batch_size: int = 1): self.loss_value = loss_batch_value self.total_loss += loss_batch_value * batch_size self.running_count += batch_size diff --git a/pytorch_accelerated/trainer.py b/pytorch_accelerated/trainer.py index 30a8b62..13b472f 100644 --- a/pytorch_accelerated/trainer.py +++ b/pytorch_accelerated/trainer.py @@ -3,13 +3,14 @@ import os from enum import Enum from functools import partial -from typing import Iterable +from typing import Optional, Union, TypeVar, Any, List, Callable, TYPE_CHECKING +from pathlib import Path import warnings import torch from accelerate import Accelerator, DistributedType from accelerate.utils import set_seed -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset from pytorch_accelerated.callbacks import ( CallbackHandler, @@ -20,6 +21,7 @@ ProgressBarCallback, StopTrainingError, TerminateOnNaNCallback, + TrainerCallback, ) from pytorch_accelerated.run_config import TrainerRunConfig from pytorch_accelerated.tracking import InMemoryRunHistory, LossTracker, RunHistory @@ -29,6 +31,9 @@ worker_init_fn, ) +if TYPE_CHECKING: + from pytorch_accelerated.schedulers import StatefulSchedulerBase + DEFAULT_CALLBACKS = ( MoveModulesToDeviceCallback, TerminateOnNaNCallback, @@ -69,11 +74,15 @@ def placeholder_set(cls): return {placeholder.name for placeholder in cls} @staticmethod - def __create_new_enum(original_enum, other, operation): + def __create_new_enum( + original_enum: "TrainerPlaceholderValues", + other, + operation: str, + ) -> Enum: enum_members = {k: v.value for k, v in original_enum._member_map_.items()} - enum_members[ - original_enum.name - ] = f"{enum_members[original_enum.name]}{operation}{other}" + enum_members[original_enum.name] = ( + f"{enum_members[original_enum.name]}{operation}{other}" + ) new_enum = Enum("TrainerPlaceholderValues", enum_members) return new_enum._member_map_[original_enum.name] @@ -85,11 +94,15 @@ def __add__(self, other): def __sub__(self, other): raise NotImplemented( - "Subtraction is not supported, please re-write the expression in terms of addition" + "Subtraction is not supported, please re-write the expression in terms of" + " addition" ) -def replace_trainer_placeholder_values(trainer, instance): +T = TypeVar("T") + + +def replace_trainer_placeholder_values(trainer: "Trainer", instance: T) -> T: """If the instance is partial and contains keywords, will replace these, returning a new function.""" if isinstance(instance, partial): @@ -126,10 +139,10 @@ class Trainer: def __init__( self, - model, - loss_func, - optimizer, - callbacks=DEFAULT_CALLBACKS, + model: torch.nn.Module, + loss_func: Callable[[Any, Any], torch.Tensor], + optimizer: torch.optim.Optimizer, + callbacks: List[TrainerCallback] = DEFAULT_CALLBACKS, run_history=None, ): """ @@ -193,7 +206,7 @@ def _create_callback_handler(self): callbacks, ) - def _create_accelerator(self): + def _create_accelerator(self) -> Accelerator: """ Create an instance of :class:`accelerate.Accelerator` which will be used to manage training. """ @@ -201,8 +214,8 @@ def _create_accelerator(self): return Accelerator() def create_train_dataloader( - self, batch_size: int, train_dl_kwargs: dict = None - ) -> Iterable: + self, batch_size: int, train_dl_kwargs: Optional[dict] = None + ) -> DataLoader: """ Create a dataloader to be used during training. This is initialised with the train_dataset and collate function which have been passed to the Trainer. @@ -233,8 +246,8 @@ def create_train_dataloader( ) def create_eval_dataloader( - self, batch_size: int, eval_dl_kwargs: dict = None - ) -> Iterable: + self, batch_size: int, eval_dl_kwargs: Optional[dict] = None + ) -> DataLoader: """ Create a dataloader to be used during evaluation. This is initialised with the eval_dataset and collate function which have been passed to the Trainer. @@ -262,11 +275,12 @@ def create_eval_dataloader( **self._eval_dl_kwargs, ) - def create_scheduler(self): + def create_scheduler(self) -> "StatefulSchedulerBase": """ Create a learning rate scheduler based on the ``create_scheduler_fn`` function which has been passed to the Trainer. :return: a learning rate scheduler instance """ + assert self.create_scheduler_fn is not None scheduler_type = replace_trainer_placeholder_values( self, self.create_scheduler_fn ) @@ -306,7 +320,7 @@ def calculate_train_batch_loss(self, batch) -> dict: "batch_size": yb.size(0), } - def backward_step(self, loss): + def backward_step(self, loss: torch.Tensor): """ Use the accelerator to perform the backward pass on the calculated value of the loss returned by :meth:`~Trainer.calculate_train_batch_loss`. If gradient accumulation is enabled, this loss has been scaled by 1 / accumulation steps. @@ -400,18 +414,20 @@ def evaluation_run_end(self): def train( self, - train_dataset, - num_epochs, - eval_dataset=None, - per_device_batch_size=8, - max_num_train_steps=None, - gradient_accumulation_steps=1, - gradient_clip_value=None, - create_scheduler_fn=None, - train_dataloader_kwargs: dict = None, - eval_dataloader_kwargs: dict = None, - reset_run_history=True, - collate_fn=None, + train_dataset: Dataset, + num_epochs: int, + eval_dataset: Optional[Dataset] = None, + per_device_batch_size: int = 8, + max_num_train_steps: Optional[int] = None, + gradient_accumulation_steps: int = 1, + gradient_clip_value: Optional[float] = None, + create_scheduler_fn: Optional[ + Callable[[torch.optim.Optimizer], "StatefulSchedulerBase"] + ] = None, + train_dataloader_kwargs: Optional[dict] = None, + eval_dataloader_kwargs: Optional[dict] = None, + reset_run_history: bool = True, + collate_fn: Optional[callable] = None, ): """ Start a training run. If an evaluation dataset is provided, this routine will include both training and evaluation epochs. @@ -470,10 +486,10 @@ def train( def evaluate( self, - dataset=None, - per_device_batch_size=8, + dataset: Optional[Dataset] = None, + per_device_batch_size: int = 8, dataloader_kwargs: dict = None, - collate_fn=None, + collate_fn: Optional[callable] = None, ): """ Start an evaluation run. @@ -511,7 +527,7 @@ def evaluate( self._run_evaluation() - def get_default_train_dl_kwargs(self, batch_size) -> dict: + def get_default_train_dl_kwargs(self, batch_size: int) -> dict: """ Return the default arguments that will be used by the training dataloader. @@ -523,15 +539,17 @@ def get_default_train_dl_kwargs(self, batch_size) -> dict: "pin_memory": True if torch.cuda.is_available() else False, "batch_size": batch_size, "num_workers": max( - os.cpu_count() // torch.cuda.device_count() - if torch.cuda.is_available() - else os.cpu_count(), + ( + os.cpu_count() // torch.cuda.device_count() + if torch.cuda.is_available() + else os.cpu_count() + ), 1, ), "worker_init_fn": worker_init_fn, } - def get_default_eval_dl_kwargs(self, batch_size) -> dict: + def get_default_eval_dl_kwargs(self, batch_size: int) -> dict: """ Return the default arguments that will be used by the evaluation dataloader. @@ -543,9 +561,11 @@ def get_default_eval_dl_kwargs(self, batch_size) -> dict: "pin_memory": True if torch.cuda.is_available() else False, "batch_size": batch_size, "num_workers": max( - os.cpu_count() // torch.cuda.device_count() - if torch.cuda.is_available() - else os.cpu_count(), + ( + os.cpu_count() // torch.cuda.device_count() + if torch.cuda.is_available() + else os.cpu_count() + ), 1, ), "worker_init_fn": worker_init_fn, @@ -582,33 +602,35 @@ def _prepare_model_optimizer_and_dataloaders(self): if self._eval_dataloader is not None: components.append(self._eval_dataloader) - prepared_components = self._accelerator.prepare(*components) - - self.model = prepared_components[0] - self.optimizer = prepared_components[1] + model: torch.nn.Module + optimizer: torch.optim.Optimizer + data_loaders: List[DataLoader] + model, optimizer, data_loaders = self._accelerator.prepare(*components) + self.model = model + self.optimizer = optimizer if self._train_dataloader is not None: - self._train_dataloader = prepared_components[2] + self._train_dataloader = data_loaders[0] self._train_dataloader.batch_sampler.even_batches = True if self._eval_dataloader is not None: - self._eval_dataloader = prepared_components[3] + self._eval_dataloader = data_loaders[1] self._eval_dataloader.batch_sampler.even_batches = ( self._pad_uneven_eval_batches ) elif self._eval_dataloader is not None: - self._eval_dataloader = prepared_components[2] + self._eval_dataloader = data_loaders[0] self._eval_dataloader.batch_sampler.even_batches = ( self._pad_uneven_eval_batches ) def _create_run_config( self, - per_device_batch_size, - num_epochs, - gradient_accumulation_steps, - max_num_train_steps, - gradient_clip_value, + per_device_batch_size: int, + num_epochs: int, + gradient_accumulation_steps: int, + max_num_train_steps: int, + gradient_clip_value: float, ) -> TrainerRunConfig: """ Create an instance of :class:`~pytorch_accelerated.run_config.TrainerRunConfig` representing the current state of the trainer. @@ -653,18 +675,23 @@ def _create_run_config( "eval_per_device_batch_size": eval_per_device_batch_size, "eval_dl_kwargs": self._eval_dl_kwargs, "gradient_accumulation_steps": gradient_accumulation_steps, - "train_total_batch_size": train_per_device_batch_size - * self._accelerator.num_processes - * gradient_accumulation_steps, - "eval_total_batch_size": eval_per_device_batch_size - * self._accelerator.num_processes, + "train_total_batch_size": ( + train_per_device_batch_size + * self._accelerator.num_processes + * gradient_accumulation_steps + ), + "eval_total_batch_size": ( + eval_per_device_batch_size * self._accelerator.num_processes + ), "num_update_steps_per_epoch": num_update_steps_per_epoch, "max_num_train_steps": max_num_train_steps, "is_local_process_zero": self._accelerator.is_local_main_process, "is_world_process_zero": self._accelerator.is_main_process, - "is_distributed": True - if self._accelerator.distributed_type != DistributedType.NO - else False, + "is_distributed": ( + True + if self._accelerator.distributed_type != DistributedType.NO + else False + ), "mixed_precision": self._accelerator.mixed_precision, "gradient_clip_value": gradient_clip_value, "num_processes": self._accelerator.num_processes, @@ -678,9 +705,10 @@ def _check_eval_batch_size(self): if self.run_config.eval_total_batch_size > len(self.eval_dataset): raise ValueError( - f"The total batch size {self.run_config.eval_total_batch_size} \ - across all processes is bigger than eval dataset size {len(self.eval_dataset)}. \ - This can be resolved by lowering the batch size" + f"The total batch size {self.run_config.eval_total_batch_size} " + " across all processes is bigger than eval dataset size" + f" {len(self.eval_dataset)}. This can be resolved by" + " lowering the batch size" ) n_samples_last_batch = ( @@ -694,11 +722,13 @@ def _check_eval_batch_size(self): ) if 0 < n_samples_last_batch < min_samples_last_batch: warnings.warn( - f"The per device batch size {self.run_config.eval_per_device_batch_size} with the " - f"eval dataset size {len(self.eval_dataset)} and the number of processes " - f"{self.run_config.num_processes} will cause at least one process to have no " - "samples on the last batch, which would lead to a `Trainer.gather` to freeze " - "indefinitely. This can be resolved by setting a different batch size" + "The per device batch size" + f" {self.run_config.eval_per_device_batch_size} with the eval dataset" + f" size {len(self.eval_dataset)} and the number of processes" + f" {self.run_config.num_processes} will cause at least one process to" + " have no samples on the last batch, which would lead to a" + " `Trainer.gather` to freeze indefinitely. This can be resolved by" + " setting a different batch size" ) elif ( min_samples_last_batch @@ -706,12 +736,13 @@ def _check_eval_batch_size(self): < self.run_config.eval_total_batch_size ): warnings.warn( - f"The per device batch size {self.run_config.eval_per_device_batch_size} with the " - f"eval dataset size {len(self.eval_dataset)} and the number of processes " - f"{self.run_config.num_processes} will cause one process to have a smaller number " - "of samples on the last batch than the rest, which would lead to a " - "`Trainer.gather` to freeze indefinitely. This can be resolved by passing a " - "`padding_value` to the `Trainer.gather`." + "The per device batch size" + f" {self.run_config.eval_per_device_batch_size} with the eval dataset" + f" size {len(self.eval_dataset)} and the number of processes" + f" {self.run_config.num_processes} will cause one process to have a" + " smaller number of samples on the last batch than the rest, which" + " would lead to a `Trainer.gather` to freeze indefinitely. This can be" + " resolved by passing a `padding_value` to the `Trainer.gather`." ) def _run_training(self): @@ -770,7 +801,7 @@ def _run_evaluation(self): self, ) - def _run_train_epoch(self, train_dl): + def _run_train_epoch(self, train_dl: DataLoader): """ The method responsible for the behaviour of each training epoch. @@ -836,7 +867,7 @@ def _perform_forward_and_backward_passes(self, batch): ) self.backward_step(batch_output["loss"]) - def _update_loss_tracker(self, batch_loss, batch_size): + def _update_loss_tracker(self, batch_loss: torch.Tensor, batch_size: int): """ Update the loss calculated for each batch using the internal loss tracker. During each epoch, losses are tracked in individual processes. @@ -846,7 +877,7 @@ def _update_loss_tracker(self, batch_loss, batch_size): batch_size, ) - def _add_epoch_loss_to_run_history(self, metric_name): + def _add_epoch_loss_to_run_history(self, metric_name: str): """ Update the run history with the average of all batch losses calculated during the epoch across all processes. """ @@ -873,7 +904,7 @@ def _clip_gradients(self): self.model.parameters(), clip_value=self.run_config.gradient_clip_value ) - def _run_eval_epoch(self, valid_dl, is_training: bool = True): + def _run_eval_epoch(self, valid_dl: DataLoader, is_training: bool = True): """ The method responsible for the behaviour of each evaluation epoch. @@ -917,7 +948,7 @@ def _run_eval_epoch(self, valid_dl, is_training: bool = True): self, ) - def gather(self, tensor, padding_value=None): + def gather(self, tensor, padding_value: Optional[float] = None) -> torch.Tensor: """ Gather the values in `tensor` across all processes and concatenate them on the first dimension. This can be useful to regroup the predictions from all processes when doing evaluation. @@ -956,7 +987,11 @@ def print(self, *args, **kwargs): print(*args, **kwargs) def save_checkpoint( - self, save_path, checkpoint_kwargs=None, save_optimizer=True, save_per_node=True + self, + save_path: Union[str, Path], + checkpoint_kwargs=None, + save_optimizer=True, + save_per_node=True, ): """ Save the model, optimizer and specified args as a checkpoint file. @@ -992,7 +1027,7 @@ def save_checkpoint( save_path, ) - def load_checkpoint(self, checkpoint_path, load_optimizer=True): + def load_checkpoint(self, checkpoint_path: Union[str, Path], load_optimizer=True): """ Load the model and optimizer from a checkpoint file. @@ -1005,9 +1040,10 @@ def load_checkpoint(self, checkpoint_path, load_optimizer=True): if load_optimizer and "optimizer_state_dict" in checkpoint: if self.optimizer is None: raise ValueError( - "You are trying to load an optimizer from a checkpoint, but no optimizer" - "has been set in the Trainer. Either pass the correct optimizer instance when" - "creating the trainer, or specify load_optimizer=False when loading the checkpoint." + "You are trying to load an optimizer from a checkpoint, but no" + " optimizerhas been set in the Trainer. Either pass the correct" + " optimizer instance whencreating the trainer, or specify" + " load_optimizer=False when loading the checkpoint." ) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])