Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sbi/inference/trainers/nle/mnle.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from sbi.inference.trainers.nle.nle_base import LikelihoodEstimatorTrainer
from sbi.neural_nets.estimators import MixedDensityEstimator
from sbi.sbi_types import TensorboardSummaryWriter
from sbi.sbi_types import TensorBoardSummaryWriter
from sbi.utils.sbiutils import del_entries


Expand All @@ -36,7 +36,7 @@ def __init__(
density_estimator: Union[str, Callable] = "mnle",
device: str = "cpu",
logging_level: Union[int, str] = "WARNING",
summary_writer: Optional[TensorboardSummaryWriter] = None,
summary_writer: Optional[TensorBoardSummaryWriter] = None,
show_progress_bars: bool = True,
):
r"""Initialize MNLE.
Expand Down
4 changes: 2 additions & 2 deletions sbi/inference/trainers/nle/nle_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.distributions import Distribution

from sbi.inference.trainers.nle.nle_base import LikelihoodEstimatorTrainer
from sbi.sbi_types import TensorboardSummaryWriter
from sbi.sbi_types import TensorBoardSummaryWriter
from sbi.utils.sbiutils import del_entries


Expand All @@ -24,7 +24,7 @@ def __init__(
density_estimator: Union[str, Callable] = "maf",
device: str = "cpu",
logging_level: Union[int, str] = "WARNING",
summary_writer: Optional[TensorboardSummaryWriter] = None,
summary_writer: Optional[TensorBoardSummaryWriter] = None,
show_progress_bars: bool = True,
):
r"""Initialize Neural Likelihood Estimation.
Expand Down
4 changes: 2 additions & 2 deletions sbi/inference/trainers/npe/mnpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from sbi.inference.trainers.npe.npe_c import NPE_C
from sbi.neural_nets.estimators import MixedDensityEstimator
from sbi.sbi_types import TensorboardSummaryWriter
from sbi.sbi_types import TensorBoardSummaryWriter
from sbi.utils.sbiutils import del_entries


Expand All @@ -34,7 +34,7 @@ def __init__(
density_estimator: Union[str, Callable] = "mnpe",
device: str = "cpu",
logging_level: Union[int, str] = "WARNING",
summary_writer: Optional[TensorboardSummaryWriter] = None,
summary_writer: Optional[TensorBoardSummaryWriter] = None,
show_progress_bars: bool = True,
):
r"""Initialize Mixed Neural Posterior Estimation (MNPE).
Expand Down
12 changes: 7 additions & 5 deletions sbi/inference/trainers/npe/npe_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
from torch.distributions import Distribution, MultivariateNormal

from sbi.inference.posteriors.direct_posterior import DirectPosterior
from sbi.inference.trainers.npe.npe_base import PosteriorEstimatorTrainer
from sbi.inference.trainers.npe.npe_base import (
PosteriorEstimatorTrainer,
)
from sbi.neural_nets.estimators.base import ConditionalDensityEstimator
from sbi.sbi_types import TensorboardSummaryWriter, TorchModule
from sbi.sbi_types import TensorBoardSummaryWriter
from sbi.utils import torchutils
from sbi.utils.sbiutils import (
batched_mixture_mv,
Expand Down Expand Up @@ -51,7 +53,7 @@ def __init__(
num_components: int = 10,
device: str = "cpu",
logging_level: Union[int, str] = "WARNING",
summary_writer: Optional[TensorboardSummaryWriter] = None,
summary_writer: Optional[TensorBoardSummaryWriter] = None,
show_progress_bars: bool = True,
):
r"""Initialize NPE-A [1].
Expand Down Expand Up @@ -231,7 +233,7 @@ def train(

def correct_for_proposal(
self,
density_estimator: Optional[TorchModule] = None,
density_estimator: Optional[torch.nn.Module] = None,
) -> "NPE_A_MDN":
r"""Build mixture of Gaussians that approximates the posterior.

Expand Down Expand Up @@ -285,7 +287,7 @@ def correct_for_proposal(

def build_posterior(
self,
density_estimator: Optional[TorchModule] = None,
density_estimator: Optional[torch.nn.Module] = None,
prior: Optional[Distribution] = None,
**kwargs,
) -> "DirectPosterior":
Expand Down
4 changes: 2 additions & 2 deletions sbi/inference/trainers/npe/npe_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import sbi.utils as utils
from sbi.inference.trainers.npe.npe_base import PosteriorEstimatorTrainer
from sbi.neural_nets.estimators.shape_handling import reshape_to_sample_batch_event
from sbi.sbi_types import TensorboardSummaryWriter
from sbi.sbi_types import TensorBoardSummaryWriter
from sbi.utils.sbiutils import del_entries


Expand All @@ -36,7 +36,7 @@ def __init__(
density_estimator: Union[str, Callable] = "maf",
device: str = "cpu",
logging_level: Union[int, str] = "WARNING",
summary_writer: Optional[TensorboardSummaryWriter] = None,
summary_writer: Optional[TensorBoardSummaryWriter] = None,
show_progress_bars: bool = True,
):
r"""Initialize NPE-B.
Expand Down
4 changes: 2 additions & 2 deletions sbi/inference/trainers/npe/npe_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
reshape_to_batch_event,
reshape_to_sample_batch_event,
)
from sbi.sbi_types import TensorboardSummaryWriter
from sbi.sbi_types import TensorBoardSummaryWriter
from sbi.utils import (
batched_mixture_mv,
batched_mixture_vmv,
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(
density_estimator: Union[str, Callable] = "maf",
device: str = "cpu",
logging_level: Union[int, str] = "WARNING",
summary_writer: Optional[TensorboardSummaryWriter] = None,
summary_writer: Optional[TensorBoardSummaryWriter] = None,
show_progress_bars: bool = True,
):
r"""Initialize NPE-C.
Expand Down
4 changes: 2 additions & 2 deletions sbi/inference/trainers/nre/bnre.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from sbi.inference.trainers.nre.nre_a import NRE_A
from sbi.inference.trainers.nre.nre_base import RatioEstimatorBuilder
from sbi.sbi_types import TensorboardSummaryWriter
from sbi.sbi_types import TensorBoardSummaryWriter
from sbi.utils.sbiutils import del_entries
from sbi.utils.torchutils import assert_all_finite

Expand All @@ -32,7 +32,7 @@ def __init__(
classifier: Union[str, RatioEstimatorBuilder] = "resnet",
device: str = "cpu",
logging_level: Union[int, str] = "warning",
summary_writer: Optional[TensorboardSummaryWriter] = None,
summary_writer: Optional[TensorBoardSummaryWriter] = None,
show_progress_bars: bool = True,
):
r"""Balanced neural ratio estimation (BNRE).
Expand Down
4 changes: 2 additions & 2 deletions sbi/inference/trainers/nre/nre_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
RatioEstimatorBuilder,
RatioEstimatorTrainer,
)
from sbi.sbi_types import TensorboardSummaryWriter
from sbi.sbi_types import TensorBoardSummaryWriter
from sbi.utils.sbiutils import del_entries
from sbi.utils.torchutils import assert_all_finite

Expand All @@ -29,7 +29,7 @@ def __init__(
classifier: Union[str, RatioEstimatorBuilder] = "resnet",
device: str = "cpu",
logging_level: Union[int, str] = "warning",
summary_writer: Optional[TensorboardSummaryWriter] = None,
summary_writer: Optional[TensorBoardSummaryWriter] = None,
show_progress_bars: bool = True,
):
r"""Initialize NRE_A.
Expand Down
4 changes: 2 additions & 2 deletions sbi/inference/trainers/nre/nre_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
RatioEstimatorBuilder,
RatioEstimatorTrainer,
)
from sbi.sbi_types import TensorboardSummaryWriter
from sbi.sbi_types import TensorBoardSummaryWriter
from sbi.utils.sbiutils import del_entries
from sbi.utils.torchutils import assert_all_finite

Expand All @@ -29,7 +29,7 @@ def __init__(
classifier: Union[str, RatioEstimatorBuilder] = "resnet",
device: str = "cpu",
logging_level: Union[int, str] = "warning",
summary_writer: Optional[TensorboardSummaryWriter] = None,
summary_writer: Optional[TensorBoardSummaryWriter] = None,
show_progress_bars: bool = True,
):
r"""Initialize NRE_B.
Expand Down
4 changes: 2 additions & 2 deletions sbi/inference/trainers/nre/nre_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
RatioEstimatorBuilder,
RatioEstimatorTrainer,
)
from sbi.sbi_types import TensorboardSummaryWriter
from sbi.sbi_types import TensorBoardSummaryWriter
from sbi.utils.sbiutils import del_entries
from sbi.utils.torchutils import assert_all_finite

Expand Down Expand Up @@ -43,7 +43,7 @@ def __init__(
classifier: Union[str, RatioEstimatorBuilder] = "resnet",
device: str = "cpu",
logging_level: Union[int, str] = "warning",
summary_writer: Optional[TensorboardSummaryWriter] = None,
summary_writer: Optional[TensorBoardSummaryWriter] = None,
show_progress_bars: bool = True,
):
r"""Initialize NRE-C.
Expand Down
17 changes: 6 additions & 11 deletions sbi/sbi_types.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import NewType, Optional, Sequence, Tuple, TypeVar, Union
from typing import Optional, Sequence, Tuple, TypeVar, Union

import numpy as np
import torch
from pyro.distributions import TransformedDistribution # type: ignore
from torch import Tensor
from torch.distributions import Distribution
from torch.distributions.transforms import Transform
from torch.nn import Module
from torch.utils.tensorboard.writer import SummaryWriter
from typing_extensions import TypeAlias

Expand All @@ -28,24 +27,20 @@
]
]

# Define alias types because otherwise, the documentation by mkdocs became very long and
# made the website look ugly.
TensorboardSummaryWriter = NewType("Writer", SummaryWriter)
# TorchTransform = NewType("torch Transform", Transform)
TorchModule = NewType("Module", Module)
TorchDistribution = NewType("torch Distribution", Distribution)
# Define alias types for better readability in type hints and checking.
# See PEP 613 for the reason why we need to use TypeAlias here.
TensorBoardSummaryWriter: TypeAlias = SummaryWriter
TorchDistribution: TypeAlias = Distribution
TorchTransform: TypeAlias = Transform
PyroTransformedDistribution: TypeAlias = TransformedDistribution
TorchTensor = NewType("Tensor", Tensor)
TorchTensor: TypeAlias = Tensor

__all__ = [
"Array",
"Shape",
"OneOrMore",
"ScalarFloat",
"TensorboardSummaryWriter",
"TorchModule",
"TensorBoardSummaryWriter",
"TorchTransform",
"transform_types",
"TorchDistribution",
Expand Down