Skip to content

Commit 18ebafe

Browse files
committed
chore: update return types for train methods
1 parent cf0c06d commit 18ebafe

File tree

4 files changed

+9
-15
lines changed

4 files changed

+9
-15
lines changed

sbi/inference/trainers/base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from warnings import warn
2222

2323
import torch
24-
from torch import Tensor, nn
24+
from torch import Tensor
2525
from torch.distributions import Distribution
2626
from torch.nn.utils.clip_grad import clip_grad_norm_
2727
from torch.optim.adam import Adam
@@ -311,7 +311,7 @@ def train(
311311
discard_prior_samples: bool = False,
312312
retrain_from_scratch: bool = False,
313313
show_train_summary: bool = False,
314-
) -> NeuralPosterior: ...
314+
) -> Union[ConditionalEstimator, RatioEstimator]: ...
315315

316316
@abstractmethod
317317
def _initialize_neural_network(
@@ -906,7 +906,7 @@ def _run_training_loop(
906906
show_train_summary: bool,
907907
loss_kwargs: Optional[Dict[str, Any]] = None,
908908
summarization_kwargs: Optional[Dict[str, Any]] = None,
909-
) -> nn.Module:
909+
) -> Union[ConditionalEstimator, RatioEstimator]:
910910
"""
911911
Run the main training loop for the neural network, including epoch-wise
912912
training, validation, and convergence checking.
@@ -940,6 +940,9 @@ def _run_training_loop(
940940

941941
assert self._neural_net is not None
942942

943+
# Move entire net to device for training.
944+
self._neural_net.to(self._device)
945+
943946
if not resume_training:
944947
self.optimizer = Adam(
945948
list(self._neural_net.parameters()),

sbi/inference/trainers/npe/npe_base.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,6 @@ def default_calibration_kernel(x):
319319
calibration_kernel=calibration_kernel,
320320
force_first_round_loss=force_first_round_loss,
321321
)
322-
323322
return self._run_training_loop( # type: ignore
324323
train_loader=train_loader,
325324
val_loader=val_loader,
@@ -645,9 +644,6 @@ def _initialize_neural_network(
645644

646645
del theta, x
647646

648-
# Move entire net to device for training.
649-
self._neural_net.to(self._device)
650-
651647
def _get_training_losses(self, batch: Any, loss_kwargs: Dict[str, Any]) -> Tensor:
652648
"""
653649
Compute training losses for a batch of data.

sbi/inference/trainers/nre/nre_base.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any, Dict, Literal, Optional, Protocol, Tuple, Union
77

88
import torch
9-
from torch import Tensor, eye, nn, ones
9+
from torch import Tensor, eye, ones
1010
from torch.distributions import Distribution
1111
from torch.utils.tensorboard.writer import SummaryWriter
1212
from typing_extensions import Self
@@ -175,7 +175,7 @@ def train(
175175
show_train_summary: bool = False,
176176
dataloader_kwargs: Optional[Dict] = None,
177177
loss_kwargs: Optional[Dict[str, Any]] = None,
178-
) -> nn.Module:
178+
) -> RatioEstimator:
179179
r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
180180
181181
Args:
@@ -237,7 +237,7 @@ def train(
237237

238238
loss_kwargs["num_atoms"] = num_atoms
239239

240-
return self._run_training_loop(
240+
return self._run_training_loop( # type: ignore
241241
train_loader=train_loader,
242242
val_loader=val_loader,
243243
max_num_epochs=max_num_epochs,
@@ -428,8 +428,6 @@ def _initialize_neural_network(
428428

429429
del x, theta
430430

431-
self._neural_net.to(self._device)
432-
433431
def _get_training_losses(self, batch: Any, loss_kwargs: Dict[str, Any]) -> Tensor:
434432
"""
435433
Compute training losses for a batch of data.

sbi/inference/trainers/vfpe/base_vf_inference.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,3 @@ def _initialize_neural_network(
667667
)
668668

669669
del theta, x
670-
671-
# Move entire net to device for training.
672-
self._neural_net.to(self._device)

0 commit comments

Comments
 (0)