Skip to content

Commit 22f43d9

Browse files
committed
Added input shape 2-dim tests on vf estimator
1 parent aadda6f commit 22f43d9

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed

tests/vf_estimator_test.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99
import torch
1010

11+
from sbi.inference.trainers.base import MaskedNeuralInference
1112
from sbi.neural_nets.embedding_nets import CNNEmbedding
1213
from sbi.neural_nets.net_builders import (
1314
build_flow_matching_estimator,
@@ -200,8 +201,8 @@ def _build_vector_field_estimator_and_tensors(
200201

201202

202203
@pytest.mark.parametrize("sde_type", ["ve", "vp", "subvp", "flow"])
203-
@pytest.mark.parametrize("input_sample_dim", (1, 2))
204-
@pytest.mark.parametrize("input_event_shape", ((3, 5), (3, 1)))
204+
@pytest.mark.parametrize("input_sample_dim", (1, 2, 3))
205+
@pytest.mark.parametrize("input_event_shape", ((1,), (4,), (3, 5), (3, 1)))
205206
@pytest.mark.parametrize("batch_dim", (1, 10))
206207
@pytest.mark.parametrize("score_net", ["simformer"])
207208
def test_masked_vector_field_estimator_loss_shapes(
@@ -268,8 +269,8 @@ def test_masked_vector_field_estimator_on_device(sde_type, device, score_net):
268269

269270

270271
@pytest.mark.parametrize("sde_type", ["ve", "vp", "subvp", "flow"])
271-
@pytest.mark.parametrize("input_sample_dim", (1, 2))
272-
@pytest.mark.parametrize("input_event_shape", ((5, 1), (5, 4)))
272+
@pytest.mark.parametrize("input_sample_dim", (1, 2, 3))
273+
@pytest.mark.parametrize("input_event_shape", ((1,), (4,), (3, 5), (3, 1)))
273274
@pytest.mark.parametrize("batch_dim", (1, 10))
274275
@pytest.mark.parametrize("score_net", ["simformer"])
275276
def test_masked_vector_field_estimator_forward_shapes(
@@ -324,8 +325,9 @@ def _build_masked_vector_field_estimator_and_tensors(
324325
Helper function for all tests that deal with shapes of masked score estimators.
325326
"""
326327

327-
num_nodes, num_features = input_event_shape
328-
building_inputs = torch.randn((batch_dim, num_nodes, num_features))
328+
num_nodes = input_event_shape[0]
329+
330+
building_inputs = torch.randn((batch_dim, *input_event_shape))
329331

330332
if sde_type == "flow":
331333
score_estimator = build_masked_flow_matching_estimator(
@@ -342,9 +344,13 @@ def _build_masked_vector_field_estimator_and_tensors(
342344
)
343345

344346
inputs = building_inputs[:batch_dim]
345-
condition_masks = torch.bernoulli(torch.rand(batch_dim, num_nodes))
346-
condition_masks[:, 0] = 0 # Force at least one variable to be latent
347-
condition_masks[:, 1] = 1 # Force at least one variable to be observed
347+
# Generate condition mask: latent indices are 0, observed are 1
348+
latent_idx = torch.arange(num_nodes) # All nodes are latent by default
349+
observed_idx = torch.tensor([], dtype=torch.long) # No observed nodes by default
350+
condition_masks = MaskedNeuralInference.generate_condition_mask_from_idx(
351+
latent_idx=latent_idx,
352+
observed_idx=observed_idx,
353+
)
348354
edge_masks = torch.ones(batch_dim, num_nodes, num_nodes)
349355

350356
inputs = inputs.unsqueeze(0)
@@ -361,8 +367,8 @@ def _build_masked_vector_field_estimator_and_tensors(
361367

362368

363369
@pytest.mark.parametrize("sde_type", ["ve", "vp", "subvp"])
364-
@pytest.mark.parametrize("input_sample_dim", (1, 2))
365-
@pytest.mark.parametrize("input_event_shape", ((3, 5), (3, 1)))
370+
@pytest.mark.parametrize("input_sample_dim", (1, 2, 3))
371+
@pytest.mark.parametrize("input_event_shape", ((1,), (4,), (3, 5), (3, 1)))
366372
@pytest.mark.parametrize("batch_dim", (1, 10))
367373
@pytest.mark.parametrize("score_net", ["simformer"])
368374
def test_unmasked_wrapper_vector_field_estimator_loss_shapes(
@@ -422,8 +428,8 @@ def test_unmasked_wrapper_vector_field_estimator_on_device(sde_type, device, sco
422428

423429

424430
@pytest.mark.parametrize("sde_type", ["ve", "vp", "subvp"])
425-
@pytest.mark.parametrize("input_sample_dim", (1, 2))
426-
@pytest.mark.parametrize("input_event_shape", ((3, 5), (3, 1)))
431+
@pytest.mark.parametrize("input_sample_dim", (1, 2, 3))
432+
@pytest.mark.parametrize("input_event_shape", ((1,), (4,), (3, 5), (3, 1)))
427433
@pytest.mark.parametrize("batch_dim", (1, 10))
428434
@pytest.mark.parametrize("score_net", ["simformer"])
429435
def test_unmasked_wrapper_vector_field_estimator_forward_shapes(

0 commit comments

Comments
 (0)