Skip to content

Commit 67d7e1c

Browse files
authored
refactor: use TypeAlias and consistent naming for sbi types (#1637)
* fix type aliases, refactor renaming
1 parent 085f9c0 commit 67d7e1c

File tree

11 files changed

+31
-34
lines changed

11 files changed

+31
-34
lines changed

sbi/inference/trainers/nle/mnle.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515
from sbi.inference.trainers.nle.nle_base import LikelihoodEstimatorTrainer
1616
from sbi.neural_nets.estimators import MixedDensityEstimator
17-
from sbi.sbi_types import TensorboardSummaryWriter
17+
from sbi.sbi_types import TensorBoardSummaryWriter
1818
from sbi.utils.sbiutils import del_entries
1919

2020

@@ -36,7 +36,7 @@ def __init__(
3636
density_estimator: Union[str, Callable] = "mnle",
3737
device: str = "cpu",
3838
logging_level: Union[int, str] = "WARNING",
39-
summary_writer: Optional[TensorboardSummaryWriter] = None,
39+
summary_writer: Optional[TensorBoardSummaryWriter] = None,
4040
show_progress_bars: bool = True,
4141
):
4242
r"""Initialize MNLE.

sbi/inference/trainers/nle/nle_a.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch.distributions import Distribution
77

88
from sbi.inference.trainers.nle.nle_base import LikelihoodEstimatorTrainer
9-
from sbi.sbi_types import TensorboardSummaryWriter
9+
from sbi.sbi_types import TensorBoardSummaryWriter
1010
from sbi.utils.sbiutils import del_entries
1111

1212

@@ -24,7 +24,7 @@ def __init__(
2424
density_estimator: Union[str, Callable] = "maf",
2525
device: str = "cpu",
2626
logging_level: Union[int, str] = "WARNING",
27-
summary_writer: Optional[TensorboardSummaryWriter] = None,
27+
summary_writer: Optional[TensorBoardSummaryWriter] = None,
2828
show_progress_bars: bool = True,
2929
):
3030
r"""Initialize Neural Likelihood Estimation.

sbi/inference/trainers/npe/mnpe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616
from sbi.inference.trainers.npe.npe_c import NPE_C
1717
from sbi.neural_nets.estimators import MixedDensityEstimator
18-
from sbi.sbi_types import TensorboardSummaryWriter
18+
from sbi.sbi_types import TensorBoardSummaryWriter
1919
from sbi.utils.sbiutils import del_entries
2020

2121

@@ -34,7 +34,7 @@ def __init__(
3434
density_estimator: Union[str, Callable] = "mnpe",
3535
device: str = "cpu",
3636
logging_level: Union[int, str] = "WARNING",
37-
summary_writer: Optional[TensorboardSummaryWriter] = None,
37+
summary_writer: Optional[TensorBoardSummaryWriter] = None,
3838
show_progress_bars: bool = True,
3939
):
4040
r"""Initialize Mixed Neural Posterior Estimation (MNPE).

sbi/inference/trainers/npe/npe_a.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
from torch.distributions import Distribution, MultivariateNormal
1414

1515
from sbi.inference.posteriors.direct_posterior import DirectPosterior
16-
from sbi.inference.trainers.npe.npe_base import PosteriorEstimatorTrainer
16+
from sbi.inference.trainers.npe.npe_base import (
17+
PosteriorEstimatorTrainer,
18+
)
1719
from sbi.neural_nets.estimators.base import ConditionalDensityEstimator
18-
from sbi.sbi_types import TensorboardSummaryWriter, TorchModule
20+
from sbi.sbi_types import TensorBoardSummaryWriter
1921
from sbi.utils import torchutils
2022
from sbi.utils.sbiutils import (
2123
batched_mixture_mv,
@@ -51,7 +53,7 @@ def __init__(
5153
num_components: int = 10,
5254
device: str = "cpu",
5355
logging_level: Union[int, str] = "WARNING",
54-
summary_writer: Optional[TensorboardSummaryWriter] = None,
56+
summary_writer: Optional[TensorBoardSummaryWriter] = None,
5557
show_progress_bars: bool = True,
5658
):
5759
r"""Initialize NPE-A [1].
@@ -231,7 +233,7 @@ def train(
231233

232234
def correct_for_proposal(
233235
self,
234-
density_estimator: Optional[TorchModule] = None,
236+
density_estimator: Optional[torch.nn.Module] = None,
235237
) -> "NPE_A_MDN":
236238
r"""Build mixture of Gaussians that approximates the posterior.
237239
@@ -285,7 +287,7 @@ def correct_for_proposal(
285287

286288
def build_posterior(
287289
self,
288-
density_estimator: Optional[TorchModule] = None,
290+
density_estimator: Optional[torch.nn.Module] = None,
289291
prior: Optional[Distribution] = None,
290292
**kwargs,
291293
) -> "DirectPosterior":

sbi/inference/trainers/npe/npe_b.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import sbi.utils as utils
1111
from sbi.inference.trainers.npe.npe_base import PosteriorEstimatorTrainer
1212
from sbi.neural_nets.estimators.shape_handling import reshape_to_sample_batch_event
13-
from sbi.sbi_types import TensorboardSummaryWriter
13+
from sbi.sbi_types import TensorBoardSummaryWriter
1414
from sbi.utils.sbiutils import del_entries
1515

1616

@@ -36,7 +36,7 @@ def __init__(
3636
density_estimator: Union[str, Callable] = "maf",
3737
device: str = "cpu",
3838
logging_level: Union[int, str] = "WARNING",
39-
summary_writer: Optional[TensorboardSummaryWriter] = None,
39+
summary_writer: Optional[TensorBoardSummaryWriter] = None,
4040
show_progress_bars: bool = True,
4141
):
4242
r"""Initialize NPE-B.

sbi/inference/trainers/npe/npe_c.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
reshape_to_batch_event,
1717
reshape_to_sample_batch_event,
1818
)
19-
from sbi.sbi_types import TensorboardSummaryWriter
19+
from sbi.sbi_types import TensorBoardSummaryWriter
2020
from sbi.utils import (
2121
batched_mixture_mv,
2222
batched_mixture_vmv,
@@ -70,7 +70,7 @@ def __init__(
7070
density_estimator: Union[str, Callable] = "maf",
7171
device: str = "cpu",
7272
logging_level: Union[int, str] = "WARNING",
73-
summary_writer: Optional[TensorboardSummaryWriter] = None,
73+
summary_writer: Optional[TensorBoardSummaryWriter] = None,
7474
show_progress_bars: bool = True,
7575
):
7676
r"""Initialize NPE-C.

sbi/inference/trainers/nre/bnre.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from sbi.inference.trainers.nre.nre_a import NRE_A
1111
from sbi.inference.trainers.nre.nre_base import RatioEstimatorBuilder
12-
from sbi.sbi_types import TensorboardSummaryWriter
12+
from sbi.sbi_types import TensorBoardSummaryWriter
1313
from sbi.utils.sbiutils import del_entries
1414
from sbi.utils.torchutils import assert_all_finite
1515

@@ -32,7 +32,7 @@ def __init__(
3232
classifier: Union[str, RatioEstimatorBuilder] = "resnet",
3333
device: str = "cpu",
3434
logging_level: Union[int, str] = "warning",
35-
summary_writer: Optional[TensorboardSummaryWriter] = None,
35+
summary_writer: Optional[TensorBoardSummaryWriter] = None,
3636
show_progress_bars: bool = True,
3737
):
3838
r"""Balanced neural ratio estimation (BNRE).

sbi/inference/trainers/nre/nre_a.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
RatioEstimatorBuilder,
1212
RatioEstimatorTrainer,
1313
)
14-
from sbi.sbi_types import TensorboardSummaryWriter
14+
from sbi.sbi_types import TensorBoardSummaryWriter
1515
from sbi.utils.sbiutils import del_entries
1616
from sbi.utils.torchutils import assert_all_finite
1717

@@ -29,7 +29,7 @@ def __init__(
2929
classifier: Union[str, RatioEstimatorBuilder] = "resnet",
3030
device: str = "cpu",
3131
logging_level: Union[int, str] = "warning",
32-
summary_writer: Optional[TensorboardSummaryWriter] = None,
32+
summary_writer: Optional[TensorBoardSummaryWriter] = None,
3333
show_progress_bars: bool = True,
3434
):
3535
r"""Initialize NRE_A.

sbi/inference/trainers/nre/nre_b.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
RatioEstimatorBuilder,
1212
RatioEstimatorTrainer,
1313
)
14-
from sbi.sbi_types import TensorboardSummaryWriter
14+
from sbi.sbi_types import TensorBoardSummaryWriter
1515
from sbi.utils.sbiutils import del_entries
1616
from sbi.utils.torchutils import assert_all_finite
1717

@@ -29,7 +29,7 @@ def __init__(
2929
classifier: Union[str, RatioEstimatorBuilder] = "resnet",
3030
device: str = "cpu",
3131
logging_level: Union[int, str] = "warning",
32-
summary_writer: Optional[TensorboardSummaryWriter] = None,
32+
summary_writer: Optional[TensorBoardSummaryWriter] = None,
3333
show_progress_bars: bool = True,
3434
):
3535
r"""Initialize NRE_B.

sbi/inference/trainers/nre/nre_c.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
RatioEstimatorBuilder,
1212
RatioEstimatorTrainer,
1313
)
14-
from sbi.sbi_types import TensorboardSummaryWriter
14+
from sbi.sbi_types import TensorBoardSummaryWriter
1515
from sbi.utils.sbiutils import del_entries
1616
from sbi.utils.torchutils import assert_all_finite
1717

@@ -43,7 +43,7 @@ def __init__(
4343
classifier: Union[str, RatioEstimatorBuilder] = "resnet",
4444
device: str = "cpu",
4545
logging_level: Union[int, str] = "warning",
46-
summary_writer: Optional[TensorboardSummaryWriter] = None,
46+
summary_writer: Optional[TensorBoardSummaryWriter] = None,
4747
show_progress_bars: bool = True,
4848
):
4949
r"""Initialize NRE-C.

0 commit comments

Comments
 (0)