|
| 1 | +import keras |
| 2 | +import inspect |
| 3 | +import textwrap |
| 4 | +from functools import wraps |
| 5 | + |
| 6 | +from keras.src import dtype_policies |
| 7 | +from keras.src import tree |
| 8 | +from keras.src.backend.common.name_scope import current_path |
| 9 | +from keras.src.utils import python_utils |
| 10 | +from keras import Operation |
| 11 | +from keras.saving import get_registered_name, get_registered_object |
| 12 | + |
| 13 | +from bayesflow.utils.serialization import serialize, deserialize |
| 14 | + |
| 15 | + |
| 16 | +class BayesFlowSerializableDict: |
| 17 | + def __init__(self, **config): |
| 18 | + self.config = config |
| 19 | + |
| 20 | + def serialize(self): |
| 21 | + return serialize(self.config) |
| 22 | + |
| 23 | + |
| 24 | +class BaseLayer(keras.Layer): |
| 25 | + def __new__(cls, *args, **kwargs): |
| 26 | + """We override __new__ to saving serializable constructor arguments. |
| 27 | +
|
| 28 | + These arguments are used to auto-generate an object serialization |
| 29 | + config, which enables user-created subclasses to be serializable |
| 30 | + out of the box in most cases without forcing the user |
| 31 | + to manually implement `get_config()`. |
| 32 | + """ |
| 33 | + |
| 34 | + # Adapted from keras.Operation.__new__, to support all serializable objects, instead |
| 35 | + # of only basic types. |
| 36 | + |
| 37 | + instance = super(Operation, cls).__new__(cls) |
| 38 | + |
| 39 | + # Generate a config to be returned by default by `get_config()`. |
| 40 | + arg_names = inspect.getfullargspec(cls.__init__).args |
| 41 | + kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args))) |
| 42 | + |
| 43 | + # Explicitly serialize `dtype` to support auto_config |
| 44 | + dtype = kwargs.get("dtype", None) |
| 45 | + if dtype is not None and isinstance(dtype, dtype_policies.DTypePolicy): |
| 46 | + # For backward compatibility, we use a str (`name`) for |
| 47 | + # `DTypePolicy` |
| 48 | + if dtype.quantization_mode is None: |
| 49 | + kwargs["dtype"] = dtype.name |
| 50 | + # Otherwise, use `dtype_policies.serialize` |
| 51 | + else: |
| 52 | + kwargs["dtype"] = dtype_policies.serialize(dtype) |
| 53 | + |
| 54 | + # Adaptation: we allow all registered serializable objects |
| 55 | + supported_types = (str, int, float, bool, type(None)) |
| 56 | + try: |
| 57 | + flat_arg_values = tree.flatten(kwargs) |
| 58 | + auto_config = True |
| 59 | + for value in flat_arg_values: |
| 60 | + is_serializable = get_registered_object(get_registered_name(type(value))) is not None |
| 61 | + is_class = inspect.isclass(value) |
| 62 | + if not (isinstance(value, supported_types) or is_serializable or is_class): |
| 63 | + auto_config = False |
| 64 | + break |
| 65 | + except TypeError: |
| 66 | + auto_config = False |
| 67 | + try: |
| 68 | + instance._lock = False |
| 69 | + if auto_config: |
| 70 | + instance._auto_config = BayesFlowSerializableDict(**kwargs) |
| 71 | + else: |
| 72 | + instance._auto_config = None |
| 73 | + instance._lock = True |
| 74 | + except RecursionError: |
| 75 | + # Setting an instance attribute in __new__ has the potential |
| 76 | + # to trigger an infinite recursion if a subclass overrides |
| 77 | + # setattr in an unsafe way. |
| 78 | + pass |
| 79 | + |
| 80 | + ### from keras.Layer.__new__ |
| 81 | + |
| 82 | + # Wrap the user-provided `build` method in the `build_wrapper` |
| 83 | + # to add name scope support and serialization support. |
| 84 | + original_build_method = instance.build |
| 85 | + |
| 86 | + @wraps(original_build_method) |
| 87 | + def build_wrapper(*args, **kwargs): |
| 88 | + with instance._open_name_scope(): |
| 89 | + instance._path = current_path() |
| 90 | + original_build_method(*args, **kwargs) |
| 91 | + # Record build config. |
| 92 | + signature = inspect.signature(original_build_method) |
| 93 | + instance._build_shapes_dict = signature.bind(*args, **kwargs).arguments |
| 94 | + # Set built, post build actions, and lock state. |
| 95 | + instance.built = True |
| 96 | + instance._post_build() |
| 97 | + instance._lock_state() |
| 98 | + |
| 99 | + instance.build = build_wrapper |
| 100 | + |
| 101 | + # Wrap the user-provided `quantize` method in the `quantize_wrapper` |
| 102 | + # to add tracker support. |
| 103 | + original_quantize_method = instance.quantize |
| 104 | + |
| 105 | + @wraps(original_quantize_method) |
| 106 | + def quantize_wrapper(mode, **kwargs): |
| 107 | + instance._check_quantize_args(mode, instance.compute_dtype) |
| 108 | + instance._tracker.unlock() |
| 109 | + try: |
| 110 | + original_quantize_method(mode, **kwargs) |
| 111 | + except Exception: |
| 112 | + raise |
| 113 | + finally: |
| 114 | + instance._tracker.lock() |
| 115 | + |
| 116 | + instance.quantize = quantize_wrapper |
| 117 | + |
| 118 | + return instance |
| 119 | + |
| 120 | + @python_utils.default |
| 121 | + def get_config(self): |
| 122 | + """Returns the config of the object. |
| 123 | +
|
| 124 | + An object config is a Python dictionary (serializable) |
| 125 | + containing the information needed to re-instantiate it. |
| 126 | + """ |
| 127 | + |
| 128 | + # Adapted from Operations.get_config to support specifying a default configuration in |
| 129 | + # subclasses, without giving up on the automatic config functionality. |
| 130 | + config = super().get_config() |
| 131 | + if not python_utils.is_default(self.get_config): |
| 132 | + # In this case the subclass implements get_config() |
| 133 | + return config |
| 134 | + |
| 135 | + # In this case the subclass doesn't implement get_config(): |
| 136 | + # Let's see if we can autogenerate it. |
| 137 | + if getattr(self, "_auto_config", None) is not None: |
| 138 | + xtra_args = set(config.keys()) |
| 139 | + config.update(self._auto_config.config) |
| 140 | + # Remove args non explicitly supported |
| 141 | + argspec = inspect.getfullargspec(self.__init__) |
| 142 | + if argspec.varkw != "kwargs": |
| 143 | + for key in xtra_args - xtra_args.intersection(argspec.args[1:]): |
| 144 | + config.pop(key, None) |
| 145 | + return config |
| 146 | + else: |
| 147 | + raise NotImplementedError( |
| 148 | + textwrap.dedent( |
| 149 | + f""" |
| 150 | + Object {self.__class__.__name__} was created by passing |
| 151 | + non-serializable argument values in `__init__()`, |
| 152 | + and therefore the object must override `get_config()` in |
| 153 | + order to be serializable. Please implement `get_config()`. |
| 154 | +
|
| 155 | + Example: |
| 156 | +
|
| 157 | + class CustomLayer(keras.layers.Layer): |
| 158 | + def __init__(self, arg1, arg2, **kwargs): |
| 159 | + super().__init__(**kwargs) |
| 160 | + self.arg1 = arg1 |
| 161 | + self.arg2 = arg2 |
| 162 | +
|
| 163 | + def get_config(self): |
| 164 | + config = super().get_config() |
| 165 | + config.update({{ |
| 166 | + "arg1": self.arg1, |
| 167 | + "arg2": self.arg2, |
| 168 | + }}) |
| 169 | + return config""" |
| 170 | + ) |
| 171 | + ) |
| 172 | + |
| 173 | + @classmethod |
| 174 | + def from_config(cls, config): |
| 175 | + """Creates an operation from its config. |
| 176 | +
|
| 177 | + This method is the reverse of `get_config`, capable of instantiating the |
| 178 | + same operation from the config dictionary. |
| 179 | +
|
| 180 | + Note: If you override this method, you might receive a serialized dtype |
| 181 | + config, which is a `dict`. You can deserialize it as follows: |
| 182 | +
|
| 183 | + ```python |
| 184 | + if "dtype" in config and isinstance(config["dtype"], dict): |
| 185 | + policy = dtype_policies.deserialize(config["dtype"]) |
| 186 | + ``` |
| 187 | +
|
| 188 | + Args: |
| 189 | + config: A Python dictionary, typically the output of `get_config`. |
| 190 | +
|
| 191 | + Returns: |
| 192 | + An operation instance. |
| 193 | + """ |
| 194 | + # Adapted from keras.Operation.from_config to use our deserialize function |
| 195 | + # Explicitly deserialize dtype config if needed. This enables users to |
| 196 | + # directly interact with the instance of `DTypePolicy`. |
| 197 | + if "dtype" in config and isinstance(config["dtype"], dict): |
| 198 | + config = config.copy() |
| 199 | + policy = dtype_policies.deserialize(config["dtype"]) |
| 200 | + if not isinstance(policy, dtype_policies.DTypePolicyMap) and policy.quantization_mode is None: |
| 201 | + # For backward compatibility, we use a str (`name`) for |
| 202 | + # `DTypePolicy` |
| 203 | + policy = policy.name |
| 204 | + config["dtype"] = policy |
| 205 | + try: |
| 206 | + return cls(**deserialize(config)) |
| 207 | + except Exception as e: |
| 208 | + raise TypeError( |
| 209 | + f"Error when deserializing class '{cls.__name__}' using config={config}.\n\nException encountered: {e}" |
| 210 | + ) |
0 commit comments