Skip to content

Commit 8620911

Browse files
committed
Adapted tests on Masked Conditional VF Estimator Wrapper to 2-dim inputs
1 parent 3c4e566 commit 8620911

File tree

1 file changed

+48
-40
lines changed

1 file changed

+48
-40
lines changed

tests/vf_estimator_test.py

Lines changed: 48 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import pytest
99
import torch
1010

11-
from sbi.inference.trainers.base import MaskedNeuralInference
1211
from sbi.neural_nets.embedding_nets import CNNEmbedding
1312
from sbi.neural_nets.net_builders import (
1413
build_flow_matching_estimator,
@@ -204,13 +203,13 @@ def _build_vector_field_estimator_and_tensors(
204203
@pytest.mark.parametrize("input_sample_dim", (1, 2, 3))
205204
@pytest.mark.parametrize("input_event_shape", ((1,), (4,), (3, 5), (3, 1)))
206205
@pytest.mark.parametrize("batch_dim", (1, 10))
207-
@pytest.mark.parametrize("score_net", ["simformer"])
206+
@pytest.mark.parametrize("net", ["simformer"])
208207
def test_masked_vector_field_estimator_loss_shapes(
209208
sde_type,
210209
input_sample_dim,
211210
input_event_shape,
212211
batch_dim,
213-
score_net,
212+
net,
214213
):
215214
"""Test whether `loss` of MaskedScoreEstimator follows the shape convention."""
216215
(
@@ -223,7 +222,7 @@ def test_masked_vector_field_estimator_loss_shapes(
223222
input_event_shape,
224223
batch_dim,
225224
input_sample_dim,
226-
net=score_net,
225+
net=net,
227226
)
228227

229228
losses = score_estimator.loss(
@@ -235,22 +234,22 @@ def test_masked_vector_field_estimator_loss_shapes(
235234
@pytest.mark.gpu
236235
@pytest.mark.parametrize("sde_type", ["ve", "vp", "subvp", "flow"])
237236
@pytest.mark.parametrize("device", ["cpu", "cuda"])
238-
@pytest.mark.parametrize("score_net", ["simformer"])
239-
def test_masked_vector_field_estimator_on_device(sde_type, device, score_net):
237+
@pytest.mark.parametrize("net", ["simformer"])
238+
def test_masked_vector_field_estimator_on_device(sde_type, device, net):
240239
"""Test whether MaskedScoreEstimator can be moved to the device."""
241240

242241
if sde_type == "flow":
243242
score_estimator = build_masked_flow_matching_estimator(
244243
torch.randn(100, 5, 1),
245244
torch.randn(100, 5, 1),
246-
net=score_net,
245+
net=net,
247246
)
248247
else:
249248
score_estimator = build_masked_score_matching_estimator(
250249
torch.randn(100, 5, 1),
251250
torch.randn(100, 5, 1),
252251
sde_type=sde_type,
253-
net=score_net,
252+
net=net,
254253
)
255254
score_estimator.to(device)
256255

@@ -272,13 +271,13 @@ def test_masked_vector_field_estimator_on_device(sde_type, device, score_net):
272271
@pytest.mark.parametrize("input_sample_dim", (1, 2, 3))
273272
@pytest.mark.parametrize("input_event_shape", ((1,), (4,), (3, 5), (3, 1)))
274273
@pytest.mark.parametrize("batch_dim", (1, 10))
275-
@pytest.mark.parametrize("score_net", ["simformer"])
274+
@pytest.mark.parametrize("net", ["simformer"])
276275
def test_masked_vector_field_estimator_forward_shapes(
277276
sde_type,
278277
input_sample_dim,
279278
input_event_shape,
280279
batch_dim,
281-
score_net,
280+
net,
282281
):
283282
"""Test whether `forward` of MaskedScoreEstimator follows the shape convention."""
284283
(
@@ -291,7 +290,7 @@ def test_masked_vector_field_estimator_forward_shapes(
291290
input_event_shape,
292291
batch_dim,
293292
input_sample_dim,
294-
net=score_net,
293+
net=net,
295294
)
296295
# Batched times
297296
times = torch.rand((batch_dim,))
@@ -344,13 +343,8 @@ def _build_masked_vector_field_estimator_and_tensors(
344343
)
345344

346345
inputs = building_inputs[:batch_dim]
347-
# Generate condition mask: latent indices are 0, observed are 1
348-
latent_idx = torch.arange(num_nodes) # All nodes are latent by default
349-
observed_idx = torch.tensor([], dtype=torch.long) # No observed nodes by default
350-
condition_masks = MaskedNeuralInference.generate_condition_mask_from_idx(
351-
latent_idx=latent_idx,
352-
observed_idx=observed_idx,
353-
)
346+
condition_masks = torch.ones(batch_dim, num_nodes)
347+
condition_masks[:, 1:] = 0 # Index 0 is latent
354348
edge_masks = torch.ones(batch_dim, num_nodes, num_nodes)
355349

356350
inputs = inputs.unsqueeze(0)
@@ -366,17 +360,17 @@ def _build_masked_vector_field_estimator_and_tensors(
366360
# *** ======== Unmasked Estimator ======== *** #
367361

368362

369-
@pytest.mark.parametrize("sde_type", ["ve", "vp", "subvp"])
363+
@pytest.mark.parametrize("sde_type", ["ve", "vp", "subvp", "flow"])
370364
@pytest.mark.parametrize("input_sample_dim", (1, 2, 3))
371365
@pytest.mark.parametrize("input_event_shape", ((1,), (4,), (3, 5), (3, 1)))
372366
@pytest.mark.parametrize("batch_dim", (1, 10))
373-
@pytest.mark.parametrize("score_net", ["simformer"])
367+
@pytest.mark.parametrize("net", ["simformer"])
374368
def test_unmasked_wrapper_vector_field_estimator_loss_shapes(
375369
sde_type,
376370
input_sample_dim,
377371
input_event_shape,
378372
batch_dim,
379-
score_net,
373+
net,
380374
):
381375
"""Test whether `loss` of MaskedConditionalVectorFieldEstimatorWrapper
382376
follows the shape convention."""
@@ -389,35 +383,44 @@ def test_unmasked_wrapper_vector_field_estimator_loss_shapes(
389383
input_event_shape,
390384
batch_dim,
391385
input_sample_dim,
392-
net=score_net,
386+
net=net,
393387
)
394388

395389
with pytest.raises(NotImplementedError):
396390
score_estimator.loss(inputs[0], condition)
397391

398392

399393
@pytest.mark.gpu
400-
@pytest.mark.parametrize("sde_type", ["ve", "vp", "subvp"])
394+
@pytest.mark.parametrize("sde_type", ["ve", "vp", "subvp", "flow"])
401395
@pytest.mark.parametrize("device", ["cpu", "cuda"])
402-
@pytest.mark.parametrize("score_net", ["simformer"])
403-
def test_unmasked_wrapper_vector_field_estimator_on_device(sde_type, device, score_net):
396+
@pytest.mark.parametrize("net", ["simformer"])
397+
def test_unmasked_wrapper_vector_field_estimator_on_device(sde_type, device, net):
404398
"""Test whether MaskedConditionalVectorFieldEstimatorWrapper
405399
can be moved to the device."""
406400
# Create condition and edge masks
407401
condition_mask = torch.ones(5, device=device)
408402
condition_mask[0] = 0 # Index 0 is latent
409403
edge_mask = torch.ones(5, 5, device=device)
410404

411-
score_estimator = (
412-
build_masked_score_matching_estimator(
413-
torch.randn(100, 5, 1),
414-
torch.randn(100, 5, 1),
405+
building_inputs = torch.randn(100, 5)
406+
407+
if sde_type == "flow":
408+
score_estimator = build_masked_flow_matching_estimator(
409+
building_inputs,
410+
building_inputs, # not used
411+
net=net,
412+
)
413+
else:
414+
score_estimator = build_masked_score_matching_estimator(
415+
building_inputs,
416+
building_inputs, # not used
415417
sde_type=sde_type,
416-
net=score_net,
418+
net=net,
417419
)
418-
.to(device)
419-
.build_conditional_vector_field_estimator(condition_mask, edge_mask)
420-
)
420+
421+
score_estimator = score_estimator.to(
422+
device
423+
).build_conditional_vector_field_estimator(condition_mask, edge_mask)
421424

422425
inputs = torch.randn(100, 1, device=device)
423426
condition = torch.randn(100, 4, device=device)
@@ -427,17 +430,17 @@ def test_unmasked_wrapper_vector_field_estimator_on_device(sde_type, device, sco
427430
assert str(out.device).split(":")[0] == device, "Output device mismatch."
428431

429432

430-
@pytest.mark.parametrize("sde_type", ["ve", "vp", "subvp"])
433+
@pytest.mark.parametrize("sde_type", ["ve", "vp", "subvp", "flow"])
431434
@pytest.mark.parametrize("input_sample_dim", (1, 2, 3))
432435
@pytest.mark.parametrize("input_event_shape", ((1,), (4,), (3, 5), (3, 1)))
433436
@pytest.mark.parametrize("batch_dim", (1, 10))
434-
@pytest.mark.parametrize("score_net", ["simformer"])
437+
@pytest.mark.parametrize("net", ["simformer"])
435438
def test_unmasked_wrapper_vector_field_estimator_forward_shapes(
436439
sde_type,
437440
input_sample_dim,
438441
input_event_shape,
439442
batch_dim,
440-
score_net,
443+
net,
441444
):
442445
"""Test whether `forward` of MaskedConditionalVectorFieldEstimatorWrapperù
443446
follow the shape convention."""
@@ -450,7 +453,7 @@ def test_unmasked_wrapper_vector_field_estimator_forward_shapes(
450453
input_event_shape,
451454
batch_dim,
452455
input_sample_dim,
453-
net=score_net,
456+
net=net,
454457
)
455458
# Batched times
456459
times = torch.rand((batch_dim,))
@@ -490,7 +493,7 @@ def _build_unmasked_vector_field_estimator_and_tensors(
490493
**kwargs,
491494
)
492495

493-
# Use the first condition and edge mask for all batches
496+
# # Use the first condition and edge mask for all batches
494497
condition_masks = condition_masks[0].clone().detach()
495498
edge_masks = edge_masks[0].clone().detach()
496499

@@ -504,8 +507,13 @@ def _build_unmasked_vector_field_estimator_and_tensors(
504507
latent_idx = (condition_masks == 0).squeeze()
505508
observed_idx = (condition_masks == 1).squeeze()
506509

507-
untangled_inputs = inputs[:, :, latent_idx, :] # (B, num_latent, F)
508-
untangled_condition = inputs[0, :, observed_idx, :] # (B, num_observed, F)
510+
# Handle inputs with different number of dimensions
511+
if len(input_event_shape) == 1:
512+
untangled_inputs = inputs[:, :, latent_idx] # (S, B, num_latent)
513+
untangled_condition = inputs[0, :, observed_idx] # (B, num_observed)
514+
else:
515+
untangled_inputs = inputs[:, :, latent_idx, :] # (S, B, num_latent, F)
516+
untangled_condition = inputs[0, :, observed_idx, :] # (B, num_observed, F)
509517

510518
return (
511519
score_estimator,

0 commit comments

Comments
 (0)