Skip to content

Commit 3b9be6a

Browse files
New file structure for neural networks (#1237)
* New file structure for neural networks * no more imports of the builders from sbi.neural_nets * Some fixups for paths * Rename build_functions to build_nets * rename to net_builders * raise ImportError for embedding nets
1 parent 04c7b04 commit 3b9be6a

29 files changed

+449
-390
lines changed

sbi/inference/fmpe/fmpe_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from sbi import utils as utils
1616
from sbi.inference.base import NeuralInference
1717
from sbi.inference.posteriors.direct_posterior import DirectPosterior
18-
from sbi.neural_nets import ConditionalDensityEstimator, flowmatching_nn
18+
from sbi.neural_nets import flowmatching_nn
19+
from sbi.neural_nets.estimators import ConditionalDensityEstimator
1920
from sbi.utils import (
2021
RestrictedPrior,
2122
handle_invalid_x,

sbi/inference/potentials/likelihood_based_potential.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
from torch.distributions import Distribution
99

1010
from sbi.inference.potentials.base_potential import BasePotential
11-
from sbi.neural_nets.estimators import ConditionalDensityEstimator
11+
from sbi.neural_nets.estimators import (
12+
ConditionalDensityEstimator,
13+
MixedDensityEstimator,
14+
)
1215
from sbi.neural_nets.estimators.shape_handling import (
1316
reshape_to_batch_event,
1417
reshape_to_sample_batch_event,
1518
)
16-
from sbi.neural_nets.mnle import MixedDensityEstimator
1719
from sbi.sbi_types import TorchTransform
1820
from sbi.utils.sbiutils import mcmc_transform
1921

sbi/inference/snle/mnle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior
1010
from sbi.inference.potentials import mixed_likelihood_estimator_based_potential
1111
from sbi.inference.snle.snle_base import LikelihoodEstimator
12-
from sbi.neural_nets.mnle import MixedDensityEstimator
12+
from sbi.neural_nets.estimators import MixedDensityEstimator
1313
from sbi.sbi_types import TensorboardSummaryWriter, TorchModule
1414
from sbi.utils.sbiutils import del_entries
1515
from sbi.utils.user_input_checks import check_prior

sbi/inference/snle/snle_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior
1717
from sbi.inference.posteriors.importance_posterior import ImportanceSamplingPosterior
1818
from sbi.inference.potentials import likelihood_estimator_based_potential
19-
from sbi.neural_nets import ConditionalDensityEstimator, likelihood_nn
19+
from sbi.neural_nets import likelihood_nn
20+
from sbi.neural_nets.estimators import ConditionalDensityEstimator
2021
from sbi.neural_nets.estimators.shape_handling import (
2122
reshape_to_batch_event,
2223
)

sbi/inference/snpe/snpe_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
from sbi.inference.posteriors.base_posterior import NeuralPosterior
2525
from sbi.inference.posteriors.importance_posterior import ImportanceSamplingPosterior
2626
from sbi.inference.potentials import posterior_estimator_based_potential
27-
from sbi.neural_nets import ConditionalDensityEstimator, posterior_nn
27+
from sbi.neural_nets import posterior_nn
28+
from sbi.neural_nets.estimators import ConditionalDensityEstimator
2829
from sbi.neural_nets.estimators.shape_handling import (
2930
reshape_to_batch_event,
3031
reshape_to_sample_batch_event,

sbi/neural_nets/__init__.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,27 @@
1-
from sbi.neural_nets.classifier import (
2-
build_linear_classifier,
3-
build_mlp_classifier,
4-
build_resnet_classifier,
5-
)
6-
from sbi.neural_nets.embedding_nets import (
7-
CNNEmbedding,
8-
FCEmbedding,
9-
PermutationInvariantEmbedding,
10-
)
11-
from sbi.neural_nets.estimators import ConditionalDensityEstimator, NFlowsFlow
121
from sbi.neural_nets.factory import (
132
classifier_nn,
143
flowmatching_nn,
154
likelihood_nn,
165
posterior_nn,
6+
posterior_score_nn,
177
)
18-
from sbi.neural_nets.flow import (
19-
build_made,
20-
build_maf,
21-
build_maf_rqs,
22-
build_nsf,
23-
build_zuko_maf,
24-
)
25-
from sbi.neural_nets.mdn import build_mdn
26-
from sbi.neural_nets.mnle import MixedDensityEstimator, build_mnle
8+
9+
10+
def __getattr__(name):
11+
if name in ["CNNEmbedding", "FCEmbedding", "PermutationInvariantEmbedding"]:
12+
raise ImportError(
13+
"As of sbi v0.23.0, you have to import embedding networks from "
14+
"`sbi.neural_nets.embedding_nets`. For example, use: "
15+
f"`from sbi.neural_nets.embedding_nets import {name}`"
16+
)
17+
elif name == "classifier_nn":
18+
return classifier_nn
19+
elif name == "flowmatching_nn":
20+
return flowmatching_nn
21+
elif name == "likelihood_nn":
22+
return likelihood_nn
23+
elif name == "posterior_nn":
24+
return posterior_nn
25+
elif name == "posterior_score_nn":
26+
return posterior_score_nn
27+
raise AttributeError(f"Module '{__name__}' has no attribute '{name}'")

0 commit comments

Comments
 (0)