Skip to content

Commit 8b41b1a

Browse files
committed
test: move estimator builder test to density_estimator_test file
1 parent e6d1b77 commit 8b41b1a

File tree

2 files changed

+109
-112
lines changed

2 files changed

+109
-112
lines changed

tests/density_estimator_test.py

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33

44
from __future__ import annotations
55

6-
from typing import Callable, Tuple
6+
from typing import Callable, Dict, Tuple
77

88
import pytest
99
import torch
10-
from torch import eye, zeros
10+
from torch import Tensor, eye, zeros
1111
from torch.distributions import HalfNormal, MultivariateNormal
1212

13+
from sbi.inference import NLE, NPE, NRE
14+
from sbi.inference.trainers.base import NeuralInference
1315
from sbi.neural_nets.embedding_nets import CNNEmbedding
1416
from sbi.neural_nets.estimators.shape_handling import reshape_to_sample_batch_event
1517
from sbi.neural_nets.estimators.zuko_flow import ZukoFlow
@@ -36,6 +38,7 @@
3638
build_zuko_unaf,
3739
)
3840
from sbi.neural_nets.net_builders.flow import build_zuko_flow
41+
from sbi.neural_nets.ratio_estimators import RatioEstimator
3942
from sbi.utils.torchutils import BoxUniform
4043

4144
# List of all density estimator builders for testing.
@@ -512,3 +515,106 @@ def test_build_zuko_flow_missing_x_dist_raises_error(which_nf):
512515
z_score_y="transform_to_unconstrained",
513516
x_dist=None, # No distribution provided
514517
)
518+
519+
520+
def build_classifier(theta, x):
521+
net = torch.nn.Linear(theta.shape[1] + x.shape[1], 1)
522+
return RatioEstimator(net=net, theta_shape=theta[0].shape, x_shape=x[0].shape)
523+
524+
525+
def build_estimator(theta, x):
526+
return build_mdn(theta, x)
527+
528+
529+
def build_estimator_missing_args():
530+
pass
531+
532+
533+
def build_estimator_missing_return(theta: Tensor, x: Tensor):
534+
pass
535+
536+
537+
@pytest.mark.parametrize(
538+
("params", "trainer_class"),
539+
[
540+
# Valid builders
541+
pytest.param(dict(classifier=build_classifier), NRE),
542+
pytest.param(dict(density_estimator=build_estimator), NPE),
543+
pytest.param(dict(density_estimator=build_estimator), NLE),
544+
# Invalid builders
545+
pytest.param(
546+
dict(classifier=build_estimator_missing_args),
547+
NRE,
548+
marks=pytest.mark.xfail(
549+
raises=TypeError,
550+
reason="Missing required parameters in classifier builder.",
551+
),
552+
),
553+
pytest.param(
554+
dict(density_estimator=build_estimator_missing_args),
555+
NPE,
556+
marks=pytest.mark.xfail(
557+
raises=TypeError,
558+
reason="Missing required parameters in density estimator builder.",
559+
),
560+
),
561+
pytest.param(
562+
dict(density_estimator=build_estimator_missing_args),
563+
NLE,
564+
marks=pytest.mark.xfail(
565+
raises=TypeError,
566+
reason="Missing required parameters in density estimator builder.",
567+
),
568+
),
569+
pytest.param(
570+
dict(classifier=build_estimator_missing_return),
571+
NRE,
572+
marks=pytest.mark.xfail(
573+
raises=AttributeError,
574+
reason="Missing return of RatioEstimator in classifier builder.",
575+
),
576+
),
577+
pytest.param(
578+
dict(density_estimator=build_estimator_missing_return),
579+
NPE,
580+
marks=pytest.mark.xfail(
581+
raises=AttributeError,
582+
reason="Missing return of type ConditionalEstimator"
583+
" in density estimator builder.",
584+
),
585+
),
586+
pytest.param(
587+
dict(density_estimator=build_estimator_missing_return),
588+
NLE,
589+
marks=pytest.mark.xfail(
590+
raises=AttributeError,
591+
reason="Missing return of type ConditionalEstimator"
592+
" in density estimator builder.",
593+
),
594+
),
595+
],
596+
)
597+
def test_trainers_with_valid_and_invalid_estimator_builders(
598+
params: Dict, trainer_class: type[NeuralInference]
599+
):
600+
"""
601+
Test trainers classes work with valid classifier builders and fail
602+
with invalid ones.
603+
604+
Args:
605+
params: Parameters passed to the trainer class.
606+
trainer_class: Trainer classes.
607+
"""
608+
609+
def simulator(theta):
610+
return 1.0 + theta + torch.randn(theta.shape, device=theta.device) * 0.1
611+
612+
num_dim = 3
613+
prior = BoxUniform(low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim))
614+
theta = prior.sample((300,))
615+
x = simulator(theta)
616+
617+
inference = trainer_class(**params)
618+
inference.append_simulations(theta, x)
619+
620+
inference.train(max_num_epochs=1)

tests/ratio_estimator_test.py

Lines changed: 1 addition & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,14 @@
33

44
from __future__ import annotations
55

6-
from typing import Dict
7-
86
import pytest
97
import torch
10-
from torch import Tensor, eye, zeros
8+
from torch import eye, zeros
119
from torch.distributions import MultivariateNormal
1210

13-
from sbi.inference import NLE, NPE, NRE
14-
from sbi.inference.trainers.base import NeuralInference
1511
from sbi.neural_nets.embedding_nets import CNNEmbedding
1612
from sbi.neural_nets.net_builders import build_linear_classifier
17-
from sbi.neural_nets.net_builders.mdn import build_mdn
1813
from sbi.neural_nets.ratio_estimators import RatioEstimator
19-
from sbi.utils.torchutils import BoxUniform
2014

2115

2216
class EmbeddingNet(torch.nn.Module):
@@ -78,106 +72,3 @@ def test_api_ratio_estimator(ratio_estimator, theta_shape, x_shape):
7872
nsamples,
7973
), f"""unnormalized_log_ratio shape is not correct. It is of shape
8074
{unnormalized_log_ratio.shape}, but should be {(nsamples,)}"""
81-
82-
83-
def build_classifier(theta, x):
84-
net = torch.nn.Linear(theta.shape[1] + x.shape[1], 1)
85-
return RatioEstimator(net=net, theta_shape=theta[0].shape, x_shape=x[0].shape)
86-
87-
88-
def build_estimator(theta, x):
89-
return build_mdn(theta, x)
90-
91-
92-
def build_estimator_missing_args():
93-
pass
94-
95-
96-
def build_estimator_missing_return(theta: Tensor, x: Tensor):
97-
pass
98-
99-
100-
@pytest.mark.parametrize(
101-
("params", "trainer_class"),
102-
[
103-
# Valid builders
104-
pytest.param(dict(classifier=build_classifier), NRE),
105-
pytest.param(dict(density_estimator=build_estimator), NPE),
106-
pytest.param(dict(density_estimator=build_estimator), NLE),
107-
# Invalid builders
108-
pytest.param(
109-
dict(classifier=build_estimator_missing_args),
110-
NRE,
111-
marks=pytest.mark.xfail(
112-
raises=TypeError,
113-
reason="Missing required parameters in classifier builder.",
114-
),
115-
),
116-
pytest.param(
117-
dict(density_estimator=build_estimator_missing_args),
118-
NPE,
119-
marks=pytest.mark.xfail(
120-
raises=TypeError,
121-
reason="Missing required parameters in density estimator builder.",
122-
),
123-
),
124-
pytest.param(
125-
dict(density_estimator=build_estimator_missing_args),
126-
NLE,
127-
marks=pytest.mark.xfail(
128-
raises=TypeError,
129-
reason="Missing required parameters in density estimator builder.",
130-
),
131-
),
132-
pytest.param(
133-
dict(classifier=build_estimator_missing_return),
134-
NRE,
135-
marks=pytest.mark.xfail(
136-
raises=AttributeError,
137-
reason="Missing return of RatioEstimator in classifier builder.",
138-
),
139-
),
140-
pytest.param(
141-
dict(density_estimator=build_estimator_missing_return),
142-
NPE,
143-
marks=pytest.mark.xfail(
144-
raises=AttributeError,
145-
reason="Missing return of type ConditionalEstimator"
146-
" in density estimator builder.",
147-
),
148-
),
149-
pytest.param(
150-
dict(density_estimator=build_estimator_missing_return),
151-
NLE,
152-
marks=pytest.mark.xfail(
153-
raises=AttributeError,
154-
reason="Missing return of type ConditionalEstimator"
155-
" in density estimator builder.",
156-
),
157-
),
158-
],
159-
)
160-
def test_trainers_with_valid_and_invalid_estimator_builders(
161-
params: Dict, trainer_class: type[NeuralInference]
162-
):
163-
"""
164-
Test trainers classes work with valid classifier builders and fail
165-
with invalid ones.
166-
167-
Args:
168-
params: Parameters passed to the trainer class.
169-
trainer_class: Trainer classes.
170-
"""
171-
172-
def simulator(theta):
173-
return 1.0 + theta + torch.randn(theta.shape, device=theta.device) * 0.1
174-
175-
num_dim = 3
176-
prior = BoxUniform(low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim))
177-
theta = prior.sample((300,))
178-
x = simulator(theta)
179-
180-
inference = trainer_class(**params)
181-
inference.append_simulations(theta, x)
182-
183-
inference.train(max_num_epochs=1)

0 commit comments

Comments
 (0)