4
4
5
5
import torch
6
6
from sbi import inference as inference
7
+ from sbi .neural_nets .embedding_nets import FCEmbedding , PermutationInvariantEmbedding
7
8
from sbi .utils .get_nn_models import posterior_nn
9
+ from torch import nn
8
10
9
11
from sbibm .algorithms .sbi .utils import (
10
12
wrap_posterior ,
11
13
wrap_prior_dist ,
12
14
wrap_simulator_fn ,
13
15
)
16
+ from sbibm .tasks .ddm .task import DDM
17
+ from sbibm .tasks .ddm .utils import map_x_to_two_D
14
18
from sbibm .tasks .task import Task
15
19
16
20
@@ -30,6 +34,13 @@ def run(
30
34
z_score_x : str = "independent" ,
31
35
z_score_theta : str = "independent" ,
32
36
max_num_epochs : Optional [int ] = 2 ** 31 - 1 ,
37
+ trial_net_kwargs : Optional [dict ] = dict (
38
+ input_dim = 2 ,
39
+ output_dim = 4 ,
40
+ num_hiddens = 10 ,
41
+ num_layers = 2 ,
42
+ ),
43
+ perm_net_kwargs : Optional [dict ] = {},
33
44
) -> Tuple [torch .Tensor , int , Optional [torch .Tensor ]]:
34
45
"""Runs (S)NPE for iid data from `sbi`
35
46
Args:
@@ -84,32 +95,16 @@ def run(
84
95
simulator = wrap_simulator_fn (simulator , transforms )
85
96
86
97
# DDM specific.
87
- from torch import nn
88
- from sbibm .tasks .ddm .utils import map_x_to_two_D
89
- from sbi .neural_nets .embedding_nets import (
90
- FCEmbedding ,
91
- PermutationInvariantEmbedding ,
92
- )
93
-
94
- observation = map_x_to_two_D (observation )
98
+ if isinstance (task , DDM ):
99
+ observation = map_x_to_two_D (observation )
95
100
num_trials = observation .shape [0 ]
96
101
97
102
# embedding net needed?
98
103
if num_trials > 1 :
99
- single_trial_net = FCEmbedding (
100
- input_dim = 2 ,
101
- output_dim = 4 ,
102
- num_hiddens = 10 ,
103
- num_layers = 2 ,
104
- )
105
-
106
104
embedding_net = PermutationInvariantEmbedding (
107
- trial_net = single_trial_net ,
108
- trial_net_output_dim = 4 ,
109
- combining_operation = "mean" ,
110
- num_layers = 2 ,
111
- num_hiddens = 20 ,
112
- output_dim = 10 ,
105
+ trial_net = FCEmbedding (** trial_net_kwargs ),
106
+ trial_net_output_dim = trial_net_kwargs ["output_dim" ],
107
+ ** perm_net_kwargs ,
113
108
)
114
109
else :
115
110
embedding_net = nn .Identity ()
@@ -128,14 +123,20 @@ def run(
128
123
129
124
for _ in range (num_rounds ):
130
125
theta = proposal .sample ((num_simulations_per_round // num_trials ,))
131
- # copy theta for iid trials
132
- theta_per_trial = theta .tile (num_trials ).reshape (
133
- theta .shape [0 ] * num_trials , - 1
134
- )
135
- x = map_x_to_two_D (simulator (theta_per_trial ))
136
126
137
- # rearrange to have trials as separate dim
138
- x = x .reshape (num_simulations , num_trials , 2 )
127
+ if num_trials > 1 :
128
+ # copy theta for iid trials
129
+ theta_per_trial = theta .tile (num_trials ).reshape (
130
+ theta .shape [0 ] * num_trials , - 1
131
+ )
132
+ x = simulator (theta_per_trial )
133
+ if isinstance (task , DDM ):
134
+ x = map_x_to_two_D (x )
135
+
136
+ # rearrange to have trials as separate dim
137
+ x = x .reshape (num_simulations , num_trials , 2 )
138
+ else :
139
+ x = simulator (theta )
139
140
140
141
density_estimator = inference_method .append_simulations (
141
142
theta , x , proposal = proposal
0 commit comments