diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index fa84a9b4f..458a2b136 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -484,6 +484,8 @@ def concatenate(self, keys: str | Sequence[str], *, into: str, axis: int = -1): """ if isinstance(keys, str): transform = Rename(keys, to_key=into) + elif len(keys) == 1: + transform = Rename(keys[0], to_key=into) else: transform = Concatenate(keys, into=into, axis=axis) self.transforms.append(transform) diff --git a/bayesflow/adapters/transforms/convert_dtype.py b/bayesflow/adapters/transforms/convert_dtype.py index 8cd21b4cc..d9159487e 100644 --- a/bayesflow/adapters/transforms/convert_dtype.py +++ b/bayesflow/adapters/transforms/convert_dtype.py @@ -1,4 +1,5 @@ import numpy as np +from keras.tree import map_structure from bayesflow.utils.serialization import serializable, serialize @@ -31,8 +32,8 @@ def get_config(self) -> dict: } return serialize(config) - def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: - return data.astype(self.to_dtype, copy=False) + def forward(self, data: np.ndarray | dict, **kwargs) -> np.ndarray | dict: + return map_structure(lambda d: d.astype(self.to_dtype, copy=False), data) - def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: - return data.astype(self.from_dtype, copy=False) + def inverse(self, data: np.ndarray | dict, **kwargs) -> np.ndarray | dict: + return map_structure(lambda d: d.astype(self.from_dtype, copy=False), data) diff --git a/bayesflow/adapters/transforms/to_array.py b/bayesflow/adapters/transforms/to_array.py index fe1b82f2d..d6dba2aa9 100644 --- a/bayesflow/adapters/transforms/to_array.py +++ b/bayesflow/adapters/transforms/to_array.py @@ -2,6 +2,7 @@ import numpy as np +from bayesflow.utils.tree import map_dict from bayesflow.utils.serialization import serializable, serialize from .elementwise_transform import ElementwiseTransform @@ -34,12 +35,20 @@ def get_config(self) -> dict: return serialize({"original_type": self.original_type}) def forward(self, data: any, **kwargs) -> np.ndarray: + if isinstance(data, dict): + # no invertiblity for dict, do not store original type + return map_dict(np.asarray, data) + if self.original_type is None: self.original_type = type(data) return np.asarray(data) - def inverse(self, data: np.ndarray, **kwargs) -> any: + def inverse(self, data: np.ndarray | dict, **kwargs) -> any: + if isinstance(data, dict): + # no invertibility for dict to keep complexity low + return data + if self.original_type is None: raise RuntimeError("Cannot call `inverse` before calling `forward` at least once.") diff --git a/bayesflow/utils/tree.py b/bayesflow/utils/tree.py index ca8bc433c..a19d2e68e 100644 --- a/bayesflow/utils/tree.py +++ b/bayesflow/utils/tree.py @@ -1,4 +1,5 @@ import optree +from typing import Callable def flatten_shape(structure): @@ -12,3 +13,31 @@ def is_shape_tuple(x): namespace="keras", ) return leaves + + +def map_dict(func: Callable, dictionary: dict) -> dict: + """Applies a function to all leaves of a (possibly nested) dictionary. + + Parameters + ---------- + func : Callable + The function to apply to the leaves. + dictionary : dict + The input dictionary. + + Returns + ------- + dict + A dictionary with the outputs of `func` as leaves. + """ + + def is_not_dict(x): + return not isinstance(x, dict) + + return optree.tree_map( + func, + dictionary, + is_leaf=is_not_dict, + none_is_leaf=True, + namespace="keras", + ) diff --git a/tests/test_adapters/conftest.py b/tests/test_adapters/conftest.py index 3193309ae..feccc6d77 100644 --- a/tests/test_adapters/conftest.py +++ b/tests/test_adapters/conftest.py @@ -13,7 +13,9 @@ def serializable_fn(x): return ( Adapter() + .group(["p1", "p2"], into="ps", prefix="p") .to_array() + .ungroup("ps", prefix="p") .as_set(["s1", "s2"]) .broadcast("t1", to="t2") .as_time_series(["t1", "t2"]) @@ -37,8 +39,6 @@ def serializable_fn(x): .rename("o1", "o2") .random_subsample("s3", sample_size=33, axis=0) .take("s3", indices=np.arange(0, 32), axis=0) - .group(["p1", "p2"], into="ps", prefix="p") - .ungroup("ps", prefix="p") ) diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index 23721a938..095058eea 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -393,3 +393,16 @@ def test_nnpe(random_data): # Both should assign noise to high-variance dimension assert std_dim[1] > 0 assert std_glob[1] > 0 + + +def test_single_concatenate_to_rename(): + # test that single-element concatenate is converted to rename + from bayesflow import Adapter + from bayesflow.adapters.transforms import Rename, Concatenate + + ad = Adapter().concatenate("a", into="b") + assert isinstance(ad[0], Rename) + ad = Adapter().concatenate(["a"], into="b") + assert isinstance(ad[0], Rename) + ad = Adapter().concatenate(["a", "b"], into="c") + assert isinstance(ad[0], Concatenate) diff --git a/tests/test_utils/test_tree.py b/tests/test_utils/test_tree.py new file mode 100644 index 000000000..0c3dc81f2 --- /dev/null +++ b/tests/test_utils/test_tree.py @@ -0,0 +1,16 @@ +def test_map_dict(): + from bayesflow.utils.tree import map_dict + + input = { + "a": { + "x": [0, 1, 2], + }, + "b": [0, 1], + "c": "foo", + } + output = map_dict(len, input) + for key, value in output.items(): + if key == "a": + assert value["x"] == len(input["a"]["x"]) + continue + assert value == len(input[key]) diff --git a/tests/test_workflows/conftest.py b/tests/test_workflows/conftest.py index 84b3fdafb..126c39cea 100644 --- a/tests/test_workflows/conftest.py +++ b/tests/test_workflows/conftest.py @@ -81,13 +81,6 @@ def sample(self, batch_shape: Shape, num_observations: int = 4) -> dict[str, Ten x = mean[:, None] + noise - return dict(mean=mean, a=x, b=x) + return dict(mean=mean, observables=dict(a=x, b=x)) return FusionSimulator() - - -@pytest.fixture -def fusion_adapter(): - from bayesflow import Adapter - - return Adapter.create_default(["mean"]).group(["a", "b"], "summary_variables") diff --git a/tests/test_workflows/test_basic_workflow.py b/tests/test_workflows/test_basic_workflow.py index 50f0bd879..f6b75e711 100644 --- a/tests/test_workflows/test_basic_workflow.py +++ b/tests/test_workflows/test_basic_workflow.py @@ -36,14 +36,13 @@ def test_basic_workflow(tmp_path, inference_network, summary_network): assert samples["parameters"].shape == (5, 3, 2) -def test_basic_workflow_fusion( - tmp_path, fusion_inference_network, fusion_summary_network, fusion_simulator, fusion_adapter -): +def test_basic_workflow_fusion(tmp_path, fusion_inference_network, fusion_summary_network, fusion_simulator): workflow = bf.BasicWorkflow( - adapter=fusion_adapter, inference_network=fusion_inference_network, summary_network=fusion_summary_network, simulator=fusion_simulator, + inference_variables=["mean"], + summary_variables=["observables"], checkpoint_filepath=str(tmp_path), )