Skip to content

Commit eb39b09

Browse files
committed
refactor: modularize epoch validation and training methods
1 parent 277be50 commit eb39b09

File tree

5 files changed

+172
-324
lines changed

5 files changed

+172
-324
lines changed

sbi/inference/trainers/base.py

Lines changed: 73 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import torch
2424
from torch import Tensor
2525
from torch.distributions import Distribution
26+
from torch.nn.utils.clip_grad import clip_grad_norm_
2627
from torch.optim.adam import Adam
2728
from torch.utils import data
2829
from torch.utils.data.sampler import SubsetRandomSampler
@@ -321,29 +322,14 @@ def _initialize_neural_network(
321322
def _get_start_index(self, discard_prior_samples: bool) -> int: ...
322323

323324
@abstractmethod
324-
def _train_epoch(
325-
self,
326-
train_loader: data.DataLoader,
327-
clip_max_norm: Optional[float],
328-
loss_kwargs: dict,
329-
) -> float: ...
325+
def _get_training_losses(
326+
self, batch: Any, loss_kwargs: Dict[str, Any]
327+
) -> Tensor: ...
330328

331329
@abstractmethod
332-
def _validate_epoch(
333-
self,
334-
val_loader: data.DataLoader,
335-
loss_kwargs: dict,
336-
validation_kwargs: dict,
337-
) -> float: ...
338-
339-
@abstractmethod
340-
def _summarize_epoch(
341-
self,
342-
train_loss: float,
343-
val_loss: float,
344-
epoch_start_time: float,
345-
summarization_kwargs: dict,
346-
) -> None: ...
330+
def _get_validation_losses(
331+
self, batch: Any, loss_kwargs: Dict[str, Any]
332+
) -> Tensor: ...
347333

348334
@abstractmethod
349335
def _get_potential_function(
@@ -919,15 +905,13 @@ def _train(
919905
clip_max_norm: Optional[float],
920906
show_train_summary: bool,
921907
loss_kwargs: Optional[Dict[str, Any]] = None,
922-
validation_kwargs: Optional[Dict[str, Any]] = None,
923908
summarization_kwargs: Optional[Dict[str, Any]] = None,
924909
):
925910
"""Main training pipeline using a config object."""
926911

927912
if loss_kwargs is None:
928913
loss_kwargs = {}
929-
if validation_kwargs is None:
930-
validation_kwargs = {}
914+
931915
if summarization_kwargs is None:
932916
summarization_kwargs = {}
933917

@@ -950,7 +934,7 @@ def _train(
950934

951935
# Calculate validation performance.
952936
self._neural_net.eval()
953-
val_loss = self._validate_epoch(val_loader, loss_kwargs, validation_kwargs)
937+
val_loss = self._validate_epoch(val_loader, loss_kwargs)
954938

955939
self._summarize_epoch(
956940
train_loss, val_loss, epoch_start_time, summarization_kwargs
@@ -978,6 +962,70 @@ def _train(
978962

979963
return deepcopy(self._neural_net)
980964

965+
def _train_epoch(
966+
self,
967+
train_loader: data.DataLoader,
968+
clip_max_norm: Optional[float],
969+
loss_kwargs: Dict[str, Any],
970+
) -> float:
971+
assert self._neural_net is not None
972+
973+
train_loss_sum = 0
974+
for batch in train_loader:
975+
self.optimizer.zero_grad()
976+
train_losses = self._get_training_losses(batch, loss_kwargs=loss_kwargs)
977+
train_loss = torch.mean(train_losses)
978+
train_loss_sum += train_losses.sum().item()
979+
980+
train_loss.backward()
981+
if clip_max_norm is not None:
982+
clip_grad_norm_(
983+
self._neural_net.parameters(),
984+
max_norm=clip_max_norm,
985+
)
986+
self.optimizer.step()
987+
988+
train_loss_average = train_loss_sum / (
989+
len(train_loader) * train_loader.batch_size # type: ignore
990+
)
991+
992+
return train_loss_average
993+
994+
def _validate_epoch(
995+
self,
996+
val_loader: data.DataLoader,
997+
loss_kwargs: Dict[str, Any],
998+
) -> float:
999+
val_loss_sum = 0
1000+
with torch.no_grad():
1001+
for batch in val_loader:
1002+
val_losses = self._get_validation_losses(
1003+
batch=batch,
1004+
loss_kwargs=loss_kwargs,
1005+
)
1006+
val_loss_sum += val_losses.sum().item()
1007+
1008+
# Take mean over all validation samples.
1009+
val_loss = val_loss_sum / (
1010+
len(val_loader) * val_loader.batch_size # type: ignore
1011+
)
1012+
1013+
return val_loss
1014+
1015+
def _summarize_epoch(
1016+
self,
1017+
train_loss: float,
1018+
val_loss: float,
1019+
epoch_start_time: float,
1020+
summarization_kwargs: Dict[str, Any],
1021+
) -> None:
1022+
self._summary["training_loss"].append(train_loss)
1023+
1024+
self._val_loss = val_loss
1025+
# Log validation loss for every epoch.
1026+
self._summary["validation_loss"].append(self._val_loss)
1027+
self._summary["epoch_durations_sec"].append(time.time() - epoch_start_time)
1028+
9811029
def _converged(self, epoch: int, stop_after_epochs: int) -> bool:
9821030
"""Return whether the training converged yet and save best model state so far.
9831031

sbi/inference/trainers/nle/nle_base.py

Lines changed: 14 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
import time
54
import warnings
65
from abc import ABC
76
from typing import Any, Dict, Literal, Optional, Tuple, Union
87

9-
import torch
108
from torch import Tensor
119
from torch.distributions import Distribution
12-
from torch.nn.utils.clip_grad import clip_grad_norm_
13-
from torch.utils import data
1410
from torch.utils.tensorboard.writer import SummaryWriter
1511
from typing_extensions import Self
1612

@@ -240,75 +236,25 @@ def _initialize_neural_network(
240236
)
241237
del theta, x
242238

243-
def _train_epoch(
244-
self,
245-
train_loader: data.DataLoader,
246-
clip_max_norm: Optional[float],
247-
loss_kwargs: Dict[str, Any],
248-
) -> float:
249-
train_loss_sum = 0
250-
for batch in train_loader:
251-
self.optimizer.zero_grad()
252-
theta_batch, x_batch = (
253-
batch[0].to(self._device),
254-
batch[1].to(self._device),
255-
)
256-
# Evaluate on x with theta as context.
257-
train_losses = self._loss(theta=theta_batch, x=x_batch)
258-
train_loss = torch.mean(train_losses)
259-
train_loss_sum += train_losses.sum().item()
260-
261-
train_loss.backward()
262-
if clip_max_norm is not None:
263-
clip_grad_norm_(
264-
self._neural_net.parameters(),
265-
max_norm=clip_max_norm,
266-
)
267-
self.optimizer.step()
268-
269-
train_loss_average = train_loss_sum / (
270-
len(train_loader) * train_loader.batch_size # type: ignore
239+
def _get_training_losses(self, batch: Any, loss_kwargs: Dict[str, Any]) -> Tensor:
240+
theta_batch, x_batch = (
241+
batch[0].to(self._device),
242+
batch[1].to(self._device),
271243
)
244+
# Evaluate on x with theta as context.
245+
train_losses = self._loss(theta=theta_batch, x=x_batch)
272246

273-
return train_loss_average
247+
return train_losses
274248

275-
def _validate_epoch(
276-
self,
277-
val_loader: data.DataLoader,
278-
loss_kwargs: Dict[str, Any],
279-
validation_kwargs: Dict[str, Any],
280-
) -> float:
281-
val_loss_sum = 0
282-
with torch.no_grad():
283-
for batch in val_loader:
284-
theta_batch, x_batch = (
285-
batch[0].to(self._device),
286-
batch[1].to(self._device),
287-
)
288-
# Evaluate on x with theta as context.
289-
val_losses = self._loss(theta=theta_batch, x=x_batch)
290-
val_loss_sum += val_losses.sum().item()
291-
292-
# Take mean over all validation samples.
293-
val_loss = val_loss_sum / (
294-
len(val_loader) * val_loader.batch_size # type: ignore
249+
def _get_validation_losses(self, batch: Any, loss_kwargs: Dict[str, Any]) -> Tensor:
250+
theta_batch, x_batch = (
251+
batch[0].to(self._device),
252+
batch[1].to(self._device),
295253
)
254+
# Evaluate on x with theta as context.
255+
val_losses = self._loss(theta=theta_batch, x=x_batch)
296256

297-
return val_loss
298-
299-
def _summarize_epoch(
300-
self,
301-
train_loss: float,
302-
val_loss: float,
303-
epoch_start_time: float,
304-
summarization_kwargs: Dict[str, Any],
305-
) -> None:
306-
self._summary["training_loss"].append(train_loss)
307-
308-
self._val_loss = val_loss
309-
# Log validation loss for every epoch.
310-
self._summary["validation_loss"].append(self._val_loss)
311-
self._summary["epoch_durations_sec"].append(time.time() - epoch_start_time)
257+
return val_losses
312258

313259
def build_posterior(
314260
self,

sbi/inference/trainers/npe/npe_base.py

Lines changed: 16 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
import time
54
from abc import ABC, abstractmethod
65
from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union
76
from warnings import warn
87

9-
import torch
108
from torch import Tensor, ones
119
from torch.distributions import Distribution
12-
from torch.nn.utils.clip_grad import clip_grad_norm_
13-
from torch.utils import data
1410
from torch.utils.tensorboard.writer import SummaryWriter
1511
from typing_extensions import Self
1612

@@ -403,78 +399,28 @@ def _initialize_neural_network(
403399
# Move entire net to device for training.
404400
self._neural_net.to(self._device)
405401

406-
def _train_epoch(
407-
self,
408-
train_loader: data.DataLoader,
409-
clip_max_norm: Optional[float],
410-
loss_kwargs: dict,
411-
) -> float:
412-
train_loss_sum = 0
413-
for batch in train_loader:
414-
self.optimizer.zero_grad()
415-
# Get batches on current device.
416-
theta_batch, x_batch, masks_batch = (
417-
batch[0].to(self._device),
418-
batch[1].to(self._device),
419-
batch[2].to(self._device),
420-
)
421-
422-
train_losses = self._loss(theta_batch, x_batch, masks_batch, **loss_kwargs)
423-
train_loss = torch.mean(train_losses)
424-
train_loss_sum += train_losses.sum().item()
425-
426-
train_loss.backward()
427-
if clip_max_norm is not None:
428-
clip_grad_norm_(self._neural_net.parameters(), max_norm=clip_max_norm)
429-
self.optimizer.step()
430-
431-
train_loss_average = train_loss_sum / (
432-
len(train_loader) * train_loader.batch_size # type: ignore
402+
def _get_training_losses(self, batch: Any, loss_kwargs: Dict[str, Any]) -> Tensor:
403+
# Get batches on current device.
404+
theta_batch, x_batch, masks_batch = (
405+
batch[0].to(self._device),
406+
batch[1].to(self._device),
407+
batch[2].to(self._device),
433408
)
434409

435-
return train_loss_average
410+
train_losses = self._loss(theta_batch, x_batch, masks_batch, **loss_kwargs)
436411

437-
def _validate_epoch(
438-
self,
439-
val_loader: data.DataLoader,
440-
loss_kwargs: dict,
441-
validation_kwargs: dict,
442-
) -> float:
443-
val_loss_sum = 0
444-
445-
with torch.no_grad():
446-
for batch in val_loader:
447-
theta_batch, x_batch, masks_batch = (
448-
batch[0].to(self._device),
449-
batch[1].to(self._device),
450-
batch[2].to(self._device),
451-
)
452-
# Take negative loss here to get validation log_prob.
453-
val_losses = self._loss(
454-
theta_batch, x_batch, masks_batch, **loss_kwargs
455-
)
456-
val_loss_sum += val_losses.sum().item()
412+
return train_losses
457413

458-
# Take mean over all validation samples.
459-
val_loss = val_loss_sum / (
460-
len(val_loader) * val_loader.batch_size # type: ignore
414+
def _get_validation_losses(self, batch: Any, loss_kwargs: Dict[str, Any]) -> Tensor:
415+
theta_batch, x_batch, masks_batch = (
416+
batch[0].to(self._device),
417+
batch[1].to(self._device),
418+
batch[2].to(self._device),
461419
)
420+
# Take negative loss here to get validation log_prob.
421+
val_losses = self._loss(theta_batch, x_batch, masks_batch, **loss_kwargs)
462422

463-
return val_loss
464-
465-
def _summarize_epoch(
466-
self,
467-
train_loss: float,
468-
val_loss: float,
469-
epoch_start_time: float,
470-
summarization_kwargs: dict,
471-
) -> None:
472-
self._summary["training_loss"].append(train_loss)
473-
474-
self._val_loss = val_loss
475-
# Log validation loss for every epoch.
476-
self._summary["validation_loss"].append(self._val_loss)
477-
self._summary["epoch_durations_sec"].append(time.time() - epoch_start_time)
423+
return val_losses
478424

479425
def build_posterior(
480426
self,

0 commit comments

Comments
 (0)