8
8
import pytest
9
9
import torch
10
10
11
+ from sbi .inference .trainers .base import MaskedNeuralInference
11
12
from sbi .neural_nets .embedding_nets import CNNEmbedding
12
13
from sbi .neural_nets .net_builders import (
13
14
build_flow_matching_estimator ,
@@ -200,8 +201,8 @@ def _build_vector_field_estimator_and_tensors(
200
201
201
202
202
203
@pytest .mark .parametrize ("sde_type" , ["ve" , "vp" , "subvp" , "flow" ])
203
- @pytest .mark .parametrize ("input_sample_dim" , (1 , 2 ))
204
- @pytest .mark .parametrize ("input_event_shape" , ((3 , 5 ), (3 , 1 )))
204
+ @pytest .mark .parametrize ("input_sample_dim" , (1 , 2 , 3 ))
205
+ @pytest .mark .parametrize ("input_event_shape" , ((1 ,), ( 4 ,), ( 3 , 5 ), (3 , 1 )))
205
206
@pytest .mark .parametrize ("batch_dim" , (1 , 10 ))
206
207
@pytest .mark .parametrize ("score_net" , ["simformer" ])
207
208
def test_masked_vector_field_estimator_loss_shapes (
@@ -268,8 +269,8 @@ def test_masked_vector_field_estimator_on_device(sde_type, device, score_net):
268
269
269
270
270
271
@pytest .mark .parametrize ("sde_type" , ["ve" , "vp" , "subvp" , "flow" ])
271
- @pytest .mark .parametrize ("input_sample_dim" , (1 , 2 ))
272
- @pytest .mark .parametrize ("input_event_shape" , ((5 , 1 ), (5 , 4 )))
272
+ @pytest .mark .parametrize ("input_sample_dim" , (1 , 2 , 3 ))
273
+ @pytest .mark .parametrize ("input_event_shape" , ((1 ,), ( 4 , ), (3 , 5 ), ( 3 , 1 )))
273
274
@pytest .mark .parametrize ("batch_dim" , (1 , 10 ))
274
275
@pytest .mark .parametrize ("score_net" , ["simformer" ])
275
276
def test_masked_vector_field_estimator_forward_shapes (
@@ -324,8 +325,9 @@ def _build_masked_vector_field_estimator_and_tensors(
324
325
Helper function for all tests that deal with shapes of masked score estimators.
325
326
"""
326
327
327
- num_nodes , num_features = input_event_shape
328
- building_inputs = torch .randn ((batch_dim , num_nodes , num_features ))
328
+ num_nodes = input_event_shape [0 ]
329
+
330
+ building_inputs = torch .randn ((batch_dim , * input_event_shape ))
329
331
330
332
if sde_type == "flow" :
331
333
score_estimator = build_masked_flow_matching_estimator (
@@ -342,9 +344,13 @@ def _build_masked_vector_field_estimator_and_tensors(
342
344
)
343
345
344
346
inputs = building_inputs [:batch_dim ]
345
- condition_masks = torch .bernoulli (torch .rand (batch_dim , num_nodes ))
346
- condition_masks [:, 0 ] = 0 # Force at least one variable to be latent
347
- condition_masks [:, 1 ] = 1 # Force at least one variable to be observed
347
+ # Generate condition mask: latent indices are 0, observed are 1
348
+ latent_idx = torch .arange (num_nodes ) # All nodes are latent by default
349
+ observed_idx = torch .tensor ([], dtype = torch .long ) # No observed nodes by default
350
+ condition_masks = MaskedNeuralInference .generate_condition_mask_from_idx (
351
+ latent_idx = latent_idx ,
352
+ observed_idx = observed_idx ,
353
+ )
348
354
edge_masks = torch .ones (batch_dim , num_nodes , num_nodes )
349
355
350
356
inputs = inputs .unsqueeze (0 )
@@ -361,8 +367,8 @@ def _build_masked_vector_field_estimator_and_tensors(
361
367
362
368
363
369
@pytest .mark .parametrize ("sde_type" , ["ve" , "vp" , "subvp" ])
364
- @pytest .mark .parametrize ("input_sample_dim" , (1 , 2 ))
365
- @pytest .mark .parametrize ("input_event_shape" , ((3 , 5 ), (3 , 1 )))
370
+ @pytest .mark .parametrize ("input_sample_dim" , (1 , 2 , 3 ))
371
+ @pytest .mark .parametrize ("input_event_shape" , ((1 ,), ( 4 ,), ( 3 , 5 ), (3 , 1 )))
366
372
@pytest .mark .parametrize ("batch_dim" , (1 , 10 ))
367
373
@pytest .mark .parametrize ("score_net" , ["simformer" ])
368
374
def test_unmasked_wrapper_vector_field_estimator_loss_shapes (
@@ -422,8 +428,8 @@ def test_unmasked_wrapper_vector_field_estimator_on_device(sde_type, device, sco
422
428
423
429
424
430
@pytest .mark .parametrize ("sde_type" , ["ve" , "vp" , "subvp" ])
425
- @pytest .mark .parametrize ("input_sample_dim" , (1 , 2 ))
426
- @pytest .mark .parametrize ("input_event_shape" , ((3 , 5 ), (3 , 1 )))
431
+ @pytest .mark .parametrize ("input_sample_dim" , (1 , 2 , 3 ))
432
+ @pytest .mark .parametrize ("input_event_shape" , ((1 ,), ( 4 ,), ( 3 , 5 ), (3 , 1 )))
427
433
@pytest .mark .parametrize ("batch_dim" , (1 , 10 ))
428
434
@pytest .mark .parametrize ("score_net" , ["simformer" ])
429
435
def test_unmasked_wrapper_vector_field_estimator_forward_shapes (
0 commit comments