Skip to content

Commit ef60b91

Browse files
committed
update ddm inference test.
1 parent 1900d61 commit ef60b91

File tree

1 file changed

+119
-10
lines changed

1 file changed

+119
-10
lines changed

tests/tasks/ddm/test_ddm_task.py

Lines changed: 119 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,21 @@
11
import sbibm
2-
3-
from sbibm.algorithms.sbi.snre import run as run_snre
4-
from sbibm.algorithms.sbi.snpe import run as run_snpe
2+
import torch
53

64
from sbibm.metrics.c2st import c2st
75

6+
from sbi.inference import SNRE, SNPE, MNLE
7+
8+
from sbi.neural_nets.embedding_nets import FCEmbedding, PermutationInvariantEmbedding
9+
from sbi.utils import posterior_nn
10+
11+
12+
mcmc_parameters = dict(
13+
num_chains=50,
14+
thin=10,
15+
warmup_steps=50,
16+
init_strategy="proposal",
17+
)
18+
819

920
def test_loading_ddm_task():
1021
sbibm.get_task("ddm")
@@ -17,22 +28,120 @@ def test_simulation_ddm_task():
1728
simulator(prior(1))
1829

1930

31+
def map_x_to_two_D(x):
32+
x = x.squeeze()
33+
x_2d = torch.zeros(x.shape[0], 2)
34+
x_2d[:, 0] = x.abs()
35+
x_2d[x >= 0, 1] = 1
36+
37+
return x_2d
38+
39+
2040
def test_inference_with_nre():
2141
task = sbibm.get_task("ddm")
22-
num_observation = 1
42+
num_observation = 101
43+
num_simulations = 10000
44+
num_samples = 1000
45+
x_o = map_x_to_two_D(task.get_observation(num_observation))
46+
47+
prior = task.get_prior_dist()
48+
simulator = task.get_simulator()
49+
50+
theta = prior.sample((num_simulations,))
51+
x = map_x_to_two_D(simulator(theta))
52+
53+
trainer = SNRE(prior)
54+
trainer.append_simulations(theta, x).train()
55+
posterior = trainer.build_posterior(
56+
mcmc_method="slice_np_vectorized", mcmc_parameters=mcmc_parameters
57+
)
58+
samples = posterior.sample((num_samples,), x=x_o)
59+
60+
reference_samples = task.get_reference_posterior_samples(num_observation)[
61+
:num_samples
62+
]
63+
score = c2st(reference_samples, samples)
64+
print(score)
65+
assert score <= 0.6, f"score={score} must be below 0.6"
66+
67+
68+
def test_inference_with_mnle():
69+
task = sbibm.get_task("ddm")
70+
num_observation = 101
71+
num_simulations = 10000
72+
num_samples = 1000
73+
x_o = map_x_to_two_D(task.get_observation(num_observation))
74+
75+
prior = task.get_prior_dist()
76+
simulator = task.get_simulator()
77+
78+
theta = prior.sample((num_simulations,))
79+
x = map_x_to_two_D(simulator(theta))
80+
81+
trainer = MNLE(prior)
82+
trainer.append_simulations(theta, x).train()
83+
posterior = trainer.build_posterior(
84+
mcmc_method="slice_np_vectorized", mcmc_parameters=mcmc_parameters
85+
)
86+
samples = posterior.sample((num_samples,), x=x_o)
87+
88+
reference_samples = task.get_reference_posterior_samples(num_observation)[
89+
:num_samples
90+
]
91+
score = c2st(reference_samples, samples)
92+
print(score)
93+
assert score <= 0.6, f"score={score} must be below 0.6"
94+
95+
96+
def test_inference_with_npe():
97+
task = sbibm.get_task("ddm")
98+
num_observation = 101
2399
num_simulations = 10000
24100
num_samples = 1000
101+
x_o = map_x_to_two_D(task.get_observation(num_observation))
102+
num_trials = x_o.shape[0]
103+
104+
prior = task.get_prior_dist()
105+
simulator = task.get_simulator()
106+
107+
theta = prior.sample((num_simulations,))
108+
109+
theta = prior.sample((num_simulations,))
110+
# copy theta for iid trials
111+
theta_per_trial = theta.tile(num_trials).reshape(num_simulations * num_trials, -1)
112+
x = map_x_to_two_D(simulator(theta_per_trial))
25113

26-
samples, num_simulations, _ = run_snre(
27-
task,
28-
num_samples=num_samples,
29-
num_simulations=num_simulations,
30-
num_observation=num_observation,
31-
num_rounds=1,
114+
# rearrange to have trials as separate dim
115+
x = x.reshape(num_simulations, num_trials, 2)
116+
117+
single_trial_net = FCEmbedding(
118+
input_dim=2,
119+
output_dim=4,
120+
num_hiddens=10,
121+
num_layers=2,
122+
)
123+
124+
embedding_net = PermutationInvariantEmbedding(
125+
trial_net=single_trial_net,
126+
trial_net_output_dim=4,
127+
combining_operation="mean",
128+
num_layers=2,
129+
num_hiddens=20,
130+
output_dim=10,
32131
)
33132

133+
de_provider = posterior_nn(
134+
model="mdn", num_components=4, embedding_net=embedding_net
135+
)
136+
137+
trainer = SNPE(prior, density_estimator=de_provider).append_simulations(theta, x)
138+
trainer.train()
139+
posterior = trainer.build_posterior()
140+
samples = posterior.sample((num_samples,), x=x_o)
141+
34142
reference_samples = task.get_reference_posterior_samples(num_observation)[
35143
:num_samples
36144
]
37145
score = c2st(reference_samples, samples)
146+
print(score)
38147
assert score <= 0.6, f"score={score} must be below 0.6"

0 commit comments

Comments
 (0)