Skip to content

Commit 023db06

Browse files
committed
Adapt Keras' auto-config mechanism for inference/summary networks
This change makes it more capable for our purposes by allowing any serializable value, not only the base types in the auto-config. We have to check if this brings any footguns/downsides, or whether this is fine for our setting. It also replaces Keras' functions with our custom serialization functions.
1 parent 8d296e9 commit 023db06

File tree

13 files changed

+255
-45
lines changed

13 files changed

+255
-45
lines changed

bayesflow/experimental/diffusion_model/diffusion_model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,6 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
140140

141141
def get_config(self):
142142
base_config = super().get_config()
143-
# base distribution is fixed and passed in constructor
144-
base_config.pop("base_distribution")
145143

146144
config = {
147145
"subnet": self.subnet,
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .base_layer import BaseLayer
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
import keras
2+
import inspect
3+
import textwrap
4+
from functools import wraps
5+
6+
from keras.src import dtype_policies
7+
from keras.src import tree
8+
from keras.src.backend.common.name_scope import current_path
9+
from keras.src.utils import python_utils
10+
from keras import Operation
11+
from keras.saving import get_registered_name, get_registered_object
12+
13+
from bayesflow.utils.serialization import serialize, deserialize
14+
15+
16+
class BayesFlowSerializableDict:
17+
def __init__(self, **config):
18+
self.config = config
19+
20+
def serialize(self):
21+
return serialize(self.config)
22+
23+
24+
class BaseLayer(keras.Layer):
25+
def __new__(cls, *args, **kwargs):
26+
"""We override __new__ to saving serializable constructor arguments.
27+
28+
These arguments are used to auto-generate an object serialization
29+
config, which enables user-created subclasses to be serializable
30+
out of the box in most cases without forcing the user
31+
to manually implement `get_config()`.
32+
"""
33+
34+
# Adapted from keras.Operation.__new__, to support all serializable objects, instead
35+
# of only basic types.
36+
37+
instance = super(Operation, cls).__new__(cls)
38+
39+
# Generate a config to be returned by default by `get_config()`.
40+
arg_names = inspect.getfullargspec(cls.__init__).args
41+
kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args)))
42+
43+
# Explicitly serialize `dtype` to support auto_config
44+
dtype = kwargs.get("dtype", None)
45+
if dtype is not None and isinstance(dtype, dtype_policies.DTypePolicy):
46+
# For backward compatibility, we use a str (`name`) for
47+
# `DTypePolicy`
48+
if dtype.quantization_mode is None:
49+
kwargs["dtype"] = dtype.name
50+
# Otherwise, use `dtype_policies.serialize`
51+
else:
52+
kwargs["dtype"] = dtype_policies.serialize(dtype)
53+
54+
# Adaptation: we allow all registered serializable objects
55+
supported_types = (str, int, float, bool, type(None))
56+
try:
57+
flat_arg_values = tree.flatten(kwargs)
58+
auto_config = True
59+
for value in flat_arg_values:
60+
is_serializable = get_registered_object(get_registered_name(type(value))) is not None
61+
is_class = inspect.isclass(value)
62+
if not (isinstance(value, supported_types) or is_serializable or is_class):
63+
auto_config = False
64+
break
65+
except TypeError:
66+
auto_config = False
67+
try:
68+
instance._lock = False
69+
if auto_config:
70+
instance._auto_config = BayesFlowSerializableDict(**kwargs)
71+
else:
72+
instance._auto_config = None
73+
instance._lock = True
74+
except RecursionError:
75+
# Setting an instance attribute in __new__ has the potential
76+
# to trigger an infinite recursion if a subclass overrides
77+
# setattr in an unsafe way.
78+
pass
79+
80+
### from keras.Layer.__new__
81+
82+
# Wrap the user-provided `build` method in the `build_wrapper`
83+
# to add name scope support and serialization support.
84+
original_build_method = instance.build
85+
86+
@wraps(original_build_method)
87+
def build_wrapper(*args, **kwargs):
88+
with instance._open_name_scope():
89+
instance._path = current_path()
90+
original_build_method(*args, **kwargs)
91+
# Record build config.
92+
signature = inspect.signature(original_build_method)
93+
instance._build_shapes_dict = signature.bind(*args, **kwargs).arguments
94+
# Set built, post build actions, and lock state.
95+
instance.built = True
96+
instance._post_build()
97+
instance._lock_state()
98+
99+
instance.build = build_wrapper
100+
101+
# Wrap the user-provided `quantize` method in the `quantize_wrapper`
102+
# to add tracker support.
103+
original_quantize_method = instance.quantize
104+
105+
@wraps(original_quantize_method)
106+
def quantize_wrapper(mode, **kwargs):
107+
instance._check_quantize_args(mode, instance.compute_dtype)
108+
instance._tracker.unlock()
109+
try:
110+
original_quantize_method(mode, **kwargs)
111+
except Exception:
112+
raise
113+
finally:
114+
instance._tracker.lock()
115+
116+
instance.quantize = quantize_wrapper
117+
118+
return instance
119+
120+
@python_utils.default
121+
def get_config(self):
122+
"""Returns the config of the object.
123+
124+
An object config is a Python dictionary (serializable)
125+
containing the information needed to re-instantiate it.
126+
"""
127+
128+
# Adapted from Operations.get_config to support specifying a default configuration in
129+
# subclasses, without giving up on the automatic config functionality.
130+
config = super().get_config()
131+
if not python_utils.is_default(self.get_config):
132+
# In this case the subclass implements get_config()
133+
return config
134+
135+
# In this case the subclass doesn't implement get_config():
136+
# Let's see if we can autogenerate it.
137+
if getattr(self, "_auto_config", None) is not None:
138+
xtra_args = set(config.keys())
139+
config.update(self._auto_config.config)
140+
# Remove args non explicitly supported
141+
argspec = inspect.getfullargspec(self.__init__)
142+
if argspec.varkw != "kwargs":
143+
for key in xtra_args - xtra_args.intersection(argspec.args[1:]):
144+
config.pop(key, None)
145+
return config
146+
else:
147+
raise NotImplementedError(
148+
textwrap.dedent(
149+
f"""
150+
Object {self.__class__.__name__} was created by passing
151+
non-serializable argument values in `__init__()`,
152+
and therefore the object must override `get_config()` in
153+
order to be serializable. Please implement `get_config()`.
154+
155+
Example:
156+
157+
class CustomLayer(keras.layers.Layer):
158+
def __init__(self, arg1, arg2, **kwargs):
159+
super().__init__(**kwargs)
160+
self.arg1 = arg1
161+
self.arg2 = arg2
162+
163+
def get_config(self):
164+
config = super().get_config()
165+
config.update({{
166+
"arg1": self.arg1,
167+
"arg2": self.arg2,
168+
}})
169+
return config"""
170+
)
171+
)
172+
173+
@classmethod
174+
def from_config(cls, config):
175+
"""Creates an operation from its config.
176+
177+
This method is the reverse of `get_config`, capable of instantiating the
178+
same operation from the config dictionary.
179+
180+
Note: If you override this method, you might receive a serialized dtype
181+
config, which is a `dict`. You can deserialize it as follows:
182+
183+
```python
184+
if "dtype" in config and isinstance(config["dtype"], dict):
185+
policy = dtype_policies.deserialize(config["dtype"])
186+
```
187+
188+
Args:
189+
config: A Python dictionary, typically the output of `get_config`.
190+
191+
Returns:
192+
An operation instance.
193+
"""
194+
# Adapted from keras.Operation.from_config to use our deserialize function
195+
# Explicitly deserialize dtype config if needed. This enables users to
196+
# directly interact with the instance of `DTypePolicy`.
197+
if "dtype" in config and isinstance(config["dtype"], dict):
198+
config = config.copy()
199+
policy = dtype_policies.deserialize(config["dtype"])
200+
if not isinstance(policy, dtype_policies.DTypePolicyMap) and policy.quantization_mode is None:
201+
# For backward compatibility, we use a str (`name`) for
202+
# `DTypePolicy`
203+
policy = policy.name
204+
config["dtype"] = policy
205+
try:
206+
return cls(**deserialize(config))
207+
except Exception as e:
208+
raise TypeError(
209+
f"Error when deserializing class '{cls.__name__}' using config={config}.\n\nException encountered: {e}"
210+
)

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,6 @@ def from_config(cls, config, custom_objects=None):
109109

110110
def get_config(self):
111111
base_config = super().get_config()
112-
# base distribution is fixed and passed in constructor
113-
base_config.pop("base_distribution")
114112

115113
config = {
116114
"total_steps": self.total_steps,

bayesflow/networks/fusion_network/fusion_network.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections.abc import Mapping
22
from ..summary_network import SummaryNetwork
3-
from bayesflow.utils.serialization import deserialize, serializable, serialize
3+
from bayesflow.utils.serialization import serializable, serialize
44
from bayesflow.types import Tensor, Shape
55
import keras
66
from keras import ops
@@ -116,8 +116,3 @@ def get_config(self) -> dict:
116116
"head": self.head,
117117
}
118118
return base_config | serialize(config)
119-
120-
@classmethod
121-
def from_config(cls, config: dict, custom_objects=None):
122-
config = deserialize(config, custom_objects=custom_objects)
123-
return cls(**config)

bayesflow/networks/inference_network.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
from bayesflow.types import Shape, Tensor
55
from bayesflow.utils import layer_kwargs, find_distribution
66
from bayesflow.utils.decorators import allow_batch_size
7-
from bayesflow.utils.serialization import deserialize, serializable, serialize
7+
from bayesflow.utils.serialization import serializable
8+
from .base_layer import BaseLayer
89

910

1011
@serializable("bayesflow.networks")
11-
class InferenceNetwork(keras.Layer):
12+
class InferenceNetwork(BaseLayer):
1213
def __init__(self, base_distribution: str = "normal", *, metrics: Sequence[keras.Metric] = None, **kwargs):
1314
self.custom_metrics = metrics
1415
super().__init__(**layer_kwargs(kwargs))
@@ -76,13 +77,3 @@ def compute_metrics(
7677
metrics[metric.name] = metric(samples, x)
7778

7879
return metrics
79-
80-
def get_config(self):
81-
base_config = super().get_config()
82-
base_config = layer_kwargs(base_config)
83-
config = {"metrics": self.custom_metrics, "base_distribution": self.base_distribution}
84-
return base_config | serialize(config)
85-
86-
@classmethod
87-
def from_config(cls, config, custom_objects=None):
88-
return cls(**deserialize(config, custom_objects=custom_objects))

bayesflow/networks/summary_network.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,20 @@
55
from bayesflow.types import Tensor
66
from bayesflow.utils import layer_kwargs, find_distribution
77
from bayesflow.utils.decorators import sanitize_input_shape
8-
from bayesflow.utils.serialization import deserialize, serializable, serialize
8+
from bayesflow.utils.serialization import serializable
9+
from .base_layer import BaseLayer
910

1011

1112
@serializable("bayesflow.networks")
12-
class SummaryNetwork(keras.Layer):
13+
class SummaryNetwork(BaseLayer):
1314
def __init__(self, base_distribution: str = None, *, metrics: Sequence[keras.Metric] = None, **kwargs):
1415
self.custom_metrics = metrics
1516
super().__init__(**layer_kwargs(kwargs))
1617
self.base_distribution = find_distribution(base_distribution)
1718

1819
@sanitize_input_shape
1920
def build(self, input_shape):
21+
print("SN build", self, input_shape)
2022
x = keras.ops.zeros(input_shape)
2123
z = self.call(x)
2224

@@ -53,13 +55,3 @@ def compute_metrics(self, x: Tensor, stage: str = "training", **kwargs) -> dict[
5355
metrics[metric.name] = metric(outputs, samples)
5456

5557
return metrics
56-
57-
def get_config(self):
58-
base_config = super().get_config()
59-
base_config = layer_kwargs(base_config)
60-
config = {"base_distribution": self.base_distribution, "metrics": self.custom_metrics}
61-
return base_config | serialize(config)
62-
63-
@classmethod
64-
def from_config(cls, config, custom_objects=None):
65-
return cls(**deserialize(config, custom_objects=custom_objects))

tests/test_adapters/test_adapters.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import keras
55

6-
from bayesflow.utils.serialization import deserialize, serialize
6+
from bayesflow.utils.serialization import deserialize, serialize, normalize_config
77

88
import bayesflow as bf
99

@@ -29,7 +29,7 @@ def test_serialize_deserialize(adapter, random_data):
2929
deserialized = deserialize(serialized)
3030
reserialized = serialize(deserialized)
3131

32-
assert keras.tree.lists_to_tuples(serialized) == keras.tree.lists_to_tuples(reserialized)
32+
assert normalize_config(serialized) == normalize_config(reserialized)
3333

3434
random_data["foo"] = random_data["x1"]
3535
deserialized_processed = deserialized(random_data)
@@ -122,7 +122,6 @@ def test_simple_transforms(random_data):
122122

123123
def test_custom_transform():
124124
# test that transform raises errors in all relevant cases
125-
import keras
126125
from bayesflow.adapters.transforms import SerializableCustomTransform
127126
from copy import deepcopy
128127

@@ -335,7 +334,7 @@ def test_nnpe(random_data):
335334
deserialized = deserialize(serialized)
336335
reserialized = serialize(deserialized)
337336

338-
assert keras.tree.lists_to_tuples(serialized) == keras.tree.lists_to_tuples(reserialized)
337+
assert normalize_config(serialized) == normalize_config(reserialized)
339338

340339
# check that only x1 is changed
341340
assert "x1" in result_training
@@ -365,7 +364,7 @@ def test_nnpe(random_data):
365364
serialized_auto = serialize(ad_auto)
366365
deserialized_auto = deserialize(serialized_auto)
367366
reserialized_auto = serialize(deserialized_auto)
368-
assert keras.tree.lists_to_tuples(serialized_auto) == keras.tree.lists_to_tuples(serialize(reserialized_auto))
367+
assert normalize_config(serialized_auto) == normalize_config(serialize(reserialized_auto))
369368

370369
# Test dimensionwise versus global noise application (per_dimension=True vs per_dimension=False)
371370
# Create data with second dimension having higher variance

tests/test_networks/test_fusion_network/test_fusion_network.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pytest
33
import keras
44

5-
from tests.utils import assert_layers_equal, allclose
5+
from tests.utils import assert_layers_equal, allclose, normalize_config
66

77

88
@pytest.mark.parametrize("automatic", [True, False])
@@ -57,7 +57,7 @@ def test_serialize_deserialize(fusion_network, multimodal_data):
5757
deserialized = deserialize(serialized)
5858
reserialized = serialize(deserialized)
5959

60-
assert keras.tree.lists_to_tuples(serialized) == keras.tree.lists_to_tuples(reserialized)
60+
assert normalize_config(serialized) == normalize_config(reserialized)
6161

6262

6363
def test_save_and_load(tmp_path, fusion_network, multimodal_data):

tests/test_networks/test_inference_networks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from bayesflow.utils.serialization import serialize, deserialize
66

7-
from tests.utils import assert_allclose, assert_layers_equal
7+
from tests.utils import assert_allclose, assert_layers_equal, normalize_config
88

99

1010
def test_build(inference_network, random_samples, random_conditions):
@@ -137,7 +137,7 @@ def test_serialize_deserialize(inference_network, random_samples, random_conditi
137137
deserialized = deserialize(serialized)
138138
reserialized = serialize(deserialized)
139139

140-
assert keras.tree.lists_to_tuples(serialized) == keras.tree.lists_to_tuples(reserialized)
140+
assert normalize_config(serialized) == normalize_config(reserialized)
141141

142142

143143
def test_save_and_load(tmp_path, inference_network, random_samples, random_conditions):

0 commit comments

Comments
 (0)