Skip to content

Commit 0cdd16a

Browse files
stefanradev93eodoleLarsKuevpratzhan-ol
authored
Release 2.0.4 (#510)
* Subset arrays (#411) * made initial backend functions for adapter subsetting, need to still make the squeeze function and link it to the front end * added subsample functionality, to do would be adding them to testing procedures * made the take function and ran the linter * changed name of subsampling function * changed documentation, to be consistent with external notation, rather than internal shorthand * small formation change to documentation * changed subsample to have sample size and axis in the constructor * moved transforms in the adapter.py so they're in alphabetical order like the other transforms * changed random_subsample to maptransform rather than filter transform * updated documentation with new naming convention * added arguments of take to the constructor * added feature to specify a percentage of the data to subsample rather than only integer input * changed subsample in adapter.py to allow float as an input for the sample size * renamed subsample_array and associated classes/functions to RandomSubsample and random_subsample respectively * included TypeError to force users to only subsample one dataset at a time * ran linter * rerun formatter * clean up random subsample transform and docs * clean up take transform and docs * nitpick clean-up * skip shape check for subsampled adapter transform inverse * fix serialization of new transforms * skip randomly subsampled key in serialization consistency check --------- Co-authored-by: LarsKue <lars@kuehmichel.de> * [no ci] docs: start of user guide - draft intro, gen models * [no ci] add draft for data processing section * [no ci] user guide: add stub on summary/inference networks * [no ci] user guide: add stub on additional topics * [no ci] add early stage disclaimer to user guide * pin dependencies in docs, fixes snowballstemmer error * fix: correct check for "no accepted samples" in rejection_sample Closes #466 * Stabilize MultivariateNormalScore by constraining initialization in PositiveDefinite link (#469) * Refactor fill_triangular_matrix * stable positive definite link, fix for #468 * Minor changes to docstring * Remove self.built=True that prevented registering layer norm in build() * np -> keras.ops * Augmentation (#470) * Remove old rounds data set, add documentation, and augmentation options to data sets * Enable augmentation to parts of the data or the whole data * Improve doc * Enable augmentations in workflow * Fix silly type check and improve readability of for loop * Bring back num_batches * Fixed log det jac computation of standardize transform y = (x - mu) / sigma log p(y) = log p(x) - log(sigma) * Fix fill_triangular_matrix The two lines were switched, leading to performance degradation. * Deal with inference_network.log_prob to return dict (as PointInferenceNetwork does) * Add diffusion model implementation (#408) This commit contains the following changes (see PR #408 for discussions) - DiffusionModel following the formalism in Kingma et. al (2023) [1] - Stochastic sampler to solve SDEs - Tests for the diffusion model [1] https://arxiv.org/abs/2303.00848 --------- Co-authored-by: arrjon <jonas.arruda@uni-bonn.de> Co-authored-by: Jonas Arruda <69197639+arrjon@users.noreply.github.com> Co-authored-by: LarsKue <lars@kuehmichel.de> * [no ci] networks docstrings: summary/inference network indicator (#462) - From the table in the `bayesflow.networks` module overview, one cannot tell which network belongs to which group. This commit adds short labels to indicate inference networks (IN) and summary networks (SN) * `ModelComparisonSimulator`: handle different outputs from individual simulators (#452) Adds option to drop, fill or error when different keys are encountered in the outputs of different simulators. Fixes #441. --------- Co-authored-by: Valentin Pratz <git@valentinpratz.de> * Add classes and transforms to simplify multimodal training (#473) * Add classes and transforms to simplify multimodal training - Add class `MultimodalSummaryNetwork` to combine multiple summary networks, each for one modality. - Add transforms `Group` and `Ungroup`, to gather the multimodal inputs in one variable (usually "summary_variables") - Add tests for new behavior * [no ci] add tutorial notebook for multimodal data * [no ci] add missing training argument * rename MultimodalSummaryNetwork to FusionNetwork * [no ci] clarify that the network implements late fusion * allow dispatch of summary/inference network from type * add tests for find_network * Add squeeze transform Very basic transform, just the inverse of expand_dims * [no ci] fix examples in ExpandDims docstring * squeeze: adapt example, add comment for changing batch dims * Permit Python version 3.12 (#474) Allow Python version 3.12 after successful CI run: https://github.com/bayesflow-org/bayesflow/actions/runs/14988542031 * Change order in readme and reference new book [skip ci] * make docs optional dependencies compatible with python 3.10 * Add a custom `Sequential` network to avoid issues with building and serialization in keras (#493) * add custom sequential to fix #491 * revert using Sequential in classifier_two_sample_test.py * Add docstring to custom Sequential Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix copilot docstring * remove mlp override methods --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Add Nnpe adapter class (#488) * Add NNPE adapter * Add NNPE adapter tests * Only apply NNPE during training * Integrate stage differentiation into tests * Improve test coverage * Fix inverse and add to tests * Adjust class name and add docstring to forward method * Enable compatibility with #486 by adjusting scales automatically * Add dimensionwise noise application * Update exception handling * Fix tests * Align diffusion model with other inference networks and remove deprecation warnings (#489) * Align dm implementation with other networks * Remove deprecation warning for using subnet_kwargs * Fix tests * Remove redundant training arg in get_alpha_sigma and some redundant comments * Fix configs creation - do not get base config due to fixed call of super().__init__() * Remove redundant training arg from tests * Fix dispatch tests for dms * Improve docs and mark option for x prediction in literal * Fix start/stop time * minor cleanup of refactory --------- Co-authored-by: Valentin Pratz <git@valentinpratz.de> * add replace nan adapter (#459) * add replace nan adapter * improved naming * _mask as additional key * update test * improve * fix serializable * changed name to return_mask * add mask naming * [no ci] docs: add basic likelihood estimation example Fixes #476. This is the barebones version showing the technical steps to do likelihood estimation. Adding more background and motivation would be nice. * make metrics serializable It seems that metrics do not store their state, I'm not sure yet if this is intended behavior. * Remove layer norm; add epsilon to std dev for stability of pos def link this breaks serialization of point estimation with MultivariateNormalScore * add end-to-end test for fusion network * fix: ensure that build is called in FusionNetwork * Correctly track train / validation losses (#485) * correctly track train / validation losses * remove mmd from two moons test * reenable metrics in continuous approximator, add trackers * readd custom metrics to two_moons test * take batch size into account when aggregating metrics * Add docs to backend approximator interfaces * Add small doc improvements * Fix typehints to docs. --------- Co-authored-by: Valentin Pratz <git@valentinpratz.de> Co-authored-by: stefanradev93 <stefan.radev93@gmail.com> * Add shuffle parameter to datasets Adds the option to disable data shuffling --------- Co-authored-by: Lars <lars@kuehmichel.de> Co-authored-by: Valentin Pratz <git@valentinpratz.de> * fix: correct vjp/jvp calls in FreeFormFlow The signature changed, making it necessary to set return_output=True * test: add basic compute_metrics test for inference networks * [no ci] extend point approximator tests - remove skip for MVN - add test for log-prob * [no ci] skip unstable MVN sample test again * update README with more specific install instructions * fix FreeFormFlow: remove superfluous index form signature change * [no ci] FreeFormFlow MLP defaults: set dropout to 0 * Better pairplots (#505) * Hacky fix for pairplots * Ensure that target sits in front of other elements * Ensure consistent spacing between plot and legends + cleanup * Update docs * Fix the propagation of `legend_fontsize` * Minor fix to comply with code style * [no ci] Formatting: escaped space only in raw strings * [no ci] fix typo in error message, model comparison approximator * [no ci] fix: size_of could not handle basic int/float Passing in basic types would lead to infinite recursion. Checks for other types than int and float might be necessary as well. * add tests for model comparison approximator * Generalize sample shape to arbitrary N-D arrays * [WIP] Move standardization into approximators and make adapter stateless. (#486) * Add standardization to continuous approximator and test * Fix init bugs, adapt tnotebooks * Add training flag to build_from_data * Fix inference conditions check * Fix tests * Remove unnecessary init calls * Add deprecation warning * Refactor compute metrics and add standardization to model comp * Fix standardization in cont approx * Fix sample keys -> condition keys * amazing keras fix * moving_mean and moving_std still not loading [WIP] * remove hacky approximator serialization test * fix building of models in tests * Fix standardization * Add standardizatrion to model comp and let it use inheritance * make assert_models/layers_equal more thorough * [no ci] use map_shape_structure to convert shapes to arrays This automatically takes care of nested structures. * Extend Standardization to support nested inputs (#501) * extend Standardization to nested inputs By using `keras.tree.flatten` und `keras.tree.pack_sequence_as`, we can support arbitrary nested structures. A `flatten_shape` function is introduced, analogous to `map_shape_structure`, for use in the build function. * keep tree utils in submodule * Streamline call * Fix typehint --------- Co-authored-by: stefanradev93 <stefan.radev93@gmail.com> * Update moments before transform and update test * Update notebooks * Refactor and simplify due to standardize * Add comment for fetching the dict's first item, deprecate logits arg and fix typehint * add missing import in test * Refactor preparation of data for networks and new point_appr.log_prob * ContinuousApproximator._prepare_data unifies all preparation in sample, log_prob and estimate for both ContinuousApproximator and PointApproximator * PointApproximator now overrides log_prob * Add class attributes to inform proper standardization * Implement stable moving mean and std * Adapt and fix tests * minor adaptations to moving average (update time, init) We should put the update before the standardization, to use the maximum amount of information available. We can then also initialize the moving M^2 with zero, as it will be filled immediately. The special case of M^2 = 0 is not problematic, as no variance automatically indicates that all entries are equal, and we can set them to zero (see my comment). I added another test case to cover that case, and added a test for the standard deviation to the existing test. * increase tolerance of allclose tests * [no ci] set trainable to False explicitly in ModelComparisonApproximator * point estimate of covariance compatible with standardization * properly set values to zero if std is zero Cases for inf and -inf were missing * fix sample post-processing in point approximator * activate tests for multivariate normal score * [no ci] undo prev commit: MVN test still not stable, was hidden by std of 0 * specify explicit build functions for approximators * set std for untrained standardization layer to one An untrained layer thereby does not modify the input. * [no ci] reformulate zero std case * approximator builds: add guards against building networks twice * [no ci] add comparison with loaded approx to workflow test * Cleanup and address building standardization layers when None specified * Cleanup and address building standardization layers when None specified 2 * Add default case for std transform and add transformation to doc. * adapt handling of the special case M^2=0 * [no ci] minor fix in concatenate_valid_shapes * [no ci] extend test suite for approximators * fixes for standardize=None case * skip unstable MVN score case * Better transformation types * Add test for both_sides_scale inverse standardization * Add test for left_side_scale inverse standardization * Remove flaky test failing due to sampling error * Fix input dtypes in inverse standardization transformation_type tests * Use concatenate_valid in _sample * Replace PositiveDefinite link with CholeskyFactor This finally makes the MVN score sampling test stable for the jax backend, for which the keras.ops.cholesky operation is numerically unstable. The score's sample method avoids calling keras.ops.cholesky to resolve the issue. Instead the estimation head returns the Cholesky factor directly rather than the covariance matrix (as it used to be). * Reintroduce test sampling with MVN score * Address TODOs and adapt docstrings and workflow * Adapt notebooks * Fix in model comparison * Update readme and add point estimation nb --------- Co-authored-by: LarsKue <lars@kuehmichel.de> Co-authored-by: Valentin Pratz <git@valentinpratz.de> Co-authored-by: Valentin Pratz <112951103+vpratz@users.noreply.github.com> Co-authored-by: han-ol <g@hans.olischlaeger.com> Co-authored-by: Hans Olischläger <106988117+han-ol@users.noreply.github.com> * Replace deprecation with FutureWarning * Adjust filename for LV * Fix types for subnets * [no ci] minor fixes to RandomSubsample transform * [no ci] remove subnet deprecation in cont-time CM * Remove empty file [no ci] * Revert layer type for coupling flow [skip ci] * remove failing import due to removed find_noise_schedule.py [no ci] * Add utility function for batched simulations (#511) The implementation is a simple wrapper leveraging the batching capabilities of `rejection_sample`. * Restore PositiveDefinite link with deprecation warning * skip cycle consistency test for diffusion models - the test is unstable for untrained diffusion models, as the networks output is not sufficiently smooth for the step size we use - remove the diffusion_model marker * Implement changes to NNPE adapter for #510 (#514) * Move docstring to comment * Always cast to _resolve_scale * Fix typo * [no ci] remove unnecessary serializable decorator on rmse * fix type hint in squeeze [no ci] * reintroduce comment in jax approximator [no ci] * remove unnecessary getattr calls [no ci] * Rename local variable transformation_type * fix error type in diffusion model [no ci] * remove non-functional per_training_step from plots.loss * Update doc [skip ci] * rename approximator.summaries to summarize with deprecation * address remaining comments --------- Co-authored-by: Leona Odole <88601208+eodole@users.noreply.github.com> Co-authored-by: LarsKue <lars@kuehmichel.de> Co-authored-by: Valentin Pratz <git@valentinpratz.de> Co-authored-by: Hans Olischläger <106988117+han-ol@users.noreply.github.com> Co-authored-by: han-ol <g@hans.olischlaeger.com> Co-authored-by: Valentin Pratz <112951103+vpratz@users.noreply.github.com> Co-authored-by: arrjon <jonas.arruda@uni-bonn.de> Co-authored-by: Jonas Arruda <69197639+arrjon@users.noreply.github.com> Co-authored-by: Simon Kucharsky <kucharssim@gmail.com> Co-authored-by: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Lasse Elsemüller <60779710+elseml@users.noreply.github.com> Co-authored-by: Jerry Huang <57327805+jerrymhuang@users.noreply.github.com>
1 parent f17817a commit 0cdd16a

File tree

145 files changed

+9686
-2939
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

145 files changed

+9686
-2939
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,6 @@ docs/
3939

4040
# MacOS
4141
.DS_Store
42+
43+
# Rproj
44+
.Rproj.user

README.md

Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -49,45 +49,12 @@ neural networks for parameter estimation, model comparison, and model validation
4949
when working with intractable simulators whose behavior as a whole is too
5050
complex to be described analytically.
5151

52-
## Getting Started
53-
54-
Using the high-level interface is easy, as demonstrated by the minimal working example below:
55-
56-
```python
57-
import bayesflow as bf
58-
59-
workflow = bf.BasicWorkflow(
60-
inference_network=bf.networks.CouplingFlow(),
61-
summary_network=bf.networks.TimeSeriesNetwork(),
62-
inference_variables=["parameters"],
63-
summary_variables=["observables"],
64-
simulator=bf.simulators.SIR()
65-
)
66-
67-
history = workflow.fit_online(epochs=15, batch_size=32, num_batches_per_epoch=200)
68-
69-
diagnostics = workflow.plot_default_diagnostics(test_data=300)
70-
```
71-
72-
For an in-depth exposition, check out our walkthrough notebooks below.
73-
74-
1. [Linear regression starter example](examples/Linear_Regression_Starter.ipynb)
75-
2. [From ABC to BayesFlow](examples/From_ABC_to_BayesFlow.ipynb)
76-
3. [Two moons starter example](examples/Two_Moons_Starter.ipynb)
77-
4. [Rapid iteration with point estimators](examples/Lotka_Volterra_Point_Estimation_and_Expert_Stats.ipynb)
78-
5. [SIR model with custom summary network](examples/SIR_Posterior_Estimation.ipynb)
79-
6. [Bayesian experimental design](examples/Bayesian_Experimental_Design.ipynb)
80-
7. [Simple model comparison example](examples/One_Sample_TTest.ipynb)
81-
8. [Moving from BayesFlow v1.1 to v2.0](examples/From_BayesFlow_1.1_to_2.0.ipynb)
82-
83-
More tutorials are always welcome! Please consider making a pull request if you have a cool application that you want to contribute.
84-
8552
## Install
8653

87-
You can install the latest stable version from PyPI using:
54+
We currently support Python 3.10 to 3.12. You can install the latest stable version from PyPI using:
8855

8956
```bash
90-
pip install bayesflow
57+
pip install "bayesflow>=2.0"
9158
```
9259

9360
If you want the latest features, you can install from source:
@@ -132,9 +99,47 @@ export KERAS_BACKEND=jax
13299

133100
This way, you also don't have to manually set the backend every time you are starting Python to use BayesFlow.
134101

135-
**Caution:** Some development environments (e.g., VSCode or PyCharm) can silently overwrite environment variables. If you have set your backend as an environment variable and you still get keras-related import errors when loading BayesFlow, these IDE shenanigans might be the culprit. Try setting the keras backend in your Python script via `import os; os.environ["KERAS_BACKEND"] = "<YOUR-BACKEND>"`.
102+
## Getting Started
103+
104+
Using the high-level interface is easy, as demonstrated by the minimal working example below:
105+
106+
```python
107+
import bayesflow as bf
108+
109+
workflow = bf.BasicWorkflow(
110+
inference_network=bf.networks.CouplingFlow(),
111+
summary_network=bf.networks.TimeSeriesNetwork(),
112+
inference_variables=["parameters"],
113+
summary_variables=["observables"],
114+
simulator=bf.simulators.SIR()
115+
)
116+
117+
history = workflow.fit_online(epochs=15, batch_size=32, num_batches_per_epoch=200)
118+
119+
diagnostics = workflow.plot_default_diagnostics(test_data=300)
120+
```
121+
122+
For an in-depth exposition, check out our expanding list of resources below.
123+
124+
### Books
125+
126+
Many examples from [Bayesian Cognitive Modeling: A Practical Course](https://bayesmodels.com/) by Lee & Wagenmakers (2013) in [BayesFlow](https://kucharssim.github.io/bayesflow-cognitive-modeling-book/).
127+
128+
### Tutorial notebooks
129+
130+
1. [Linear regression starter example](examples/Linear_Regression_Starter.ipynb)
131+
2. [From ABC to BayesFlow](examples/From_ABC_to_BayesFlow.ipynb)
132+
3. [Two moons starter example](examples/Two_Moons_Starter.ipynb)
133+
4. [Rapid iteration with point estimators](examples/Lotka_Volterra_Point_Estimation.ipynb)
134+
5. [SIR model with custom summary network](examples/SIR_Posterior_Estimation.ipynb)
135+
6. [Bayesian experimental design](examples/Bayesian_Experimental_Design.ipynb)
136+
7. [Simple model comparison example](examples/One_Sample_TTest.ipynb)
137+
8. [Likelihood estimation](examples/Likelihood_Estimation.ipynb)
138+
9. [Moving from BayesFlow v1.1 to v2.0](examples/From_BayesFlow_1.1_to_2.0.ipynb)
139+
140+
More tutorials are always welcome! Please consider making a pull request if you have a cool application that you want to contribute.
136141

137-
### From Source
142+
## Contributing
138143

139144
If you want to contribute to BayesFlow, we recommend installing it from source, see [CONTRIBUTING.md](CONTRIBUTING.md) for more details.
140145

bayesflow/adapters/adapter.py

Lines changed: 198 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,24 @@
1414
Drop,
1515
ExpandDims,
1616
FilterTransform,
17+
Group,
1718
Keep,
1819
Log,
1920
MapTransform,
21+
NNPE,
2022
NumpyTransform,
2123
OneHot,
2224
Rename,
2325
SerializableCustomTransform,
26+
Squeeze,
2427
Sqrt,
2528
Standardize,
2629
ToArray,
2730
Transform,
31+
Ungroup,
32+
RandomSubsample,
33+
Take,
34+
NanToNum,
2835
)
2936
from .transforms.filter_transform import Predicate
3037

@@ -598,6 +605,52 @@ def expand_dims(self, keys: str | Sequence[str], *, axis: int | tuple):
598605
self.transforms.append(transform)
599606
return self
600607

608+
def group(self, keys: Sequence[str], into: str, *, prefix: str = ""):
609+
"""Append a :py:class:`~transforms.Group` transform to the adapter.
610+
611+
Groups the given variables as a dictionary in the key `into`. As most transforms do
612+
not support nested structures, this should usually be the last transform in the adapter.
613+
614+
Parameters
615+
----------
616+
keys : Sequence of str
617+
The names of the variables to group together.
618+
into : str
619+
The name of the variable to store the grouped variables in.
620+
prefix : str, optional
621+
An optional common prefix of the variable names before grouping, which will be removed after grouping.
622+
623+
Raises
624+
------
625+
ValueError
626+
If a prefix is specified, but a provided key does not start with the prefix.
627+
"""
628+
if isinstance(keys, str):
629+
keys = [keys]
630+
631+
transform = Group(keys=keys, into=into, prefix=prefix)
632+
self.transforms.append(transform)
633+
return self
634+
635+
def ungroup(self, key: str, *, prefix: str = ""):
636+
"""Append an :py:class:`~transforms.Ungroup` transform to the adapter.
637+
638+
Ungroups the the variables in `key` from a dictionary into individual entries. Most transforms do
639+
not support nested structures, so this can be used to flatten a nested structure.
640+
The nesting can be re-established after the transforms using the :py:meth:`group` method.
641+
642+
Parameters
643+
----------
644+
key : str
645+
The name of the variable to ungroup. The corresponding variable has to be a dictionary.
646+
prefix : str, optional
647+
An optional common prefix that will be added to the ungrouped variable names. This can be necessary
648+
to avoid duplicate names.
649+
"""
650+
transform = Ungroup(key=key, prefix=prefix)
651+
self.transforms.append(transform)
652+
return self
653+
601654
def keep(self, keys: str | Sequence[str]):
602655
"""Append a :py:class:`~transforms.Keep` transform to the adapter.
603656
@@ -648,6 +701,43 @@ def map_dtype(self, keys: str | Sequence[str], to_dtype: str):
648701
self.transforms.append(transform)
649702
return self
650703

704+
def nnpe(
705+
self,
706+
keys: str | Sequence[str],
707+
*,
708+
spike_scale: float | None = None,
709+
slab_scale: float | None = None,
710+
per_dimension: bool = True,
711+
seed: int | None = None,
712+
):
713+
"""Append an :py:class:`~transforms.NNPE` transform to the adapter.
714+
715+
Parameters
716+
----------
717+
keys : str or Sequence of str
718+
The names of the variables to transform.
719+
spike_scale : float or np.ndarray or None, default=None
720+
The scale of the spike (Normal) distribution. Automatically determined if None.
721+
slab_scale : float or np.ndarray or None, default=None
722+
The scale of the slab (Cauchy) distribution. Automatically determined if None.
723+
per_dimension : bool, default=True
724+
If true, noise is applied per dimension of the last axis of the input data.
725+
If false, noise is applied globally.
726+
seed : int or None
727+
The seed for the random number generator. If None, a random seed is used.
728+
"""
729+
if isinstance(keys, str):
730+
keys = [keys]
731+
732+
transform = MapTransform(
733+
{
734+
key: NNPE(spike_scale=spike_scale, slab_scale=slab_scale, per_dimension=per_dimension, seed=seed)
735+
for key in keys
736+
}
737+
)
738+
self.transforms.append(transform)
739+
return self
740+
651741
def one_hot(self, keys: str | Sequence[str], num_classes: int):
652742
"""Append a :py:class:`~transforms.OneHot` transform to the adapter.
653743
@@ -665,6 +755,28 @@ def one_hot(self, keys: str | Sequence[str], num_classes: int):
665755
self.transforms.append(transform)
666756
return self
667757

758+
def random_subsample(self, key: str, *, sample_size: int | float, axis: int = -1):
759+
"""
760+
Append a :py:class:`~transforms.RandomSubsample` transform to the adapter.
761+
762+
Parameters
763+
----------
764+
key : str or Sequence of str
765+
The name of the variable to subsample.
766+
sample_size : int or float
767+
The number of samples to draw, or a fraction between 0 and 1 of the total number of samples to draw.
768+
axis: int, optional
769+
Which axis to draw samples over. The last axis is used by default.
770+
"""
771+
772+
if not isinstance(key, str):
773+
raise TypeError("Can only subsample one batch entry at a time.")
774+
775+
transform = MapTransform({key: RandomSubsample(sample_size=sample_size, axis=axis)})
776+
777+
self.transforms.append(transform)
778+
return self
779+
668780
def rename(self, from_key: str, to_key: str):
669781
"""Append a :py:class:`~transforms.Rename` transform to the adapter.
670782
@@ -708,6 +820,24 @@ def split(self, key: str, *, into: Sequence[str], indices_or_sections: int | Seq
708820

709821
return self
710822

823+
def squeeze(self, keys: str | Sequence[str], *, axis: int | Sequence[int]):
824+
"""Append a :py:class:`~transforms.Squeeze` transform to the adapter.
825+
826+
Parameters
827+
----------
828+
keys : str or Sequence of str
829+
The names of the variables to squeeze.
830+
axis : int or tuple
831+
The axis to squeeze. As the number of batch dimensions might change, we advise using negative
832+
numbers (i.e., indexing from the end instead of the start).
833+
"""
834+
if isinstance(keys, str):
835+
keys = [keys]
836+
837+
transform = MapTransform({key: Squeeze(axis=axis) for key in keys})
838+
self.transforms.append(transform)
839+
return self
840+
711841
def sqrt(self, keys: str | Sequence[str]):
712842
"""Append an :py:class:`~transforms.Sqrt` transform to the adapter.
713843
@@ -741,7 +871,7 @@ def standardize(
741871
Names of variables to include in the transform.
742872
exclude : str or Sequence of str, optional
743873
Names of variables to exclude from the transform.
744-
**kwargs : dict
874+
**kwargs :
745875
Additional keyword arguments passed to the transform.
746876
"""
747877
transform = FilterTransform(
@@ -754,6 +884,42 @@ def standardize(
754884
self.transforms.append(transform)
755885
return self
756886

887+
def take(
888+
self,
889+
include: str | Sequence[str] = None,
890+
*,
891+
indices: Sequence[int],
892+
axis: int = -1,
893+
predicate: Predicate = None,
894+
exclude: str | Sequence[str] = None,
895+
):
896+
"""
897+
Append a :py:class:`~transforms.Take` transform to the adapter.
898+
899+
Parameters
900+
----------
901+
include : str or Sequence of str, optional
902+
Names of variables to include in the transform.
903+
indices : Sequence of int
904+
Which indices to take from the data.
905+
axis : int, optional
906+
Which axis to take from. The last axis is used by default.
907+
predicate : Predicate, optional
908+
Function that indicates which variables should be transformed.
909+
exclude : str or Sequence of str, optional
910+
Names of variables to exclude from the transform.
911+
"""
912+
transform = FilterTransform(
913+
transform_constructor=Take,
914+
predicate=predicate,
915+
include=include,
916+
exclude=exclude,
917+
indices=indices,
918+
axis=axis,
919+
)
920+
self.transforms.append(transform)
921+
return self
922+
757923
def to_array(
758924
self,
759925
include: str | Sequence[str] = None,
@@ -791,3 +957,34 @@ def to_dict(self):
791957
transform = ToDict()
792958
self.transforms.append(transform)
793959
return self
960+
961+
def nan_to_num(
962+
self,
963+
keys: str | Sequence[str],
964+
default_value: float = 0.0,
965+
return_mask: bool = False,
966+
mask_prefix: str = "mask",
967+
):
968+
"""
969+
Append :py:class:`~bf.adapters.transforms.NanToNum` transform to the adapter.
970+
971+
Parameters
972+
----------
973+
keys : str or sequence of str
974+
The names of the variables to clean / mask.
975+
default_value : float
976+
Value to substitute wherever data is NaN. Defaults to 0.0.
977+
return_mask : bool
978+
If True, encode a binary missingness mask alongside the data. Defaults to False.
979+
mask_prefix : str
980+
Prefix for the mask key in the output dictionary. Defaults to 'mask_'. If the mask key already exists,
981+
a ValueError is raised to avoid overwriting existing masks.
982+
"""
983+
if isinstance(keys, str):
984+
keys = [keys]
985+
986+
for key in keys:
987+
self.transforms.append(
988+
NanToNum(key=key, default_value=default_value, return_mask=return_mask, mask_prefix=mask_prefix)
989+
)
990+
return self

bayesflow/adapters/transforms/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,28 @@
88
from .elementwise_transform import ElementwiseTransform
99
from .expand_dims import ExpandDims
1010
from .filter_transform import FilterTransform
11+
from .group import Group
1112
from .keep import Keep
1213
from .log import Log
1314
from .map_transform import MapTransform
15+
from .nnpe import NNPE
1416
from .numpy_transform import NumpyTransform
1517
from .one_hot import OneHot
1618
from .rename import Rename
1719
from .scale import Scale
1820
from .serializable_custom_transform import SerializableCustomTransform
1921
from .shift import Shift
2022
from .split import Split
23+
from .squeeze import Squeeze
2124
from .sqrt import Sqrt
2225
from .standardize import Standardize
2326
from .to_array import ToArray
2427
from .to_dict import ToDict
2528
from .transform import Transform
29+
from .random_subsample import RandomSubsample
30+
from .take import Take
31+
from .ungroup import Ungroup
32+
from .nan_to_num import NanToNum
2633

2734
from ...utils._docs import _add_imports_to_all
2835

0 commit comments

Comments
 (0)