11
11
from sbi .neural_nets .embedding_nets import CNNEmbedding
12
12
from sbi .neural_nets .net_builders import (
13
13
build_flow_matching_estimator ,
14
+ build_masked_flow_matching_estimator ,
14
15
build_masked_score_matching_estimator ,
15
16
build_score_matching_estimator ,
16
17
)
@@ -198,7 +199,7 @@ def _build_vector_field_estimator_and_tensors(
198
199
# *** ======== Masked Estimator ======== *** #
199
200
200
201
201
- @pytest .mark .parametrize ("sde_type" , ["ve" , "vp" , "subvp" ])
202
+ @pytest .mark .parametrize ("sde_type" , ["ve" , "vp" , "subvp" , "flow" ])
202
203
@pytest .mark .parametrize ("input_sample_dim" , (1 , 2 ))
203
204
@pytest .mark .parametrize ("input_event_shape" , ((3 , 5 ), (3 , 1 )))
204
205
@pytest .mark .parametrize ("batch_dim" , (1 , 10 ))
@@ -231,17 +232,25 @@ def test_masked_vector_field_estimator_loss_shapes(
231
232
232
233
233
234
@pytest .mark .gpu
234
- @pytest .mark .parametrize ("sde_type" , ["ve" , "vp" , "subvp" ])
235
+ @pytest .mark .parametrize ("sde_type" , ["ve" , "vp" , "subvp" , "flow" ])
235
236
@pytest .mark .parametrize ("device" , ["cpu" , "cuda" ])
236
237
@pytest .mark .parametrize ("score_net" , ["simformer" ])
237
238
def test_masked_vector_field_estimator_on_device (sde_type , device , score_net ):
238
239
"""Test whether MaskedScoreEstimator can be moved to the device."""
239
- score_estimator = build_masked_score_matching_estimator (
240
- torch .randn (100 , 5 , 1 ),
241
- torch .randn (100 , 5 , 1 ),
242
- sde_type = sde_type ,
243
- net = score_net ,
244
- )
240
+
241
+ if sde_type == "flow" :
242
+ score_estimator = build_masked_flow_matching_estimator (
243
+ torch .randn (100 , 5 , 1 ),
244
+ torch .randn (100 , 5 , 1 ),
245
+ net = score_net ,
246
+ )
247
+ else :
248
+ score_estimator = build_masked_score_matching_estimator (
249
+ torch .randn (100 , 5 , 1 ),
250
+ torch .randn (100 , 5 , 1 ),
251
+ sde_type = sde_type ,
252
+ net = score_net ,
253
+ )
245
254
score_estimator .to (device )
246
255
247
256
# Test forward
@@ -258,7 +267,7 @@ def test_masked_vector_field_estimator_on_device(sde_type, device, score_net):
258
267
assert str (loss .device ).split (":" )[0 ] == device , "Loss device mismatch."
259
268
260
269
261
- @pytest .mark .parametrize ("sde_type" , ["ve" , "vp" , "subvp" ])
270
+ @pytest .mark .parametrize ("sde_type" , ["ve" , "vp" , "subvp" , "flow" ])
262
271
@pytest .mark .parametrize ("input_sample_dim" , (1 , 2 ))
263
272
@pytest .mark .parametrize ("input_event_shape" , ((5 , 1 ), (5 , 4 )))
264
273
@pytest .mark .parametrize ("batch_dim" , (1 , 10 ))
@@ -318,12 +327,19 @@ def _build_masked_vector_field_estimator_and_tensors(
318
327
num_nodes , num_features = input_event_shape
319
328
building_inputs = torch .randn ((batch_dim , num_nodes , num_features ))
320
329
321
- score_estimator = build_masked_score_matching_estimator (
322
- building_inputs ,
323
- building_inputs , # not used
324
- sde_type = sde_type ,
325
- ** kwargs ,
326
- )
330
+ if sde_type == "flow" :
331
+ score_estimator = build_masked_flow_matching_estimator (
332
+ building_inputs ,
333
+ building_inputs , # not used
334
+ ** kwargs ,
335
+ )
336
+ else :
337
+ score_estimator = build_masked_score_matching_estimator (
338
+ building_inputs ,
339
+ building_inputs , # not used
340
+ sde_type = sde_type ,
341
+ ** kwargs ,
342
+ )
327
343
328
344
inputs = building_inputs [:batch_dim ]
329
345
condition_masks = torch .bernoulli (torch .rand (batch_dim , num_nodes ))
0 commit comments