Skip to content

Commit d31a761

Browse files
authored
Merge pull request #436 from bayesflow-org/dev
Optimal transport hot fixes, consistency models test, and fast imports
2 parents 3c83a47 + 5434666 commit d31a761

File tree

10 files changed

+99
-43
lines changed

10 files changed

+99
-43
lines changed
Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,31 @@
1-
import inspect
1+
import sys
2+
import types
23

34

45
def _add_imports_to_all(include_modules: bool | list[str] = False, exclude: list[str] | None = None):
56
"""Add all global variables to __all__"""
67
if not isinstance(include_modules, (bool, list)):
78
raise ValueError("include_modules must be a boolean or a list of strings")
89

9-
exclude = exclude or []
10-
calling_module = inspect.stack()[1]
11-
local_stack = calling_module[0]
12-
global_vars = local_stack.f_globals
13-
all_vars = global_vars["__all__"] if "__all__" in global_vars else []
14-
included_vars = []
15-
for var_name in set(global_vars.keys()):
16-
if inspect.ismodule(global_vars[var_name]):
17-
if include_modules is True and var_name not in exclude and not var_name.startswith("_"):
18-
included_vars.append(var_name)
19-
elif isinstance(include_modules, list) and var_name in include_modules:
20-
included_vars.append(var_name)
21-
elif var_name not in exclude and not var_name.startswith("_"):
22-
included_vars.append(var_name)
23-
global_vars["__all__"] = sorted(list(set(all_vars).union(included_vars)))
10+
exclude_set = set(exclude or [])
11+
contains = exclude_set.__contains__
12+
mod_type = types.ModuleType
13+
frame = sys._getframe(1)
14+
g: dict = frame.f_globals
15+
existing = set(g.get("__all__", []))
16+
17+
to_add = []
18+
include_list = include_modules if isinstance(include_modules, list) else ()
19+
inc_all = include_modules is True
20+
21+
for name, val in g.items():
22+
if name.startswith("_") or contains(name):
23+
continue
24+
25+
if isinstance(val, mod_type):
26+
if inc_all or name in include_list:
27+
to_add.append(name)
28+
else:
29+
to_add.append(name)
30+
31+
g["__all__"] = sorted(existing.union(to_add))

bayesflow/utils/serialization.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import inspect
55
import keras
66
import numpy as np
7+
import sys
78

89
# this import needs to be exactly like this to work with monkey patching
910
from keras.saving import deserialize_keras_object
@@ -97,7 +98,10 @@ def deserialize(obj, custom_objects=None, safe_mode=True, **kwargs):
9798
# we marked this as a type during serialization
9899
obj = obj[len(_type_prefix) :]
99100
tp = keras.saving.get_registered_object(
100-
obj, custom_objects=custom_objects, module_objects=builtins.__dict__ | np.__dict__
101+
# TODO: can we pass module objects without overwriting numpy's dict with builtins?
102+
obj,
103+
custom_objects=custom_objects,
104+
module_objects=np.__dict__ | builtins.__dict__,
101105
)
102106
if tp is None:
103107
raise ValueError(
@@ -117,10 +121,9 @@ def deserialize(obj, custom_objects=None, safe_mode=True, **kwargs):
117121
@allow_args
118122
def serializable(cls, package=None, name=None):
119123
if package is None:
120-
# get the calling module's name, e.g. "bayesflow.networks.inference_network"
121-
stack = inspect.stack()
122-
module = inspect.getmodule(stack[1][0])
123-
package = copy(module.__name__)
124+
frame = sys._getframe(1)
125+
g = frame.f_globals
126+
package = g.get("__name__", "bayesflow")
124127

125128
if name is None:
126129
name = copy(cls.__name__)

examples/Linear_Regression_Starter.ipynb

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
"outputs": [],
3939
"source": [
4040
"import numpy as np\n",
41+
"from pathlib import Path\n",
4142
"\n",
4243
"import keras\n",
4344
"import bayesflow as bf"
@@ -955,7 +956,9 @@
955956
"outputs": [],
956957
"source": [
957958
"# Recommended - full serialization (checkpoints folder must exist)\n",
958-
"workflow.approximator.save(filepath=\"checkpoints/regression.keras\")\n",
959+
"filepath = Path(\"checkpoints\") / \"regression.keras\"\n",
960+
"filepath.parent.mkdir(exist_ok=True)\n",
961+
"workflow.approximator.save(filepath=filepath)\n",
959962
"\n",
960963
"# Not recommended due to adapter mismatches - weights only\n",
961964
"# approximator.save_weights(filepath=\"checkpoints/regression.h5\")"
@@ -975,7 +978,7 @@
975978
"outputs": [],
976979
"source": [
977980
"# Load approximator\n",
978-
"approximator = keras.saving.load_model(\"checkpoints/regression.keras\")"
981+
"approximator = keras.saving.load_model(filepath)"
979982
]
980983
},
981984
{

examples/Lotka_Volterra_Point_Estimation_and_Expert_Stats.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
"source": [
3838
"import matplotlib.pyplot as plt\n",
3939
"import numpy as np\n",
40+
"from pathlib import Path\n",
4041
"import seaborn as sns\n",
4142
"\n",
4243
"import scipy\n",
@@ -748,7 +749,8 @@
748749
"metadata": {},
749750
"outputs": [],
750751
"source": [
751-
"checkpoint_path = \"checkpoints/model.keras\"\n",
752+
"checkpoint_path = Path(\"checkpoints\") / \"model.keras\"\n",
753+
"checkpoint_path.parent.mkdir(exist_ok=True)\n",
752754
"keras.saving.save_model(point_inference_workflow.approximator, checkpoint_path)"
753755
]
754756
},

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ all = [
3737
"jupyter",
3838
"jupyterlab",
3939
"nbconvert",
40+
"ipython",
41+
"ipykernel",
4042
"pre-commit",
4143
"ruff",
4244
"tox",
@@ -72,6 +74,8 @@ docs = [
7274
]
7375
test = [
7476
"nbconvert",
77+
"ipython",
78+
"ipykernel",
7579
"pytest",
7680
"pytest-cov",
7781
"pytest-rerunfailures",

tests/test_examples/test_examples.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from tests.utils import run_notebook
44

55

6+
@pytest.mark.skip(reason="requires setting up Stan")
67
@pytest.mark.slow
78
def test_bayesian_experimental_design(examples_path):
89
run_notebook(examples_path / "Bayesian_Experimental_Design.ipynb")
@@ -30,7 +31,7 @@ def test_one_sample_ttest(examples_path):
3031

3132
@pytest.mark.slow
3233
def test_sir_posterior_estimation(examples_path):
33-
run_notebook(examples_path / "SIR_Posterior_estimation.ipynb")
34+
run_notebook(examples_path / "SIR_Posterior_Estimation.ipynb")
3435

3536

3637
@pytest.mark.slow

tests/test_networks/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def typical_point_inference_network_subnet():
8585
"spline_coupling_flow",
8686
"flow_matching",
8787
"free_form_flow",
88+
"consistency_model",
8889
],
8990
scope="function",
9091
)
@@ -106,7 +107,8 @@ def inference_network_subnet(request):
106107

107108

108109
@pytest.fixture(
109-
params=["affine_coupling_flow", "spline_coupling_flow", "flow_matching", "free_form_flow"], scope="function"
110+
params=["affine_coupling_flow", "spline_coupling_flow", "flow_matching", "free_form_flow", "consistency_model"],
111+
scope="function",
110112
)
111113
def generative_inference_network(request):
112114
return request.getfixturevalue(request.param)

tests/test_networks/test_inference_networks.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,21 @@ def test_variable_batch_size(inference_network, random_samples, random_condition
3636
else:
3737
new_conditions = keras.ops.zeros((bs,) + keras.ops.shape(random_conditions)[1:])
3838

39-
inference_network(new_input, conditions=new_conditions)
39+
try:
40+
inference_network(new_input, conditions=new_conditions)
41+
except NotImplementedError:
42+
# network is not invertible
43+
pass
4044
inference_network(new_input, conditions=new_conditions, inverse=True)
4145

4246

4347
@pytest.mark.parametrize("density", [True, False])
4448
def test_output_structure(density, generative_inference_network, random_samples, random_conditions):
45-
output = generative_inference_network(random_samples, conditions=random_conditions, density=density)
49+
try:
50+
output = generative_inference_network(random_samples, conditions=random_conditions, density=density)
51+
except NotImplementedError:
52+
# network not invertible
53+
return
4654

4755
if density:
4856
assert isinstance(output, tuple)
@@ -57,9 +65,13 @@ def test_output_structure(density, generative_inference_network, random_samples,
5765

5866

5967
def test_output_shape(generative_inference_network, random_samples, random_conditions):
60-
forward_output, forward_log_density = generative_inference_network(
61-
random_samples, conditions=random_conditions, density=True
62-
)
68+
try:
69+
forward_output, forward_log_density = generative_inference_network(
70+
random_samples, conditions=random_conditions, density=True
71+
)
72+
except NotImplementedError:
73+
# network is not invertible, not forward function available
74+
return
6375

6476
assert keras.ops.shape(forward_output) == keras.ops.shape(random_samples)
6577
assert keras.ops.shape(forward_log_density) == (keras.ops.shape(random_samples)[0],)
@@ -74,9 +86,13 @@ def test_output_shape(generative_inference_network, random_samples, random_condi
7486

7587
def test_cycle_consistency(generative_inference_network, random_samples, random_conditions):
7688
# cycle-consistency means the forward and inverse methods are inverses of each other
77-
forward_output, forward_log_density = generative_inference_network(
78-
random_samples, conditions=random_conditions, density=True
79-
)
89+
try:
90+
forward_output, forward_log_density = generative_inference_network(
91+
random_samples, conditions=random_conditions, density=True
92+
)
93+
except NotImplementedError:
94+
# network is not invertible, cycle consistency cannot be tested.
95+
return
8096
inverse_output, inverse_log_density = generative_inference_network(
8197
forward_output, conditions=random_conditions, density=True, inverse=True
8298
)
@@ -88,7 +104,11 @@ def test_cycle_consistency(generative_inference_network, random_samples, random_
88104
def test_density_numerically(generative_inference_network, random_samples, random_conditions):
89105
from bayesflow.utils import jacobian
90106

91-
output, log_density = generative_inference_network(random_samples, conditions=random_conditions, density=True)
107+
try:
108+
output, log_density = generative_inference_network(random_samples, conditions=random_conditions, density=True)
109+
except NotImplementedError:
110+
# network does not support density estimation
111+
return
92112

93113
def f(x):
94114
return generative_inference_network(x, conditions=random_conditions)

tests/test_utils/test_dispatch.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
# Import the dispatch functions
55
from bayesflow.utils import find_network, find_permutation, find_pooling, find_recurrent_net
6+
from tests.utils import assert_allclose
67

78
# --- Tests for find_network.py ---
89

@@ -118,23 +119,21 @@ def test_find_pooling_mean():
118119
# Check that a keras Lambda layer is returned
119120
assert isinstance(pooling, keras.layers.Lambda)
120121
# Test that the lambda function produces a mean when applied to a sample tensor.
121-
import numpy as np
122122

123-
sample = np.array([[1, 2], [3, 4]])
123+
sample = keras.ops.convert_to_tensor([[1, 2], [3, 4]])
124124
# Keras Lambda layers expect tensors via call(), here we simply call the layer's function.
125125
result = pooling.call(sample)
126-
np.testing.assert_allclose(result, sample.mean(axis=-2))
126+
assert_allclose(result, keras.ops.mean(sample, axis=-2))
127127

128128

129129
@pytest.mark.parametrize("name,func", [("max", keras.ops.max), ("min", keras.ops.min)])
130130
def test_find_pooling_max_min(name, func):
131131
pooling = find_pooling(name)
132132
assert isinstance(pooling, keras.layers.Lambda)
133-
import numpy as np
134133

135-
sample = np.array([[1, 2], [3, 4]])
134+
sample = keras.ops.convert_to_tensor([[1, 2], [3, 4]])
136135
result = pooling.call(sample)
137-
np.testing.assert_allclose(result, func(sample, axis=-2))
136+
assert_allclose(result, func(sample, axis=-2))
138137

139138

140139
def test_find_pooling_learnable(monkeypatch):

tests/utils/jupyter.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,25 @@
11
import nbformat
22
from nbconvert.preprocessors import ExecutePreprocessor
33

4+
from pathlib import Path
5+
import shutil
6+
47

58
def run_notebook(path):
9+
path = Path(path)
10+
checkpoint_path = path.parent / "checkpoints"
11+
# only clean up if the directory did not exist before the test
12+
cleanup_checkpoints = not checkpoint_path.exists()
613
with open(str(path)) as f:
714
nb = nbformat.read(f, nbformat.NO_CONVERT)
815

9-
kernel = ExecutePreprocessor(timeout=600, kernel_name="python3")
16+
kernel = ExecutePreprocessor(timeout=600, kernel_name="python3", resources={"metadata": {"path": path.parent}})
17+
18+
try:
19+
result = kernel.preprocess(nb)
20+
finally:
21+
if cleanup_checkpoints and checkpoint_path.exists():
22+
# clean up if the directory was created by the test
23+
shutil.rmtree(checkpoint_path)
1024

11-
return kernel.preprocess(nb)
25+
return result

0 commit comments

Comments
 (0)