|
1 | 1 | # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
|
2 | 2 | # under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
|
3 | 3 |
|
| 4 | +import time |
4 | 5 | import warnings
|
5 | 6 | from abc import ABC, abstractmethod
|
6 | 7 | from copy import deepcopy
|
|
11 | 12 | Any,
|
12 | 13 | Callable,
|
13 | 14 | Dict,
|
| 15 | + Generic, |
14 | 16 | List,
|
15 | 17 | Literal,
|
16 | 18 | Optional,
|
| 19 | + Sequence, |
17 | 20 | Tuple,
|
18 | 21 | Union,
|
19 | 22 | )
|
|
22 | 25 | import torch
|
23 | 26 | from torch import Tensor
|
24 | 27 | from torch.distributions import Distribution
|
| 28 | +from torch.nn.utils.clip_grad import clip_grad_norm_ |
| 29 | +from torch.optim.adam import Adam |
25 | 30 | from torch.utils import data
|
26 | 31 | from torch.utils.data.sampler import SubsetRandomSampler
|
27 | 32 | from torch.utils.tensorboard.writer import SummaryWriter
|
|
47 | 52 | from sbi.neural_nets.estimators.base import (
|
48 | 53 | ConditionalDensityEstimator,
|
49 | 54 | ConditionalEstimator,
|
| 55 | + ConditionalEstimatorType, |
50 | 56 | ConditionalVectorFieldEstimator,
|
51 | 57 | )
|
52 | 58 | from sbi.sbi_types import TorchTransform
|
@@ -158,7 +164,7 @@ def infer(
|
158 | 164 | return posterior
|
159 | 165 |
|
160 | 166 |
|
161 |
| -class NeuralInference(ABC): |
| 167 | +class NeuralInference(ABC, Generic[ConditionalEstimatorType]): |
162 | 168 | """Abstract base class for neural inference methods."""
|
163 | 169 |
|
164 | 170 | def __init__(
|
@@ -307,7 +313,20 @@ def train(
|
307 | 313 | discard_prior_samples: bool = False,
|
308 | 314 | retrain_from_scratch: bool = False,
|
309 | 315 | show_train_summary: bool = False,
|
310 |
| - ) -> NeuralPosterior: ... |
| 316 | + ) -> ConditionalEstimatorType: ... |
| 317 | + |
| 318 | + @abstractmethod |
| 319 | + def _initialize_neural_network( |
| 320 | + self, retrain_from_scratch: bool, start_idx: int |
| 321 | + ) -> None: ... |
| 322 | + |
| 323 | + @abstractmethod |
| 324 | + def _get_start_index(self, discard_prior_samples: bool) -> int: ... |
| 325 | + |
| 326 | + @abstractmethod |
| 327 | + def _get_losses( |
| 328 | + self, batch: Sequence[Tensor], loss_kwargs: Dict[str, Any] |
| 329 | + ) -> Tensor: ... |
311 | 330 |
|
312 | 331 | @abstractmethod
|
313 | 332 | def _get_potential_function(
|
@@ -872,6 +891,199 @@ def _create_posterior(
|
872 | 891 | )
|
873 | 892 | return posterior
|
874 | 893 |
|
| 894 | + def _run_training_loop( |
| 895 | + self, |
| 896 | + train_loader: data.DataLoader, |
| 897 | + val_loader: data.DataLoader, |
| 898 | + max_num_epochs: int, |
| 899 | + stop_after_epochs: int, |
| 900 | + learning_rate: float, |
| 901 | + resume_training: bool, |
| 902 | + clip_max_norm: Optional[float], |
| 903 | + show_train_summary: bool, |
| 904 | + loss_kwargs: Optional[Dict[str, Any]] = None, |
| 905 | + summarization_kwargs: Optional[Dict[str, Any]] = None, |
| 906 | + ) -> ConditionalEstimatorType: |
| 907 | + """ |
| 908 | + Run the main training loop for the neural network, including epoch-wise |
| 909 | + training, validation, and convergence checking. |
| 910 | +
|
| 911 | + Args: |
| 912 | + train_loader: Dataloader for training. |
| 913 | + val_loader: Dataloader for validation. |
| 914 | + learning_rate: Learning rate for Adam optimizer. |
| 915 | + stop_after_epochs: The number of epochs to wait for improvement on the |
| 916 | + validation set before terminating training. |
| 917 | + max_num_epochs: Maximum number of epochs to run. If reached, we stop |
| 918 | + training even when the validation loss is still decreasing. Otherwise, |
| 919 | + we train until validation loss increases (see also `stop_after_epochs`). |
| 920 | + clip_max_norm: Value at which to clip the total gradient norm in order to |
| 921 | + prevent exploding gradients. Use None for no clipping. |
| 922 | + resume_training: Can be used in case training time is limited, e.g. on a |
| 923 | + cluster. If `True`, the split between train and validation set, the |
| 924 | + optimizer, the number of epochs, and the best validation log-prob will |
| 925 | + be restored from the last time `.train()` was called. |
| 926 | + show_train_summary: Whether to print the number of epochs and validation |
| 927 | + loss after the training. |
| 928 | + loss_kwargs: Additional or updated kwargs to be passed to the self._loss fn. |
| 929 | + summarization_kwargs: Additional kwargs passed to self._summarize_epoch fn. |
| 930 | + """ |
| 931 | + |
| 932 | + if loss_kwargs is None: |
| 933 | + loss_kwargs = {} |
| 934 | + |
| 935 | + if summarization_kwargs is None: |
| 936 | + summarization_kwargs = {} |
| 937 | + |
| 938 | + assert self._neural_net is not None |
| 939 | + |
| 940 | + # Move entire net to device for training. |
| 941 | + self._neural_net.to(self._device) |
| 942 | + |
| 943 | + if not resume_training: |
| 944 | + self.optimizer = Adam( |
| 945 | + list(self._neural_net.parameters()), |
| 946 | + lr=learning_rate, |
| 947 | + ) |
| 948 | + self.epoch, self.val_loss = 0, float("Inf") |
| 949 | + |
| 950 | + while self.epoch <= max_num_epochs and not self._converged( |
| 951 | + self.epoch, stop_after_epochs |
| 952 | + ): |
| 953 | + # Train for a single epoch. |
| 954 | + self._neural_net.train() |
| 955 | + epoch_start_time = time.time() |
| 956 | + train_loss = self._train_epoch(train_loader, clip_max_norm, loss_kwargs) |
| 957 | + |
| 958 | + # Calculate validation performance. |
| 959 | + self._neural_net.eval() |
| 960 | + |
| 961 | + self._val_loss = self._validate_epoch(val_loader, loss_kwargs) |
| 962 | + |
| 963 | + self._summarize_epoch( |
| 964 | + train_loss, self._val_loss, epoch_start_time, summarization_kwargs |
| 965 | + ) |
| 966 | + |
| 967 | + self.epoch += 1 |
| 968 | + self._maybe_show_progress(self._show_progress_bars, self.epoch) |
| 969 | + |
| 970 | + self._report_convergence_at_end(self.epoch, stop_after_epochs, max_num_epochs) |
| 971 | + |
| 972 | + # Update summary. |
| 973 | + self._summary["epochs_trained"].append(self.epoch) |
| 974 | + self._summary["best_validation_loss"].append(self._best_val_loss) |
| 975 | + |
| 976 | + # Update TensorBoard and summary dict. |
| 977 | + self._summarize(round_=self._round) |
| 978 | + |
| 979 | + # Update description for progress bar. |
| 980 | + if show_train_summary: |
| 981 | + print(self._describe_round(self._round, self._summary)) |
| 982 | + |
| 983 | + # Avoid keeping the gradients in the resulting network, which can |
| 984 | + # cause memory leakage when benchmarking. |
| 985 | + self._neural_net.zero_grad(set_to_none=True) |
| 986 | + |
| 987 | + return deepcopy(self._neural_net) |
| 988 | + |
| 989 | + def _train_epoch( |
| 990 | + self, |
| 991 | + train_loader: data.DataLoader, |
| 992 | + clip_max_norm: Optional[float], |
| 993 | + loss_kwargs: Dict[str, Any], |
| 994 | + ) -> float: |
| 995 | + """ |
| 996 | + Perform a single training epoch over the provided training data. |
| 997 | +
|
| 998 | + Args: |
| 999 | + train_loader: Dataloader for training. |
| 1000 | + clip_max_norm: Value at which to clip the total gradient norm in order to |
| 1001 | + prevent exploding gradients. Use None for no clipping. |
| 1002 | + loss_kwargs: Additional or updated kwargs to be passed to the self._loss fn. |
| 1003 | +
|
| 1004 | + Returns: |
| 1005 | + The average training loss over all samples in the epoch. |
| 1006 | + """ |
| 1007 | + |
| 1008 | + assert self._neural_net is not None |
| 1009 | + |
| 1010 | + train_loss_sum = 0 |
| 1011 | + for batch in train_loader: |
| 1012 | + self.optimizer.zero_grad() |
| 1013 | + train_losses = self._get_losses(batch=batch, loss_kwargs=loss_kwargs) |
| 1014 | + train_loss = torch.mean(train_losses) |
| 1015 | + train_loss_sum += train_losses.sum().item() |
| 1016 | + |
| 1017 | + train_loss.backward() |
| 1018 | + if clip_max_norm is not None: |
| 1019 | + clip_grad_norm_( |
| 1020 | + self._neural_net.parameters(), |
| 1021 | + max_norm=clip_max_norm, |
| 1022 | + ) |
| 1023 | + self.optimizer.step() |
| 1024 | + |
| 1025 | + train_loss_average = train_loss_sum / ( |
| 1026 | + len(train_loader) * train_loader.batch_size # type: ignore |
| 1027 | + ) |
| 1028 | + |
| 1029 | + return train_loss_average |
| 1030 | + |
| 1031 | + def _validate_epoch( |
| 1032 | + self, |
| 1033 | + val_loader: data.DataLoader, |
| 1034 | + loss_kwargs: Dict[str, Any], |
| 1035 | + ) -> float: |
| 1036 | + """ |
| 1037 | + Perform a single validation epoch over the provided validation data. |
| 1038 | +
|
| 1039 | + Args: |
| 1040 | + val_loader: Dataloader for validation. |
| 1041 | + loss_kwargs: Additional or updated kwargs to be passed to the self._loss fn. |
| 1042 | +
|
| 1043 | + Returns: |
| 1044 | + The average validation loss over all samples in the epoch. |
| 1045 | + """ |
| 1046 | + |
| 1047 | + val_loss_sum = 0 |
| 1048 | + with torch.no_grad(): |
| 1049 | + for batch in val_loader: |
| 1050 | + val_losses = self._get_losses(batch=batch, loss_kwargs=loss_kwargs) |
| 1051 | + val_loss_sum += val_losses.sum().item() |
| 1052 | + |
| 1053 | + # Take mean over all validation samples. |
| 1054 | + val_loss = val_loss_sum / ( |
| 1055 | + len(val_loader) * val_loader.batch_size # type: ignore |
| 1056 | + ) |
| 1057 | + |
| 1058 | + return val_loss |
| 1059 | + |
| 1060 | + def _summarize_epoch( |
| 1061 | + self, |
| 1062 | + train_loss: float, |
| 1063 | + val_loss: float, |
| 1064 | + epoch_start_time: float, |
| 1065 | + summarization_kwargs: Dict[str, Any], |
| 1066 | + ) -> None: |
| 1067 | + """ |
| 1068 | + Update internal summaries after a single training epoch. |
| 1069 | +
|
| 1070 | + Records training and validation losses, as well as the duration of the epoch, |
| 1071 | + in `self._summary` dictionary. |
| 1072 | +
|
| 1073 | + Args: |
| 1074 | + train_loss: The average training loss for the epoch. |
| 1075 | + val_loss: The average validation loss for the epoch. |
| 1076 | + epoch_start_time: Timestamp when the epoch started, used to compute |
| 1077 | + duration. |
| 1078 | + summarization_kwargs: Additional keyword arguments for customizing |
| 1079 | + the summarization. |
| 1080 | + """ |
| 1081 | + |
| 1082 | + self._summary["training_loss"].append(train_loss) |
| 1083 | + # Log validation loss for every epoch. |
| 1084 | + self._summary["validation_loss"].append(val_loss) |
| 1085 | + self._summary["epoch_durations_sec"].append(time.time() - epoch_start_time) |
| 1086 | + |
875 | 1087 | def _converged(self, epoch: int, stop_after_epochs: int) -> bool:
|
876 | 1088 | """Return whether the training converged yet and save best model state so far.
|
877 | 1089 |
|
|
0 commit comments