Skip to content

Commit aadda6f

Browse files
committed
Added and adapted test on masked flow vf estimator
1 parent b08008e commit aadda6f

File tree

1 file changed

+31
-15
lines changed

1 file changed

+31
-15
lines changed

tests/vf_estimator_test.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sbi.neural_nets.embedding_nets import CNNEmbedding
1212
from sbi.neural_nets.net_builders import (
1313
build_flow_matching_estimator,
14+
build_masked_flow_matching_estimator,
1415
build_masked_score_matching_estimator,
1516
build_score_matching_estimator,
1617
)
@@ -198,7 +199,7 @@ def _build_vector_field_estimator_and_tensors(
198199
# *** ======== Masked Estimator ======== *** #
199200

200201

201-
@pytest.mark.parametrize("sde_type", ["ve", "vp", "subvp"])
202+
@pytest.mark.parametrize("sde_type", ["ve", "vp", "subvp", "flow"])
202203
@pytest.mark.parametrize("input_sample_dim", (1, 2))
203204
@pytest.mark.parametrize("input_event_shape", ((3, 5), (3, 1)))
204205
@pytest.mark.parametrize("batch_dim", (1, 10))
@@ -231,17 +232,25 @@ def test_masked_vector_field_estimator_loss_shapes(
231232

232233

233234
@pytest.mark.gpu
234-
@pytest.mark.parametrize("sde_type", ["ve", "vp", "subvp"])
235+
@pytest.mark.parametrize("sde_type", ["ve", "vp", "subvp", "flow"])
235236
@pytest.mark.parametrize("device", ["cpu", "cuda"])
236237
@pytest.mark.parametrize("score_net", ["simformer"])
237238
def test_masked_vector_field_estimator_on_device(sde_type, device, score_net):
238239
"""Test whether MaskedScoreEstimator can be moved to the device."""
239-
score_estimator = build_masked_score_matching_estimator(
240-
torch.randn(100, 5, 1),
241-
torch.randn(100, 5, 1),
242-
sde_type=sde_type,
243-
net=score_net,
244-
)
240+
241+
if sde_type == "flow":
242+
score_estimator = build_masked_flow_matching_estimator(
243+
torch.randn(100, 5, 1),
244+
torch.randn(100, 5, 1),
245+
net=score_net,
246+
)
247+
else:
248+
score_estimator = build_masked_score_matching_estimator(
249+
torch.randn(100, 5, 1),
250+
torch.randn(100, 5, 1),
251+
sde_type=sde_type,
252+
net=score_net,
253+
)
245254
score_estimator.to(device)
246255

247256
# Test forward
@@ -258,7 +267,7 @@ def test_masked_vector_field_estimator_on_device(sde_type, device, score_net):
258267
assert str(loss.device).split(":")[0] == device, "Loss device mismatch."
259268

260269

261-
@pytest.mark.parametrize("sde_type", ["ve", "vp", "subvp"])
270+
@pytest.mark.parametrize("sde_type", ["ve", "vp", "subvp", "flow"])
262271
@pytest.mark.parametrize("input_sample_dim", (1, 2))
263272
@pytest.mark.parametrize("input_event_shape", ((5, 1), (5, 4)))
264273
@pytest.mark.parametrize("batch_dim", (1, 10))
@@ -318,12 +327,19 @@ def _build_masked_vector_field_estimator_and_tensors(
318327
num_nodes, num_features = input_event_shape
319328
building_inputs = torch.randn((batch_dim, num_nodes, num_features))
320329

321-
score_estimator = build_masked_score_matching_estimator(
322-
building_inputs,
323-
building_inputs, # not used
324-
sde_type=sde_type,
325-
**kwargs,
326-
)
330+
if sde_type == "flow":
331+
score_estimator = build_masked_flow_matching_estimator(
332+
building_inputs,
333+
building_inputs, # not used
334+
**kwargs,
335+
)
336+
else:
337+
score_estimator = build_masked_score_matching_estimator(
338+
building_inputs,
339+
building_inputs, # not used
340+
sde_type=sde_type,
341+
**kwargs,
342+
)
327343

328344
inputs = building_inputs[:batch_dim]
329345
condition_masks = torch.bernoulli(torch.rand(batch_dim, num_nodes))

0 commit comments

Comments
 (0)