@@ -36,13 +36,21 @@ def test_variable_batch_size(inference_network, random_samples, random_condition
36
36
else :
37
37
new_conditions = keras .ops .zeros ((bs ,) + keras .ops .shape (random_conditions )[1 :])
38
38
39
- inference_network (new_input , conditions = new_conditions )
39
+ try :
40
+ inference_network (new_input , conditions = new_conditions )
41
+ except NotImplementedError :
42
+ # network is not invertible
43
+ pass
40
44
inference_network (new_input , conditions = new_conditions , inverse = True )
41
45
42
46
43
47
@pytest .mark .parametrize ("density" , [True , False ])
44
48
def test_output_structure (density , generative_inference_network , random_samples , random_conditions ):
45
- output = generative_inference_network (random_samples , conditions = random_conditions , density = density )
49
+ try :
50
+ output = generative_inference_network (random_samples , conditions = random_conditions , density = density )
51
+ except NotImplementedError :
52
+ # network not invertible
53
+ return
46
54
47
55
if density :
48
56
assert isinstance (output , tuple )
@@ -57,9 +65,13 @@ def test_output_structure(density, generative_inference_network, random_samples,
57
65
58
66
59
67
def test_output_shape (generative_inference_network , random_samples , random_conditions ):
60
- forward_output , forward_log_density = generative_inference_network (
61
- random_samples , conditions = random_conditions , density = True
62
- )
68
+ try :
69
+ forward_output , forward_log_density = generative_inference_network (
70
+ random_samples , conditions = random_conditions , density = True
71
+ )
72
+ except NotImplementedError :
73
+ # network is not invertible, not forward function available
74
+ return
63
75
64
76
assert keras .ops .shape (forward_output ) == keras .ops .shape (random_samples )
65
77
assert keras .ops .shape (forward_log_density ) == (keras .ops .shape (random_samples )[0 ],)
@@ -74,9 +86,13 @@ def test_output_shape(generative_inference_network, random_samples, random_condi
74
86
75
87
def test_cycle_consistency (generative_inference_network , random_samples , random_conditions ):
76
88
# cycle-consistency means the forward and inverse methods are inverses of each other
77
- forward_output , forward_log_density = generative_inference_network (
78
- random_samples , conditions = random_conditions , density = True
79
- )
89
+ try :
90
+ forward_output , forward_log_density = generative_inference_network (
91
+ random_samples , conditions = random_conditions , density = True
92
+ )
93
+ except NotImplementedError :
94
+ # network is not invertible, cycle consistency cannot be tested.
95
+ return
80
96
inverse_output , inverse_log_density = generative_inference_network (
81
97
forward_output , conditions = random_conditions , density = True , inverse = True
82
98
)
@@ -88,7 +104,11 @@ def test_cycle_consistency(generative_inference_network, random_samples, random_
88
104
def test_density_numerically (generative_inference_network , random_samples , random_conditions ):
89
105
from bayesflow .utils import jacobian
90
106
91
- output , log_density = generative_inference_network (random_samples , conditions = random_conditions , density = True )
107
+ try :
108
+ output , log_density = generative_inference_network (random_samples , conditions = random_conditions , density = True )
109
+ except NotImplementedError :
110
+ # network does not support density estimation
111
+ return
92
112
93
113
def f (x ):
94
114
return generative_inference_network (x , conditions = random_conditions )
0 commit comments