Skip to content

Commit 91196a4

Browse files
committed
Removed depcrated simformer unit test in posterior_nn tests
1 parent 15a1db7 commit 91196a4

File tree

1 file changed

+0
-88
lines changed

1 file changed

+0
-88
lines changed

tests/posterior_nn_test.py

Lines changed: 0 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -342,94 +342,6 @@ def test_batched_score_sample_with_different_x(
342342
)
343343

344344

345-
@pytest.mark.slow
346-
@pytest.mark.parametrize("x_o_batch_dim", (0, 1, 2))
347-
@pytest.mark.parametrize("sampling_method", ["sde", "ode"])
348-
@pytest.mark.parametrize(
349-
"sample_shape",
350-
(
351-
(5,), # less than num_chains
352-
(4, 2), # 2D batch
353-
),
354-
)
355-
def test_batched_score_simformer_sample_with_different_x(
356-
x_o_batch_dim: bool,
357-
sampling_method: str,
358-
sample_shape: torch.Size,
359-
):
360-
num_dim = 2
361-
num_simulations = 100
362-
363-
prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim))
364-
simulator = diagonal_linear_gaussian
365-
366-
inference = Simformer(prior=prior)
367-
368-
thetas = prior.sample((num_simulations,))
369-
xs = simulator(thetas)
370-
inputs = torch.stack([thetas, xs], dim=1)
371-
372-
inference.append_simulations(
373-
inputs=inputs,
374-
).train(max_num_epochs=100)
375-
376-
x_o = ones(num_dim) if x_o_batch_dim == 0 else ones(x_o_batch_dim, num_dim)
377-
378-
# Build conditional for the specific task: infer theta (node 0) given x (node 1).
379-
inference_condition_mask = torch.tensor([False, True])
380-
381-
posterior = inference.build_conditional(
382-
condition_mask=inference_condition_mask,
383-
sample_with=sampling_method, # type: ignore
384-
)
385-
386-
samples = posterior.sample_batched(
387-
sample_shape,
388-
x_o,
389-
)
390-
391-
assert (
392-
samples.shape == (*sample_shape, x_o_batch_dim, num_dim)
393-
if x_o_batch_dim > 0
394-
else (*sample_shape, num_dim)
395-
), "Sample shape wrong"
396-
397-
# test only for 1 sample_shape case to avoid repeating this test.
398-
if x_o_batch_dim > 1 and sample_shape == (5,):
399-
assert samples.shape[1] == x_o_batch_dim, "Batch dimension wrong"
400-
inference = Simformer(prior=prior)
401-
402-
inference.append_simulations(
403-
inputs=inputs,
404-
).train(max_num_epochs=100)
405-
406-
inference_condition_mask = torch.tensor([False, True])
407-
408-
posterior = inference.build_conditional(
409-
condition_mask=inference_condition_mask,
410-
sample_with=sampling_method, # type: ignore
411-
)
412-
413-
x_o = torch.stack([0.5 * ones(num_dim), -0.5 * ones(num_dim)], dim=0)
414-
# test with multiple chains to test whether correct chains are
415-
# concatenated.
416-
sample_shape = torch.Size([1000]) # use enough samples for accuracy comparison
417-
samples = posterior.sample_batched(sample_shape, x_o)
418-
419-
samples_separate1 = posterior.sample(sample_shape, x_o[0])
420-
samples_separate2 = posterior.sample(sample_shape, x_o[1])
421-
422-
# Check if means are approx. same
423-
samples_m = torch.mean(samples, dim=0, dtype=torch.float32)
424-
samples_separate1_m = torch.mean(samples_separate1, dim=0, dtype=torch.float32)
425-
samples_separate2_m = torch.mean(samples_separate2, dim=0, dtype=torch.float32)
426-
samples_sep_m = torch.stack([samples_separate1_m, samples_separate2_m], dim=0)
427-
428-
assert torch.allclose(samples_m, samples_sep_m, atol=0.2, rtol=0.2), (
429-
"Batched sampling is not consistent with separate sampling."
430-
)
431-
432-
433345
@pytest.mark.slow
434346
@pytest.mark.parametrize("density_estimator", ["mdn", "maf", "zuko_nsf"])
435347
def test_batched_sampling_and_logprob_accuracy(density_estimator: str):

0 commit comments

Comments
 (0)