@@ -342,94 +342,6 @@ def test_batched_score_sample_with_different_x(
342
342
)
343
343
344
344
345
- @pytest .mark .slow
346
- @pytest .mark .parametrize ("x_o_batch_dim" , (0 , 1 , 2 ))
347
- @pytest .mark .parametrize ("sampling_method" , ["sde" , "ode" ])
348
- @pytest .mark .parametrize (
349
- "sample_shape" ,
350
- (
351
- (5 ,), # less than num_chains
352
- (4 , 2 ), # 2D batch
353
- ),
354
- )
355
- def test_batched_score_simformer_sample_with_different_x (
356
- x_o_batch_dim : bool ,
357
- sampling_method : str ,
358
- sample_shape : torch .Size ,
359
- ):
360
- num_dim = 2
361
- num_simulations = 100
362
-
363
- prior = MultivariateNormal (loc = zeros (num_dim ), covariance_matrix = eye (num_dim ))
364
- simulator = diagonal_linear_gaussian
365
-
366
- inference = Simformer (prior = prior )
367
-
368
- thetas = prior .sample ((num_simulations ,))
369
- xs = simulator (thetas )
370
- inputs = torch .stack ([thetas , xs ], dim = 1 )
371
-
372
- inference .append_simulations (
373
- inputs = inputs ,
374
- ).train (max_num_epochs = 100 )
375
-
376
- x_o = ones (num_dim ) if x_o_batch_dim == 0 else ones (x_o_batch_dim , num_dim )
377
-
378
- # Build conditional for the specific task: infer theta (node 0) given x (node 1).
379
- inference_condition_mask = torch .tensor ([False , True ])
380
-
381
- posterior = inference .build_conditional (
382
- condition_mask = inference_condition_mask ,
383
- sample_with = sampling_method , # type: ignore
384
- )
385
-
386
- samples = posterior .sample_batched (
387
- sample_shape ,
388
- x_o ,
389
- )
390
-
391
- assert (
392
- samples .shape == (* sample_shape , x_o_batch_dim , num_dim )
393
- if x_o_batch_dim > 0
394
- else (* sample_shape , num_dim )
395
- ), "Sample shape wrong"
396
-
397
- # test only for 1 sample_shape case to avoid repeating this test.
398
- if x_o_batch_dim > 1 and sample_shape == (5 ,):
399
- assert samples .shape [1 ] == x_o_batch_dim , "Batch dimension wrong"
400
- inference = Simformer (prior = prior )
401
-
402
- inference .append_simulations (
403
- inputs = inputs ,
404
- ).train (max_num_epochs = 100 )
405
-
406
- inference_condition_mask = torch .tensor ([False , True ])
407
-
408
- posterior = inference .build_conditional (
409
- condition_mask = inference_condition_mask ,
410
- sample_with = sampling_method , # type: ignore
411
- )
412
-
413
- x_o = torch .stack ([0.5 * ones (num_dim ), - 0.5 * ones (num_dim )], dim = 0 )
414
- # test with multiple chains to test whether correct chains are
415
- # concatenated.
416
- sample_shape = torch .Size ([1000 ]) # use enough samples for accuracy comparison
417
- samples = posterior .sample_batched (sample_shape , x_o )
418
-
419
- samples_separate1 = posterior .sample (sample_shape , x_o [0 ])
420
- samples_separate2 = posterior .sample (sample_shape , x_o [1 ])
421
-
422
- # Check if means are approx. same
423
- samples_m = torch .mean (samples , dim = 0 , dtype = torch .float32 )
424
- samples_separate1_m = torch .mean (samples_separate1 , dim = 0 , dtype = torch .float32 )
425
- samples_separate2_m = torch .mean (samples_separate2 , dim = 0 , dtype = torch .float32 )
426
- samples_sep_m = torch .stack ([samples_separate1_m , samples_separate2_m ], dim = 0 )
427
-
428
- assert torch .allclose (samples_m , samples_sep_m , atol = 0.2 , rtol = 0.2 ), (
429
- "Batched sampling is not consistent with separate sampling."
430
- )
431
-
432
-
433
345
@pytest .mark .slow
434
346
@pytest .mark .parametrize ("density_estimator" , ["mdn" , "maf" , "zuko_nsf" ])
435
347
def test_batched_sampling_and_logprob_accuracy (density_estimator : str ):
0 commit comments