Skip to content

Commit 822ad89

Browse files
LarsKuevpratzstefanradev93marvinschmitt
authored
v2.0.2 (#447)
* [no ci] notebook tests: increase timeout, fix platform/backend dependent code Torch is very slow, so I had to increase the timeout accordingly. * Enable use of summary networks with functional API again (#434) * summary networks: add tests for using functional API * fix build functions for use with functional API * [no ci] docs: add GitHub and Discourse links, reorder navbar * [no ci] docs: acknowledge scikit-learn website * [no ci] docs: capitalize navigation headings * More tests (#437) * fix docs of coupling flow * add additional tests * Automatically run slow tests when main is involved. (#438) In addition, this PR limits the slow test to Windows and Python 3.10. The choices are somewhat arbitrary, my thought was to test the setup not covered as much through use by the devs. * Update dispatch * Update dispatching distributions * Improve workflow tests with multiple summary nets / approximators * Fix zombie find_distribution import * Add readme entry [no ci] * Update README: NumFOCUS affiliation, awesome-abi list (#445) * fix is_symbolic_tensor * remove multiple batch sizes, remove multiple python version tests, remove update-workflows branch from workflow style tests, add __init__ and conftest to test_point_approximators (#443) * implement compile_from_config and get_compile_config (#442) * implement compile_from_config and get_compile_config * add optimizer build to compile_from_config * Fix Optimal Transport for Compiled Contexts (#446) * remove the is_symbolic_tensor check because this would otherwise skip the whole function for compiled contexts * skip pyabc test * fix sinkhorn and log_sinkhorn message formatting for jax by making the warning message worse * update dispatch tests for more coverage * Update issue templates (#448) * Hotfix Version 2.0.1 (#431) * fix optimal transport config (#429) * run linter * [skip-ci] bump version to 2.0.1 * Update issue templates * Robustify kwargs passing inference networks, add class variables * fix convergence method to debug for non-log sinkhorn * Bump optimal transport default to False * use logging.info for backend selection instead of logging.debug * fix model comparison approximator * improve docs and type hints * improve One-Sample T-Test Notebook: - use torch as default backend - reduce range of N so users of jax won't be stuck with a slow notebook - use BayesFlow built-in MLP instead of keras.Sequential solution - general code cleanup * remove backend print * [skip ci] turn all single-quoted strings into double-quoted strings * turn all single-quoted strings into double-quoted strings amend to trigger workflow --------- Co-authored-by: Valentin Pratz <git@valentinpratz.de> Co-authored-by: Valentin Pratz <112951103+vpratz@users.noreply.github.com> Co-authored-by: stefanradev93 <stefan.radev93@gmail.com> Co-authored-by: Marvin Schmitt <35921281+marvinschmitt@users.noreply.github.com>
1 parent d31a761 commit 822ad89

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+1061
-579
lines changed

.github/ISSUE_TEMPLATE/bug_report.md

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
---
2+
name: Bug report
3+
about: Create a bug report to help us improve BayesFlow
4+
title: "[BUG]"
5+
labels: ''
6+
assignees: ''
7+
8+
---
9+
10+
**Describe the bug**
11+
A clear and concise description of what the bug is.
12+
13+
**To Reproduce**
14+
Minimal steps to reproduce the behavior:
15+
1. Import '...'
16+
2. Create network '....'
17+
3. Call '....'
18+
4. See error
19+
20+
**Expected behavior**
21+
A clear and concise description of what you expected to happen.
22+
23+
**Traceback**
24+
If you encounter an error, please provide a complete traceback to help explain your problem.
25+
26+
**Environment**
27+
- OS: [e.g. Ubuntu]
28+
- Python Version: [e.g. 3.11]
29+
- Backend: [e.g. jax, tensorflow, pytorch]
30+
- BayesFlow Version: [e.g. 2.0.2]
31+
32+
**Additional context**
33+
Add any other context about the problem here.
34+
35+
**Minimality**
36+
- [ ] I verify that my example is minimal, does not rely on third-party packages, and is most likely an issue in BayesFlow.
+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
---
2+
name: Feature request
3+
about: Suggest a new feature to be implemented in BayesFlow
4+
title: "[FEATURE]"
5+
labels: feature
6+
assignees: ''
7+
8+
---
9+
10+
**Is your feature request related to a problem? Please describe.**
11+
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12+
13+
**Describe the solution you'd like**
14+
A clear and concise description of what you want to happen.
15+
16+
**Describe alternatives you've considered**
17+
A clear and concise description of any alternative solutions or features you've considered.
18+
19+
**Additional context**
20+
Add any other context or screenshots about the feature request here.

.github/workflows/style.yaml

-2
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,10 @@ on:
66
branches:
77
- main
88
- dev
9-
- update-workflows
109
push:
1110
branches:
1211
- main
1312
- dev
14-
- update-workflows
1513

1614
jobs:
1715
check-code-style:

.github/workflows/tests.yaml

+6-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
strategy:
2525
matrix:
2626
os: [ubuntu-latest, windows-latest]
27-
python-version: ["3.10", "3.11"]
27+
python-version: ["3.10"] # we usually only need to test the oldest python version
2828
backend: ["jax", "tensorflow", "torch"]
2929

3030
runs-on: ${{ matrix.os }}
@@ -73,8 +73,11 @@ jobs:
7373
pytest -x -m "not slow"
7474
7575
- name: Run Slow Tests
76-
# run all slow tests only on manual trigger
77-
if: github.event_name == 'workflow_dispatch'
76+
# Run slow tests on manual trigger and pushes/PRs to main.
77+
# Limit to one OS and Python version to save compute.
78+
# Multiline if statements are weird, https://github.com/orgs/community/discussions/25641,
79+
# but feel free to convert it.
80+
if: ${{ ((github.event_name == 'workflow_dispatch') || (github.event_name == 'push' && github.ref_name == 'main') || (github.event_name == 'pull_request' && github.base_ref == 'main')) && ((matrix.os == 'windows-latest') && (matrix.python-version == '3.10')) }}
7881
run: |
7982
pytest -m "slow"
8083

README.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
![Codecov](https://img.shields.io/codecov/c/github/bayesflow-org/bayesflow?style=for-the-badge&link=https%3A%2F%2Fapp.codecov.io%2Fgh%2Fbayesflow-org%2Fbayesflow%2Ftree%2Fmain)
44
[![DOI](https://img.shields.io/badge/DOI-10.21105%2Fjoss.05702-blue?style=for-the-badge)](https://doi.org/10.21105/joss.05702)
55
![PyPI - License](https://img.shields.io/pypi/l/bayesflow?style=for-the-badge)
6+
![NumFOCUS Affiliated Project](https://img.shields.io/badge/NumFOCUS-Affiliated%20Project-orange?style=for-the-badge)
67

78
BayesFlow is a Python library for simulation-based **Amortized Bayesian Inference** with neural networks.
89
It provides users and researchers with:
@@ -225,8 +226,10 @@ You can find and install the old Bayesflow version via the `stable-legacy` branc
225226

226227
## Awesome Amortized Inference
227228

228-
If you are interested in a curated list of resources, including reviews, software, papers, and other resources related to amortized inference, feel free to explore our [community-driven list](https://github.com/bayesflow-org/awesome-amortized-inference).
229+
If you are interested in a curated list of resources, including reviews, software, papers, and other resources related to amortized inference, feel free to explore our [community-driven list](https://github.com/bayesflow-org/awesome-amortized-inference). If you'd like a paper (by yourself or someone else) featured, please add it to the list with a pull request, an issue, or a message to the maintainers.
229230

230231
## Acknowledgments
231232

232233
This project is currently managed by researchers from Rensselaer Polytechnic Institute, TU Dortmund University, and Heidelberg University. It is partially funded by the Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) Projects 528702768 and 508399956. The project is further supported by Germany's Excellence Strategy -- EXC-2075 - 390740016 (Stuttgart Cluster of Excellence SimTech) and EXC-2181 - 390900948 (Heidelberg Cluster of Excellence STRUCTURES), the collaborative research cluster TRR 391 – 520388526, as well as the Informatics for Life initiative funded by the Klaus Tschira Foundation.
234+
235+
BayesFlow is a [NumFOCUS Affiliated Project](https://numfocus.org/sponsored-projects/affiliated-projects).

bayesflow/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def setup():
3333

3434
from bayesflow.utils import logging
3535

36-
logging.debug(f"Using backend {keras.backend.backend()!r}")
36+
logging.info(f"Using backend {keras.backend.backend()!r}")
3737

3838
if keras.backend.backend() == "torch":
3939
import torch

bayesflow/approximators/approximator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def build_adapter(cls, **kwargs) -> Adapter:
2323
raise NotImplementedError
2424

2525
def build_from_data(self, data: Mapping[str, any]) -> None:
26-
self.compute_metrics(**data, stage="training")
26+
self.compute_metrics(**filter_kwargs(data, self.compute_metrics), stage="training")
2727
self.built = True
2828

2929
@classmethod

bayesflow/approximators/continuous_approximator.py

+33-3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ class ContinuousApproximator(Approximator):
3232
Additional arguments passed to the :py:class:`bayesflow.approximators.Approximator` class.
3333
"""
3434

35+
SAMPLE_KEYS = ["summary_variables", "inference_conditions"]
36+
3537
def __init__(
3638
self,
3739
*,
@@ -51,6 +53,7 @@ def build_adapter(
5153
inference_variables: Sequence[str],
5254
inference_conditions: Sequence[str] = None,
5355
summary_variables: Sequence[str] = None,
56+
standardize: bool = True,
5457
sample_weight: str = None,
5558
) -> Adapter:
5659
"""Create an :py:class:`~bayesflow.adapters.Adapter` suited for the approximator.
@@ -63,9 +66,12 @@ def build_adapter(
6366
Names of the inference conditions in the data
6467
summary_variables : Sequence of str, optional
6568
Names of the summary variables in the data
69+
standardize : bool, optional
70+
Decide whether to standardize all variables, default is True
6671
sample_weight : str, optional
6772
Name of the sample weights
6873
"""
74+
6975
adapter = Adapter()
7076
adapter.to_array()
7177
adapter.convert_dtype("float64", "float32")
@@ -82,7 +88,9 @@ def build_adapter(
8288
adapter = adapter.rename(sample_weight, "sample_weight")
8389

8490
adapter.keep(["inference_variables", "inference_conditions", "summary_variables", "sample_weight"])
85-
adapter.standardize(exclude="sample_weight")
91+
92+
if standardize:
93+
adapter.standardize(exclude="sample_weight")
8694

8795
return adapter
8896

@@ -104,6 +112,12 @@ def compile(
104112

105113
return super().compile(*args, **kwargs)
106114

115+
def compile_from_config(self, config):
116+
self.compile(**deserialize(config))
117+
if hasattr(self, "optimizer") and self.built:
118+
# Create optimizer variables.
119+
self.optimizer.build(self.trainable_variables)
120+
107121
def compute_metrics(
108122
self,
109123
inference_variables: Tensor,
@@ -213,6 +227,16 @@ def get_config(self):
213227

214228
return base_config | serialize(config)
215229

230+
def get_compile_config(self):
231+
base_config = super().get_compile_config() or {}
232+
233+
config = {
234+
"inference_metrics": self.inference_network._metrics,
235+
"summary_metrics": self.summary_network._metrics if self.summary_network is not None else None,
236+
}
237+
238+
return base_config | serialize(config)
239+
216240
def estimate(
217241
self,
218242
conditions: Mapping[str, np.ndarray],
@@ -318,12 +342,18 @@ def sample(
318342
dict[str, np.ndarray]
319343
Dictionary containing generated samples with the same keys as `conditions`.
320344
"""
345+
346+
# Apply adapter transforms to raw simulated / real quantities
321347
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
322-
# at inference time, inference_variables are estimated by the networks and thus ignored in conditions
323-
conditions.pop("inference_variables", None)
348+
349+
# Ensure only keys relevant for sampling are present in the conditions dictionary
350+
conditions = {k: v for k, v in conditions.items() if k in ContinuousApproximator.SAMPLE_KEYS}
351+
324352
conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
325353
conditions = {"inference_variables": self._sample(num_samples=num_samples, **conditions, **kwargs)}
326354
conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions)
355+
356+
# Back-transform quantities and samples
327357
conditions = self.adapter(conditions, inverse=True, strict=False, **kwargs)
328358

329359
if split:

bayesflow/approximators/model_comparison_approximator.py

+25-3
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,13 @@ class ModelComparisonApproximator(Approximator):
3030
The network backbone (e.g, an MLP) that is used for model classification.
3131
The input of the classifier network is created by concatenating `classifier_variables`
3232
and (optional) output of the summary_network.
33-
summary_network: bg.networks.SummaryNetwork, optional
33+
summary_network: bf.networks.SummaryNetwork, optional
3434
The summary network used for data summarization (default is None).
3535
The input of the summary network is `summary_variables`.
3636
"""
3737

38+
SAMPLE_KEYS = ["summary_variables", "classifier_conditions"]
39+
3840
def __init__(
3941
self,
4042
*,
@@ -118,6 +120,12 @@ def compile(
118120

119121
return super().compile(*args, **kwargs)
120122

123+
def compile_from_config(self, config):
124+
self.compile(**deserialize(config))
125+
if hasattr(self, "optimizer") and self.built:
126+
# Create optimizer variables.
127+
self.optimizer.build(self.trainable_variables)
128+
121129
def compute_metrics(
122130
self,
123131
*,
@@ -262,6 +270,16 @@ def get_config(self):
262270

263271
return base_config | serialize(config)
264272

273+
def get_compile_config(self):
274+
base_config = super().get_compile_config() or {}
275+
276+
config = {
277+
"classifier_metrics": self.classifier_network._metrics,
278+
"summary_metrics": self.summary_network._metrics if self.summary_network is not None else None,
279+
}
280+
281+
return base_config | serialize(config)
282+
265283
def predict(
266284
self,
267285
*,
@@ -288,9 +306,13 @@ def predict(
288306
np.ndarray
289307
Predicted posterior model probabilities given `conditions`.
290308
"""
309+
310+
# Apply adapter transforms to raw simulated / real quantities
291311
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
292-
# at inference time, model_indices are predicted by the networks and thus ignored in conditions
293-
conditions.pop("model_indices", None)
312+
313+
# Ensure only keys relevant for sampling are present in the conditions dictionary
314+
conditions = {k: v for k, v in conditions.items() if k in ModelComparisonApproximator.SAMPLE_KEYS}
315+
294316
conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
295317

296318
output = self._predict(**conditions, **kwargs)

bayesflow/approximators/point_approximator.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,10 @@ def log_prob(
156156

157157
def _prepare_conditions(self, conditions: Mapping[str, np.ndarray], **kwargs) -> dict[str, Tensor]:
158158
"""Adapts and converts the conditions to tensors."""
159+
159160
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
160-
conditions.pop("inference_variables", None)
161+
conditions = {k: v for k, v in conditions.items() if k in ContinuousApproximator.SAMPLE_KEYS}
162+
161163
return keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
162164

163165
def _apply_inverse_adapter_to_estimates(

bayesflow/distributions/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
from .diagonal_student_t import DiagonalStudentT
1010
from .mixture import Mixture
1111

12-
from .find_distribution import find_distribution
13-
1412
from ..utils._docs import _add_imports_to_all
1513

1614
_add_imports_to_all(include_modules=[])

0 commit comments

Comments
 (0)