|
13 | 13 | from torch.distributions import Distribution, MultivariateNormal
|
14 | 14 |
|
15 | 15 | from sbi.inference.posteriors.direct_posterior import DirectPosterior
|
16 |
| -from sbi.inference.trainers.npe.npe_base import PosteriorEstimatorTrainer |
| 16 | +from sbi.inference.trainers.npe.npe_base import ( |
| 17 | + PosteriorEstimatorTrainer, |
| 18 | +) |
17 | 19 | from sbi.neural_nets.estimators.base import ConditionalDensityEstimator
|
18 |
| -from sbi.sbi_types import TensorboardSummaryWriter, TorchModule |
| 20 | +from sbi.sbi_types import TensorBoardSummaryWriter |
19 | 21 | from sbi.utils import torchutils
|
20 | 22 | from sbi.utils.sbiutils import (
|
21 | 23 | batched_mixture_mv,
|
@@ -51,7 +53,7 @@ def __init__(
|
51 | 53 | num_components: int = 10,
|
52 | 54 | device: str = "cpu",
|
53 | 55 | logging_level: Union[int, str] = "WARNING",
|
54 |
| - summary_writer: Optional[TensorboardSummaryWriter] = None, |
| 56 | + summary_writer: Optional[TensorBoardSummaryWriter] = None, |
55 | 57 | show_progress_bars: bool = True,
|
56 | 58 | ):
|
57 | 59 | r"""Initialize NPE-A [1].
|
@@ -231,7 +233,7 @@ def train(
|
231 | 233 |
|
232 | 234 | def correct_for_proposal(
|
233 | 235 | self,
|
234 |
| - density_estimator: Optional[TorchModule] = None, |
| 236 | + density_estimator: Optional[torch.nn.Module] = None, |
235 | 237 | ) -> "NPE_A_MDN":
|
236 | 238 | r"""Build mixture of Gaussians that approximates the posterior.
|
237 | 239 |
|
@@ -285,7 +287,7 @@ def correct_for_proposal(
|
285 | 287 |
|
286 | 288 | def build_posterior(
|
287 | 289 | self,
|
288 |
| - density_estimator: Optional[TorchModule] = None, |
| 290 | + density_estimator: Optional[torch.nn.Module] = None, |
289 | 291 | prior: Optional[Distribution] = None,
|
290 | 292 | **kwargs,
|
291 | 293 | ) -> "DirectPosterior":
|
|
0 commit comments