8
8
import pytest
9
9
import torch
10
10
11
- from sbi .inference .trainers .base import MaskedNeuralInference
12
11
from sbi .neural_nets .embedding_nets import CNNEmbedding
13
12
from sbi .neural_nets .net_builders import (
14
13
build_flow_matching_estimator ,
@@ -204,13 +203,13 @@ def _build_vector_field_estimator_and_tensors(
204
203
@pytest .mark .parametrize ("input_sample_dim" , (1 , 2 , 3 ))
205
204
@pytest .mark .parametrize ("input_event_shape" , ((1 ,), (4 ,), (3 , 5 ), (3 , 1 )))
206
205
@pytest .mark .parametrize ("batch_dim" , (1 , 10 ))
207
- @pytest .mark .parametrize ("score_net " , ["simformer" ])
206
+ @pytest .mark .parametrize ("net " , ["simformer" ])
208
207
def test_masked_vector_field_estimator_loss_shapes (
209
208
sde_type ,
210
209
input_sample_dim ,
211
210
input_event_shape ,
212
211
batch_dim ,
213
- score_net ,
212
+ net ,
214
213
):
215
214
"""Test whether `loss` of MaskedScoreEstimator follows the shape convention."""
216
215
(
@@ -223,7 +222,7 @@ def test_masked_vector_field_estimator_loss_shapes(
223
222
input_event_shape ,
224
223
batch_dim ,
225
224
input_sample_dim ,
226
- net = score_net ,
225
+ net = net ,
227
226
)
228
227
229
228
losses = score_estimator .loss (
@@ -235,22 +234,22 @@ def test_masked_vector_field_estimator_loss_shapes(
235
234
@pytest .mark .gpu
236
235
@pytest .mark .parametrize ("sde_type" , ["ve" , "vp" , "subvp" , "flow" ])
237
236
@pytest .mark .parametrize ("device" , ["cpu" , "cuda" ])
238
- @pytest .mark .parametrize ("score_net " , ["simformer" ])
239
- def test_masked_vector_field_estimator_on_device (sde_type , device , score_net ):
237
+ @pytest .mark .parametrize ("net " , ["simformer" ])
238
+ def test_masked_vector_field_estimator_on_device (sde_type , device , net ):
240
239
"""Test whether MaskedScoreEstimator can be moved to the device."""
241
240
242
241
if sde_type == "flow" :
243
242
score_estimator = build_masked_flow_matching_estimator (
244
243
torch .randn (100 , 5 , 1 ),
245
244
torch .randn (100 , 5 , 1 ),
246
- net = score_net ,
245
+ net = net ,
247
246
)
248
247
else :
249
248
score_estimator = build_masked_score_matching_estimator (
250
249
torch .randn (100 , 5 , 1 ),
251
250
torch .randn (100 , 5 , 1 ),
252
251
sde_type = sde_type ,
253
- net = score_net ,
252
+ net = net ,
254
253
)
255
254
score_estimator .to (device )
256
255
@@ -272,13 +271,13 @@ def test_masked_vector_field_estimator_on_device(sde_type, device, score_net):
272
271
@pytest .mark .parametrize ("input_sample_dim" , (1 , 2 , 3 ))
273
272
@pytest .mark .parametrize ("input_event_shape" , ((1 ,), (4 ,), (3 , 5 ), (3 , 1 )))
274
273
@pytest .mark .parametrize ("batch_dim" , (1 , 10 ))
275
- @pytest .mark .parametrize ("score_net " , ["simformer" ])
274
+ @pytest .mark .parametrize ("net " , ["simformer" ])
276
275
def test_masked_vector_field_estimator_forward_shapes (
277
276
sde_type ,
278
277
input_sample_dim ,
279
278
input_event_shape ,
280
279
batch_dim ,
281
- score_net ,
280
+ net ,
282
281
):
283
282
"""Test whether `forward` of MaskedScoreEstimator follows the shape convention."""
284
283
(
@@ -291,7 +290,7 @@ def test_masked_vector_field_estimator_forward_shapes(
291
290
input_event_shape ,
292
291
batch_dim ,
293
292
input_sample_dim ,
294
- net = score_net ,
293
+ net = net ,
295
294
)
296
295
# Batched times
297
296
times = torch .rand ((batch_dim ,))
@@ -344,13 +343,8 @@ def _build_masked_vector_field_estimator_and_tensors(
344
343
)
345
344
346
345
inputs = building_inputs [:batch_dim ]
347
- # Generate condition mask: latent indices are 0, observed are 1
348
- latent_idx = torch .arange (num_nodes ) # All nodes are latent by default
349
- observed_idx = torch .tensor ([], dtype = torch .long ) # No observed nodes by default
350
- condition_masks = MaskedNeuralInference .generate_condition_mask_from_idx (
351
- latent_idx = latent_idx ,
352
- observed_idx = observed_idx ,
353
- )
346
+ condition_masks = torch .ones (batch_dim , num_nodes )
347
+ condition_masks [:, 1 :] = 0 # Index 0 is latent
354
348
edge_masks = torch .ones (batch_dim , num_nodes , num_nodes )
355
349
356
350
inputs = inputs .unsqueeze (0 )
@@ -366,17 +360,17 @@ def _build_masked_vector_field_estimator_and_tensors(
366
360
# *** ======== Unmasked Estimator ======== *** #
367
361
368
362
369
- @pytest .mark .parametrize ("sde_type" , ["ve" , "vp" , "subvp" ])
363
+ @pytest .mark .parametrize ("sde_type" , ["ve" , "vp" , "subvp" , "flow" ])
370
364
@pytest .mark .parametrize ("input_sample_dim" , (1 , 2 , 3 ))
371
365
@pytest .mark .parametrize ("input_event_shape" , ((1 ,), (4 ,), (3 , 5 ), (3 , 1 )))
372
366
@pytest .mark .parametrize ("batch_dim" , (1 , 10 ))
373
- @pytest .mark .parametrize ("score_net " , ["simformer" ])
367
+ @pytest .mark .parametrize ("net " , ["simformer" ])
374
368
def test_unmasked_wrapper_vector_field_estimator_loss_shapes (
375
369
sde_type ,
376
370
input_sample_dim ,
377
371
input_event_shape ,
378
372
batch_dim ,
379
- score_net ,
373
+ net ,
380
374
):
381
375
"""Test whether `loss` of MaskedConditionalVectorFieldEstimatorWrapper
382
376
follows the shape convention."""
@@ -389,35 +383,44 @@ def test_unmasked_wrapper_vector_field_estimator_loss_shapes(
389
383
input_event_shape ,
390
384
batch_dim ,
391
385
input_sample_dim ,
392
- net = score_net ,
386
+ net = net ,
393
387
)
394
388
395
389
with pytest .raises (NotImplementedError ):
396
390
score_estimator .loss (inputs [0 ], condition )
397
391
398
392
399
393
@pytest .mark .gpu
400
- @pytest .mark .parametrize ("sde_type" , ["ve" , "vp" , "subvp" ])
394
+ @pytest .mark .parametrize ("sde_type" , ["ve" , "vp" , "subvp" , "flow" ])
401
395
@pytest .mark .parametrize ("device" , ["cpu" , "cuda" ])
402
- @pytest .mark .parametrize ("score_net " , ["simformer" ])
403
- def test_unmasked_wrapper_vector_field_estimator_on_device (sde_type , device , score_net ):
396
+ @pytest .mark .parametrize ("net " , ["simformer" ])
397
+ def test_unmasked_wrapper_vector_field_estimator_on_device (sde_type , device , net ):
404
398
"""Test whether MaskedConditionalVectorFieldEstimatorWrapper
405
399
can be moved to the device."""
406
400
# Create condition and edge masks
407
401
condition_mask = torch .ones (5 , device = device )
408
402
condition_mask [0 ] = 0 # Index 0 is latent
409
403
edge_mask = torch .ones (5 , 5 , device = device )
410
404
411
- score_estimator = (
412
- build_masked_score_matching_estimator (
413
- torch .randn (100 , 5 , 1 ),
414
- torch .randn (100 , 5 , 1 ),
405
+ building_inputs = torch .randn (100 , 5 )
406
+
407
+ if sde_type == "flow" :
408
+ score_estimator = build_masked_flow_matching_estimator (
409
+ building_inputs ,
410
+ building_inputs , # not used
411
+ net = net ,
412
+ )
413
+ else :
414
+ score_estimator = build_masked_score_matching_estimator (
415
+ building_inputs ,
416
+ building_inputs , # not used
415
417
sde_type = sde_type ,
416
- net = score_net ,
418
+ net = net ,
417
419
)
418
- .to (device )
419
- .build_conditional_vector_field_estimator (condition_mask , edge_mask )
420
- )
420
+
421
+ score_estimator = score_estimator .to (
422
+ device
423
+ ).build_conditional_vector_field_estimator (condition_mask , edge_mask )
421
424
422
425
inputs = torch .randn (100 , 1 , device = device )
423
426
condition = torch .randn (100 , 4 , device = device )
@@ -427,17 +430,17 @@ def test_unmasked_wrapper_vector_field_estimator_on_device(sde_type, device, sco
427
430
assert str (out .device ).split (":" )[0 ] == device , "Output device mismatch."
428
431
429
432
430
- @pytest .mark .parametrize ("sde_type" , ["ve" , "vp" , "subvp" ])
433
+ @pytest .mark .parametrize ("sde_type" , ["ve" , "vp" , "subvp" , "flow" ])
431
434
@pytest .mark .parametrize ("input_sample_dim" , (1 , 2 , 3 ))
432
435
@pytest .mark .parametrize ("input_event_shape" , ((1 ,), (4 ,), (3 , 5 ), (3 , 1 )))
433
436
@pytest .mark .parametrize ("batch_dim" , (1 , 10 ))
434
- @pytest .mark .parametrize ("score_net " , ["simformer" ])
437
+ @pytest .mark .parametrize ("net " , ["simformer" ])
435
438
def test_unmasked_wrapper_vector_field_estimator_forward_shapes (
436
439
sde_type ,
437
440
input_sample_dim ,
438
441
input_event_shape ,
439
442
batch_dim ,
440
- score_net ,
443
+ net ,
441
444
):
442
445
"""Test whether `forward` of MaskedConditionalVectorFieldEstimatorWrapperù
443
446
follow the shape convention."""
@@ -450,7 +453,7 @@ def test_unmasked_wrapper_vector_field_estimator_forward_shapes(
450
453
input_event_shape ,
451
454
batch_dim ,
452
455
input_sample_dim ,
453
- net = score_net ,
456
+ net = net ,
454
457
)
455
458
# Batched times
456
459
times = torch .rand ((batch_dim ,))
@@ -490,7 +493,7 @@ def _build_unmasked_vector_field_estimator_and_tensors(
490
493
** kwargs ,
491
494
)
492
495
493
- # Use the first condition and edge mask for all batches
496
+ # # Use the first condition and edge mask for all batches
494
497
condition_masks = condition_masks [0 ].clone ().detach ()
495
498
edge_masks = edge_masks [0 ].clone ().detach ()
496
499
@@ -504,8 +507,13 @@ def _build_unmasked_vector_field_estimator_and_tensors(
504
507
latent_idx = (condition_masks == 0 ).squeeze ()
505
508
observed_idx = (condition_masks == 1 ).squeeze ()
506
509
507
- untangled_inputs = inputs [:, :, latent_idx , :] # (B, num_latent, F)
508
- untangled_condition = inputs [0 , :, observed_idx , :] # (B, num_observed, F)
510
+ # Handle inputs with different number of dimensions
511
+ if len (input_event_shape ) == 1 :
512
+ untangled_inputs = inputs [:, :, latent_idx ] # (S, B, num_latent)
513
+ untangled_condition = inputs [0 , :, observed_idx ] # (B, num_observed)
514
+ else :
515
+ untangled_inputs = inputs [:, :, latent_idx , :] # (S, B, num_latent, F)
516
+ untangled_condition = inputs [0 , :, observed_idx , :] # (B, num_observed, F)
509
517
510
518
return (
511
519
score_estimator ,
0 commit comments