4
4
5
5
import torch
6
6
from sbi import inference as inference
7
+ from sbi .neural_nets .embedding_nets import FCEmbedding , PermutationInvariantEmbedding
7
8
from sbi .utils .get_nn_models import posterior_nn
9
+ from torch import nn
8
10
9
11
from sbibm .algorithms .sbi .utils import (
10
12
wrap_posterior ,
11
13
wrap_prior_dist ,
12
14
wrap_simulator_fn ,
13
15
)
16
+ from sbibm .tasks .ddm .task import DDM
17
+ from sbibm .tasks .ddm .utils import map_x_to_two_D
14
18
from sbibm .tasks .task import Task
15
19
16
20
@@ -30,6 +34,18 @@ def run(
30
34
z_score_x : str = "independent" ,
31
35
z_score_theta : str = "independent" ,
32
36
max_num_epochs : Optional [int ] = 2 ** 31 - 1 ,
37
+ trial_net_kwargs : Optional [dict ] = dict (
38
+ input_dim = 2 ,
39
+ output_dim = 4 ,
40
+ num_hiddens = 10 ,
41
+ num_layers = 2 ,
42
+ ),
43
+ perm_net_kwargs : Optional [dict ] = dict (
44
+ combining_operation = "mean" ,
45
+ num_layers = 2 ,
46
+ num_hiddens = 40 ,
47
+ output_dim = 20 ,
48
+ ),
33
49
) -> Tuple [torch .Tensor , int , Optional [torch .Tensor ]]:
34
50
"""Runs (S)NPE for iid data from `sbi`
35
51
Args:
@@ -84,32 +100,16 @@ def run(
84
100
simulator = wrap_simulator_fn (simulator , transforms )
85
101
86
102
# DDM specific.
87
- from torch import nn
88
- from sbibm .tasks .ddm .utils import map_x_to_two_D
89
- from sbi .neural_nets .embedding_nets import (
90
- FCEmbedding ,
91
- PermutationInvariantEmbedding ,
92
- )
93
-
94
- observation = map_x_to_two_D (observation )
103
+ if isinstance (task , DDM ):
104
+ observation = map_x_to_two_D (observation )
95
105
num_trials = observation .shape [0 ]
96
106
97
107
# embedding net needed?
98
108
if num_trials > 1 :
99
- single_trial_net = FCEmbedding (
100
- input_dim = 2 ,
101
- output_dim = 4 ,
102
- num_hiddens = 10 ,
103
- num_layers = 2 ,
104
- )
105
-
106
109
embedding_net = PermutationInvariantEmbedding (
107
- trial_net = single_trial_net ,
108
- trial_net_output_dim = 4 ,
109
- combining_operation = "mean" ,
110
- num_layers = 2 ,
111
- num_hiddens = 20 ,
112
- output_dim = 10 ,
110
+ trial_net = FCEmbedding (** trial_net_kwargs ),
111
+ trial_net_output_dim = trial_net_kwargs ["output_dim" ],
112
+ ** perm_net_kwargs ,
113
113
)
114
114
else :
115
115
embedding_net = nn .Identity ()
@@ -128,14 +128,20 @@ def run(
128
128
129
129
for _ in range (num_rounds ):
130
130
theta = proposal .sample ((num_simulations_per_round // num_trials ,))
131
- # copy theta for iid trials
132
- theta_per_trial = theta .tile (num_trials ).reshape (
133
- theta .shape [0 ] * num_trials , - 1
134
- )
135
- x = map_x_to_two_D (simulator (theta_per_trial ))
136
131
137
- # rearrange to have trials as separate dim
138
- x = x .reshape (num_simulations , num_trials , 2 )
132
+ if num_trials > 1 :
133
+ # copy theta for iid trials
134
+ theta_per_trial = theta .tile (num_trials ).reshape (
135
+ theta .shape [0 ] * num_trials , - 1
136
+ )
137
+ x = simulator (theta_per_trial )
138
+ if isinstance (task , DDM ):
139
+ x = map_x_to_two_D (x )
140
+
141
+ # rearrange to have trials as separate dim
142
+ x = x .reshape (num_simulations_per_round // num_trials , num_trials , 2 )
143
+ else :
144
+ x = simulator (theta )
139
145
140
146
density_estimator = inference_method .append_simulations (
141
147
theta , x , proposal = proposal
0 commit comments