Skip to content

Commit 5434666

Browse files
authored
add exception to tests for non-invertible networks (#425)
1 parent 1e63803 commit 5434666

File tree

2 files changed

+32
-10
lines changed

2 files changed

+32
-10
lines changed

tests/test_networks/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def typical_point_inference_network_subnet():
8585
"spline_coupling_flow",
8686
"flow_matching",
8787
"free_form_flow",
88+
"consistency_model",
8889
],
8990
scope="function",
9091
)
@@ -106,7 +107,8 @@ def inference_network_subnet(request):
106107

107108

108109
@pytest.fixture(
109-
params=["affine_coupling_flow", "spline_coupling_flow", "flow_matching", "free_form_flow"], scope="function"
110+
params=["affine_coupling_flow", "spline_coupling_flow", "flow_matching", "free_form_flow", "consistency_model"],
111+
scope="function",
110112
)
111113
def generative_inference_network(request):
112114
return request.getfixturevalue(request.param)

tests/test_networks/test_inference_networks.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,21 @@ def test_variable_batch_size(inference_network, random_samples, random_condition
3636
else:
3737
new_conditions = keras.ops.zeros((bs,) + keras.ops.shape(random_conditions)[1:])
3838

39-
inference_network(new_input, conditions=new_conditions)
39+
try:
40+
inference_network(new_input, conditions=new_conditions)
41+
except NotImplementedError:
42+
# network is not invertible
43+
pass
4044
inference_network(new_input, conditions=new_conditions, inverse=True)
4145

4246

4347
@pytest.mark.parametrize("density", [True, False])
4448
def test_output_structure(density, generative_inference_network, random_samples, random_conditions):
45-
output = generative_inference_network(random_samples, conditions=random_conditions, density=density)
49+
try:
50+
output = generative_inference_network(random_samples, conditions=random_conditions, density=density)
51+
except NotImplementedError:
52+
# network not invertible
53+
return
4654

4755
if density:
4856
assert isinstance(output, tuple)
@@ -57,9 +65,13 @@ def test_output_structure(density, generative_inference_network, random_samples,
5765

5866

5967
def test_output_shape(generative_inference_network, random_samples, random_conditions):
60-
forward_output, forward_log_density = generative_inference_network(
61-
random_samples, conditions=random_conditions, density=True
62-
)
68+
try:
69+
forward_output, forward_log_density = generative_inference_network(
70+
random_samples, conditions=random_conditions, density=True
71+
)
72+
except NotImplementedError:
73+
# network is not invertible, not forward function available
74+
return
6375

6476
assert keras.ops.shape(forward_output) == keras.ops.shape(random_samples)
6577
assert keras.ops.shape(forward_log_density) == (keras.ops.shape(random_samples)[0],)
@@ -74,9 +86,13 @@ def test_output_shape(generative_inference_network, random_samples, random_condi
7486

7587
def test_cycle_consistency(generative_inference_network, random_samples, random_conditions):
7688
# cycle-consistency means the forward and inverse methods are inverses of each other
77-
forward_output, forward_log_density = generative_inference_network(
78-
random_samples, conditions=random_conditions, density=True
79-
)
89+
try:
90+
forward_output, forward_log_density = generative_inference_network(
91+
random_samples, conditions=random_conditions, density=True
92+
)
93+
except NotImplementedError:
94+
# network is not invertible, cycle consistency cannot be tested.
95+
return
8096
inverse_output, inverse_log_density = generative_inference_network(
8197
forward_output, conditions=random_conditions, density=True, inverse=True
8298
)
@@ -88,7 +104,11 @@ def test_cycle_consistency(generative_inference_network, random_samples, random_
88104
def test_density_numerically(generative_inference_network, random_samples, random_conditions):
89105
from bayesflow.utils import jacobian
90106

91-
output, log_density = generative_inference_network(random_samples, conditions=random_conditions, density=True)
107+
try:
108+
output, log_density = generative_inference_network(random_samples, conditions=random_conditions, density=True)
109+
except NotImplementedError:
110+
# network does not support density estimation
111+
return
92112

93113
def f(x):
94114
return generative_inference_network(x, conditions=random_conditions)

0 commit comments

Comments
 (0)