Skip to content

Commit 4899bd1

Browse files
committed
fix iid snpe script.
1 parent d197e51 commit 4899bd1

File tree

1 file changed

+34
-28
lines changed

1 file changed

+34
-28
lines changed

sbibm/algorithms/sbi/snpe.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,17 @@
44

55
import torch
66
from sbi import inference as inference
7+
from sbi.neural_nets.embedding_nets import FCEmbedding, PermutationInvariantEmbedding
78
from sbi.utils.get_nn_models import posterior_nn
9+
from torch import nn
810

911
from sbibm.algorithms.sbi.utils import (
1012
wrap_posterior,
1113
wrap_prior_dist,
1214
wrap_simulator_fn,
1315
)
16+
from sbibm.tasks.ddm.task import DDM
17+
from sbibm.tasks.ddm.utils import map_x_to_two_D
1418
from sbibm.tasks.task import Task
1519

1620

@@ -30,6 +34,18 @@ def run(
3034
z_score_x: str = "independent",
3135
z_score_theta: str = "independent",
3236
max_num_epochs: Optional[int] = 2**31 - 1,
37+
trial_net_kwargs: Optional[dict] = dict(
38+
input_dim=2,
39+
output_dim=4,
40+
num_hiddens=10,
41+
num_layers=2,
42+
),
43+
perm_net_kwargs: Optional[dict] = dict(
44+
combining_operation="mean",
45+
num_layers=2,
46+
num_hiddens=40,
47+
output_dim=20,
48+
),
3349
) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]:
3450
"""Runs (S)NPE for iid data from `sbi`
3551
Args:
@@ -84,32 +100,16 @@ def run(
84100
simulator = wrap_simulator_fn(simulator, transforms)
85101

86102
# DDM specific.
87-
from torch import nn
88-
from sbibm.tasks.ddm.utils import map_x_to_two_D
89-
from sbi.neural_nets.embedding_nets import (
90-
FCEmbedding,
91-
PermutationInvariantEmbedding,
92-
)
93-
94-
observation = map_x_to_two_D(observation)
103+
if isinstance(task, DDM):
104+
observation = map_x_to_two_D(observation)
95105
num_trials = observation.shape[0]
96106

97107
# embedding net needed?
98108
if num_trials > 1:
99-
single_trial_net = FCEmbedding(
100-
input_dim=2,
101-
output_dim=4,
102-
num_hiddens=10,
103-
num_layers=2,
104-
)
105-
106109
embedding_net = PermutationInvariantEmbedding(
107-
trial_net=single_trial_net,
108-
trial_net_output_dim=4,
109-
combining_operation="mean",
110-
num_layers=2,
111-
num_hiddens=20,
112-
output_dim=10,
110+
trial_net=FCEmbedding(**trial_net_kwargs),
111+
trial_net_output_dim=trial_net_kwargs["output_dim"],
112+
**perm_net_kwargs,
113113
)
114114
else:
115115
embedding_net = nn.Identity()
@@ -128,14 +128,20 @@ def run(
128128

129129
for _ in range(num_rounds):
130130
theta = proposal.sample((num_simulations_per_round // num_trials,))
131-
# copy theta for iid trials
132-
theta_per_trial = theta.tile(num_trials).reshape(
133-
theta.shape[0] * num_trials, -1
134-
)
135-
x = map_x_to_two_D(simulator(theta_per_trial))
136131

137-
# rearrange to have trials as separate dim
138-
x = x.reshape(num_simulations, num_trials, 2)
132+
if num_trials > 1:
133+
# copy theta for iid trials
134+
theta_per_trial = theta.tile(num_trials).reshape(
135+
theta.shape[0] * num_trials, -1
136+
)
137+
x = simulator(theta_per_trial)
138+
if isinstance(task, DDM):
139+
x = map_x_to_two_D(x)
140+
141+
# rearrange to have trials as separate dim
142+
x = x.reshape(num_simulations_per_round // num_trials, num_trials, 2)
143+
else:
144+
x = simulator(theta)
139145

140146
density_estimator = inference_method.append_simulations(
141147
theta, x, proposal=proposal

0 commit comments

Comments
 (0)