Skip to content

Commit 27b0824

Browse files
authored
refactor: modularize inference training method (#1651)
* refactor: add abstract methods and training methods to base class * refactor(nle): modularize trainer code * refactor(npe): modularize trainer function * refactor(nre): modularize trainer method * refactor(vector_field): modularize trainer function * refactor: add method for getting starting index * refactor: rename _update_summary function * chore: update import * refactor: update train method modularization focusing on the training loop * refactor: modularize epoch validation and training methods * doc: update training method docstring * chore: add Any return type * chore(vector_field): update dictionary argument resolution * chore: skip validation times from loss_kwargs dictionary * chore: add docstrings and update method arrangements * chore: update return types for train methods * test(npe): expect AssertionError instead of AttributeError * chore: move self._val_loss to _run_training_loop method * refactor: combine loss methods under _get_losses method * docs: add docstrings for missing methods * refactor: add generic type to NeuralInference base class * chore: update nre train method return type * chore: update _get_losses method parameter type annotation
1 parent 28f3deb commit 27b0824

File tree

10 files changed

+791
-569
lines changed

10 files changed

+791
-569
lines changed

sbi/inference/trainers/base.py

Lines changed: 214 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4+
import time
45
import warnings
56
from abc import ABC, abstractmethod
67
from copy import deepcopy
@@ -11,9 +12,11 @@
1112
Any,
1213
Callable,
1314
Dict,
15+
Generic,
1416
List,
1517
Literal,
1618
Optional,
19+
Sequence,
1720
Tuple,
1821
Union,
1922
)
@@ -22,6 +25,8 @@
2225
import torch
2326
from torch import Tensor
2427
from torch.distributions import Distribution
28+
from torch.nn.utils.clip_grad import clip_grad_norm_
29+
from torch.optim.adam import Adam
2530
from torch.utils import data
2631
from torch.utils.data.sampler import SubsetRandomSampler
2732
from torch.utils.tensorboard.writer import SummaryWriter
@@ -47,6 +52,7 @@
4752
from sbi.neural_nets.estimators.base import (
4853
ConditionalDensityEstimator,
4954
ConditionalEstimator,
55+
ConditionalEstimatorType,
5056
ConditionalVectorFieldEstimator,
5157
)
5258
from sbi.sbi_types import TorchTransform
@@ -158,7 +164,7 @@ def infer(
158164
return posterior
159165

160166

161-
class NeuralInference(ABC):
167+
class NeuralInference(ABC, Generic[ConditionalEstimatorType]):
162168
"""Abstract base class for neural inference methods."""
163169

164170
def __init__(
@@ -307,7 +313,20 @@ def train(
307313
discard_prior_samples: bool = False,
308314
retrain_from_scratch: bool = False,
309315
show_train_summary: bool = False,
310-
) -> NeuralPosterior: ...
316+
) -> ConditionalEstimatorType: ...
317+
318+
@abstractmethod
319+
def _initialize_neural_network(
320+
self, retrain_from_scratch: bool, start_idx: int
321+
) -> None: ...
322+
323+
@abstractmethod
324+
def _get_start_index(self, discard_prior_samples: bool) -> int: ...
325+
326+
@abstractmethod
327+
def _get_losses(
328+
self, batch: Sequence[Tensor], loss_kwargs: Dict[str, Any]
329+
) -> Tensor: ...
311330

312331
@abstractmethod
313332
def _get_potential_function(
@@ -872,6 +891,199 @@ def _create_posterior(
872891
)
873892
return posterior
874893

894+
def _run_training_loop(
895+
self,
896+
train_loader: data.DataLoader,
897+
val_loader: data.DataLoader,
898+
max_num_epochs: int,
899+
stop_after_epochs: int,
900+
learning_rate: float,
901+
resume_training: bool,
902+
clip_max_norm: Optional[float],
903+
show_train_summary: bool,
904+
loss_kwargs: Optional[Dict[str, Any]] = None,
905+
summarization_kwargs: Optional[Dict[str, Any]] = None,
906+
) -> ConditionalEstimatorType:
907+
"""
908+
Run the main training loop for the neural network, including epoch-wise
909+
training, validation, and convergence checking.
910+
911+
Args:
912+
train_loader: Dataloader for training.
913+
val_loader: Dataloader for validation.
914+
learning_rate: Learning rate for Adam optimizer.
915+
stop_after_epochs: The number of epochs to wait for improvement on the
916+
validation set before terminating training.
917+
max_num_epochs: Maximum number of epochs to run. If reached, we stop
918+
training even when the validation loss is still decreasing. Otherwise,
919+
we train until validation loss increases (see also `stop_after_epochs`).
920+
clip_max_norm: Value at which to clip the total gradient norm in order to
921+
prevent exploding gradients. Use None for no clipping.
922+
resume_training: Can be used in case training time is limited, e.g. on a
923+
cluster. If `True`, the split between train and validation set, the
924+
optimizer, the number of epochs, and the best validation log-prob will
925+
be restored from the last time `.train()` was called.
926+
show_train_summary: Whether to print the number of epochs and validation
927+
loss after the training.
928+
loss_kwargs: Additional or updated kwargs to be passed to the self._loss fn.
929+
summarization_kwargs: Additional kwargs passed to self._summarize_epoch fn.
930+
"""
931+
932+
if loss_kwargs is None:
933+
loss_kwargs = {}
934+
935+
if summarization_kwargs is None:
936+
summarization_kwargs = {}
937+
938+
assert self._neural_net is not None
939+
940+
# Move entire net to device for training.
941+
self._neural_net.to(self._device)
942+
943+
if not resume_training:
944+
self.optimizer = Adam(
945+
list(self._neural_net.parameters()),
946+
lr=learning_rate,
947+
)
948+
self.epoch, self.val_loss = 0, float("Inf")
949+
950+
while self.epoch <= max_num_epochs and not self._converged(
951+
self.epoch, stop_after_epochs
952+
):
953+
# Train for a single epoch.
954+
self._neural_net.train()
955+
epoch_start_time = time.time()
956+
train_loss = self._train_epoch(train_loader, clip_max_norm, loss_kwargs)
957+
958+
# Calculate validation performance.
959+
self._neural_net.eval()
960+
961+
self._val_loss = self._validate_epoch(val_loader, loss_kwargs)
962+
963+
self._summarize_epoch(
964+
train_loss, self._val_loss, epoch_start_time, summarization_kwargs
965+
)
966+
967+
self.epoch += 1
968+
self._maybe_show_progress(self._show_progress_bars, self.epoch)
969+
970+
self._report_convergence_at_end(self.epoch, stop_after_epochs, max_num_epochs)
971+
972+
# Update summary.
973+
self._summary["epochs_trained"].append(self.epoch)
974+
self._summary["best_validation_loss"].append(self._best_val_loss)
975+
976+
# Update TensorBoard and summary dict.
977+
self._summarize(round_=self._round)
978+
979+
# Update description for progress bar.
980+
if show_train_summary:
981+
print(self._describe_round(self._round, self._summary))
982+
983+
# Avoid keeping the gradients in the resulting network, which can
984+
# cause memory leakage when benchmarking.
985+
self._neural_net.zero_grad(set_to_none=True)
986+
987+
return deepcopy(self._neural_net)
988+
989+
def _train_epoch(
990+
self,
991+
train_loader: data.DataLoader,
992+
clip_max_norm: Optional[float],
993+
loss_kwargs: Dict[str, Any],
994+
) -> float:
995+
"""
996+
Perform a single training epoch over the provided training data.
997+
998+
Args:
999+
train_loader: Dataloader for training.
1000+
clip_max_norm: Value at which to clip the total gradient norm in order to
1001+
prevent exploding gradients. Use None for no clipping.
1002+
loss_kwargs: Additional or updated kwargs to be passed to the self._loss fn.
1003+
1004+
Returns:
1005+
The average training loss over all samples in the epoch.
1006+
"""
1007+
1008+
assert self._neural_net is not None
1009+
1010+
train_loss_sum = 0
1011+
for batch in train_loader:
1012+
self.optimizer.zero_grad()
1013+
train_losses = self._get_losses(batch=batch, loss_kwargs=loss_kwargs)
1014+
train_loss = torch.mean(train_losses)
1015+
train_loss_sum += train_losses.sum().item()
1016+
1017+
train_loss.backward()
1018+
if clip_max_norm is not None:
1019+
clip_grad_norm_(
1020+
self._neural_net.parameters(),
1021+
max_norm=clip_max_norm,
1022+
)
1023+
self.optimizer.step()
1024+
1025+
train_loss_average = train_loss_sum / (
1026+
len(train_loader) * train_loader.batch_size # type: ignore
1027+
)
1028+
1029+
return train_loss_average
1030+
1031+
def _validate_epoch(
1032+
self,
1033+
val_loader: data.DataLoader,
1034+
loss_kwargs: Dict[str, Any],
1035+
) -> float:
1036+
"""
1037+
Perform a single validation epoch over the provided validation data.
1038+
1039+
Args:
1040+
val_loader: Dataloader for validation.
1041+
loss_kwargs: Additional or updated kwargs to be passed to the self._loss fn.
1042+
1043+
Returns:
1044+
The average validation loss over all samples in the epoch.
1045+
"""
1046+
1047+
val_loss_sum = 0
1048+
with torch.no_grad():
1049+
for batch in val_loader:
1050+
val_losses = self._get_losses(batch=batch, loss_kwargs=loss_kwargs)
1051+
val_loss_sum += val_losses.sum().item()
1052+
1053+
# Take mean over all validation samples.
1054+
val_loss = val_loss_sum / (
1055+
len(val_loader) * val_loader.batch_size # type: ignore
1056+
)
1057+
1058+
return val_loss
1059+
1060+
def _summarize_epoch(
1061+
self,
1062+
train_loss: float,
1063+
val_loss: float,
1064+
epoch_start_time: float,
1065+
summarization_kwargs: Dict[str, Any],
1066+
) -> None:
1067+
"""
1068+
Update internal summaries after a single training epoch.
1069+
1070+
Records training and validation losses, as well as the duration of the epoch,
1071+
in `self._summary` dictionary.
1072+
1073+
Args:
1074+
train_loss: The average training loss for the epoch.
1075+
val_loss: The average validation loss for the epoch.
1076+
epoch_start_time: Timestamp when the epoch started, used to compute
1077+
duration.
1078+
summarization_kwargs: Additional keyword arguments for customizing
1079+
the summarization.
1080+
"""
1081+
1082+
self._summary["training_loss"].append(train_loss)
1083+
# Log validation loss for every epoch.
1084+
self._summary["validation_loss"].append(val_loss)
1085+
self._summary["epoch_durations_sec"].append(time.time() - epoch_start_time)
1086+
8751087
def _converged(self, epoch: int, stop_after_epochs: int) -> bool:
8761088
"""Return whether the training converged yet and save best model state so far.
8771089

0 commit comments

Comments
 (0)