|
3 | 3 |
|
4 | 4 | from __future__ import annotations
|
5 | 5 |
|
6 |
| -from typing import Callable, Tuple |
| 6 | +from typing import Callable, Dict, Tuple |
7 | 7 |
|
8 | 8 | import pytest
|
9 | 9 | import torch
|
10 |
| -from torch import eye, zeros |
| 10 | +from torch import Tensor, eye, zeros |
11 | 11 | from torch.distributions import HalfNormal, MultivariateNormal
|
12 | 12 |
|
| 13 | +from sbi.inference import NLE, NPE, NRE |
| 14 | +from sbi.inference.trainers.base import NeuralInference |
13 | 15 | from sbi.neural_nets.embedding_nets import CNNEmbedding
|
14 | 16 | from sbi.neural_nets.estimators.shape_handling import reshape_to_sample_batch_event
|
15 | 17 | from sbi.neural_nets.estimators.zuko_flow import ZukoFlow
|
|
36 | 38 | build_zuko_unaf,
|
37 | 39 | )
|
38 | 40 | from sbi.neural_nets.net_builders.flow import build_zuko_flow
|
| 41 | +from sbi.neural_nets.ratio_estimators import RatioEstimator |
39 | 42 | from sbi.utils.torchutils import BoxUniform
|
40 | 43 |
|
41 | 44 | # List of all density estimator builders for testing.
|
@@ -512,3 +515,106 @@ def test_build_zuko_flow_missing_x_dist_raises_error(which_nf):
|
512 | 515 | z_score_y="transform_to_unconstrained",
|
513 | 516 | x_dist=None, # No distribution provided
|
514 | 517 | )
|
| 518 | + |
| 519 | + |
| 520 | +def build_classifier(theta, x): |
| 521 | + net = torch.nn.Linear(theta.shape[1] + x.shape[1], 1) |
| 522 | + return RatioEstimator(net=net, theta_shape=theta[0].shape, x_shape=x[0].shape) |
| 523 | + |
| 524 | + |
| 525 | +def build_estimator(theta, x): |
| 526 | + return build_mdn(theta, x) |
| 527 | + |
| 528 | + |
| 529 | +def build_estimator_missing_args(): |
| 530 | + pass |
| 531 | + |
| 532 | + |
| 533 | +def build_estimator_missing_return(theta: Tensor, x: Tensor): |
| 534 | + pass |
| 535 | + |
| 536 | + |
| 537 | +@pytest.mark.parametrize( |
| 538 | + ("params", "trainer_class"), |
| 539 | + [ |
| 540 | + # Valid builders |
| 541 | + pytest.param(dict(classifier=build_classifier), NRE), |
| 542 | + pytest.param(dict(density_estimator=build_estimator), NPE), |
| 543 | + pytest.param(dict(density_estimator=build_estimator), NLE), |
| 544 | + # Invalid builders |
| 545 | + pytest.param( |
| 546 | + dict(classifier=build_estimator_missing_args), |
| 547 | + NRE, |
| 548 | + marks=pytest.mark.xfail( |
| 549 | + raises=TypeError, |
| 550 | + reason="Missing required parameters in classifier builder.", |
| 551 | + ), |
| 552 | + ), |
| 553 | + pytest.param( |
| 554 | + dict(density_estimator=build_estimator_missing_args), |
| 555 | + NPE, |
| 556 | + marks=pytest.mark.xfail( |
| 557 | + raises=TypeError, |
| 558 | + reason="Missing required parameters in density estimator builder.", |
| 559 | + ), |
| 560 | + ), |
| 561 | + pytest.param( |
| 562 | + dict(density_estimator=build_estimator_missing_args), |
| 563 | + NLE, |
| 564 | + marks=pytest.mark.xfail( |
| 565 | + raises=TypeError, |
| 566 | + reason="Missing required parameters in density estimator builder.", |
| 567 | + ), |
| 568 | + ), |
| 569 | + pytest.param( |
| 570 | + dict(classifier=build_estimator_missing_return), |
| 571 | + NRE, |
| 572 | + marks=pytest.mark.xfail( |
| 573 | + raises=AttributeError, |
| 574 | + reason="Missing return of RatioEstimator in classifier builder.", |
| 575 | + ), |
| 576 | + ), |
| 577 | + pytest.param( |
| 578 | + dict(density_estimator=build_estimator_missing_return), |
| 579 | + NPE, |
| 580 | + marks=pytest.mark.xfail( |
| 581 | + raises=AttributeError, |
| 582 | + reason="Missing return of type ConditionalEstimator" |
| 583 | + " in density estimator builder.", |
| 584 | + ), |
| 585 | + ), |
| 586 | + pytest.param( |
| 587 | + dict(density_estimator=build_estimator_missing_return), |
| 588 | + NLE, |
| 589 | + marks=pytest.mark.xfail( |
| 590 | + raises=AttributeError, |
| 591 | + reason="Missing return of type ConditionalEstimator" |
| 592 | + " in density estimator builder.", |
| 593 | + ), |
| 594 | + ), |
| 595 | + ], |
| 596 | +) |
| 597 | +def test_trainers_with_valid_and_invalid_estimator_builders( |
| 598 | + params: Dict, trainer_class: type[NeuralInference] |
| 599 | +): |
| 600 | + """ |
| 601 | + Test trainers classes work with valid classifier builders and fail |
| 602 | + with invalid ones. |
| 603 | +
|
| 604 | + Args: |
| 605 | + params: Parameters passed to the trainer class. |
| 606 | + trainer_class: Trainer classes. |
| 607 | + """ |
| 608 | + |
| 609 | + def simulator(theta): |
| 610 | + return 1.0 + theta + torch.randn(theta.shape, device=theta.device) * 0.1 |
| 611 | + |
| 612 | + num_dim = 3 |
| 613 | + prior = BoxUniform(low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim)) |
| 614 | + theta = prior.sample((300,)) |
| 615 | + x = simulator(theta) |
| 616 | + |
| 617 | + inference = trainer_class(**params) |
| 618 | + inference.append_simulations(theta, x) |
| 619 | + |
| 620 | + inference.train(max_num_epochs=1) |
0 commit comments