Skip to content

Commit c3e5b9f

Browse files
committed
fix iid snpe script.
1 parent d197e51 commit c3e5b9f

File tree

1 file changed

+29
-28
lines changed

1 file changed

+29
-28
lines changed

sbibm/algorithms/sbi/snpe.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,17 @@
44

55
import torch
66
from sbi import inference as inference
7+
from sbi.neural_nets.embedding_nets import FCEmbedding, PermutationInvariantEmbedding
78
from sbi.utils.get_nn_models import posterior_nn
9+
from torch import nn
810

911
from sbibm.algorithms.sbi.utils import (
1012
wrap_posterior,
1113
wrap_prior_dist,
1214
wrap_simulator_fn,
1315
)
16+
from sbibm.tasks.ddm.task import DDM
17+
from sbibm.tasks.ddm.utils import map_x_to_two_D
1418
from sbibm.tasks.task import Task
1519

1620

@@ -30,6 +34,13 @@ def run(
3034
z_score_x: str = "independent",
3135
z_score_theta: str = "independent",
3236
max_num_epochs: Optional[int] = 2**31 - 1,
37+
trial_net_kwargs: Optional[dict] = dict(
38+
input_dim=2,
39+
output_dim=4,
40+
num_hiddens=10,
41+
num_layers=2,
42+
),
43+
perm_net_kwargs: Optional[dict] = {},
3344
) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]:
3445
"""Runs (S)NPE for iid data from `sbi`
3546
Args:
@@ -84,32 +95,16 @@ def run(
8495
simulator = wrap_simulator_fn(simulator, transforms)
8596

8697
# DDM specific.
87-
from torch import nn
88-
from sbibm.tasks.ddm.utils import map_x_to_two_D
89-
from sbi.neural_nets.embedding_nets import (
90-
FCEmbedding,
91-
PermutationInvariantEmbedding,
92-
)
93-
94-
observation = map_x_to_two_D(observation)
98+
if isinstance(task, DDM):
99+
observation = map_x_to_two_D(observation)
95100
num_trials = observation.shape[0]
96101

97102
# embedding net needed?
98103
if num_trials > 1:
99-
single_trial_net = FCEmbedding(
100-
input_dim=2,
101-
output_dim=4,
102-
num_hiddens=10,
103-
num_layers=2,
104-
)
105-
106104
embedding_net = PermutationInvariantEmbedding(
107-
trial_net=single_trial_net,
108-
trial_net_output_dim=4,
109-
combining_operation="mean",
110-
num_layers=2,
111-
num_hiddens=20,
112-
output_dim=10,
105+
trial_net=FCEmbedding(**trial_net_kwargs),
106+
trial_net_output_dim=trial_net_kwargs["output_dim"],
107+
**perm_net_kwargs,
113108
)
114109
else:
115110
embedding_net = nn.Identity()
@@ -128,14 +123,20 @@ def run(
128123

129124
for _ in range(num_rounds):
130125
theta = proposal.sample((num_simulations_per_round // num_trials,))
131-
# copy theta for iid trials
132-
theta_per_trial = theta.tile(num_trials).reshape(
133-
theta.shape[0] * num_trials, -1
134-
)
135-
x = map_x_to_two_D(simulator(theta_per_trial))
136126

137-
# rearrange to have trials as separate dim
138-
x = x.reshape(num_simulations, num_trials, 2)
127+
if num_trials > 1:
128+
# copy theta for iid trials
129+
theta_per_trial = theta.tile(num_trials).reshape(
130+
theta.shape[0] * num_trials, -1
131+
)
132+
x = simulator(theta_per_trial)
133+
if isinstance(task, DDM):
134+
x = map_x_to_two_D(x)
135+
136+
# rearrange to have trials as separate dim
137+
x = x.reshape(num_simulations, num_trials, 2)
138+
else:
139+
x = simulator(theta)
139140

140141
density_estimator = inference_method.append_simulations(
141142
theta, x, proposal=proposal

0 commit comments

Comments
 (0)