Skip to content

Commit 29416f6

Browse files
manuelgloecklerjaivardhankapoorjanfb
authored
feat: improving flow and score matching API and nets (#1544)
* shift VectorFieldNet type from estimators/flowmatching_estimator.py to utils/vector_field_utils.py * refactor: update imports and enhance ConditionalScoreEstimator - changed net type in ConditionalScoreEstimator and related classes to VectorFieldNet - added embedding_net to ConditionalScoreEstimator for condition embedding * Update z-score parameters in flowmatching and posterior score neural networks, keeping others as independent * refactor: update vector field neural network architecture - replaced flowmatcher and score estimator imports with vector field equivalents - introduced new vector field neural network builders for MLP and transformer architectures - enhanced flowmatching_nn and posterior_score_nn functions to support new model types - added custom euler integration method in FlowMatchingEstimator for improved sampling - updated z-score handling and embedding net integration in estimator classes Remaining bugs: - Extra num_samples dim in Zuko sampling function that we need to fix * update: fix handling of shapes during sampling for npse and fmpe - increased hidden_features from 50 to 64 in posterior_score_nn to make it divisible by num_heads - added vector_field_fn to FlowMatchingEstimator for better dimension handling during sampling - implemented reshaping logic in ConditionalScoreEstimator to manage sample dimensions during batch processing * refactor: improve docstring formatting and remove commented-out code * refactor: integrate embedding net in flowmatching estimator - replaced direct condition input with embedded condition in log_prob, sample, and sample_and_log_prob methods - removed commented-out code and print statements for cleaner implementation - updated score estimator import to use posterior_score_nn in tests/score_samplers_test.py * refactor: update flowmatching estimator and vector field network initialization - removed noise_scale parameter from FlowMatchingEstimator - adjusted vector field calculation in FlowMatchingEstimator for improved accuracy - modified last layer initialization in AdaMLPBlock and DiTBlock to scale weights instead of zeroing - streamlined MLP block processing in GlobalEmbeddingMLP for clarity * Fix some problems * To run CI, comment out broken tests * Updates: Internal nets should be shared, but Estimator builders should be seperate! (as they have different preconditionrs) * Unify shape handling in score and flow. Add tests for consistentency of score and flow estimators * Formating to get CI going (failing tests expected) * Some small fixes and refactorings * Fix ruff things * Fixing score sampler tests with new net builder API * Fixing flow estimator bugs * Bug hunting + fixing * Rearrange trainers + fixing tests to not use "special" hyperparameters to test (i.e .use defaults). * Fix ruff * Fixing failing tests * Fix validation loss check * Fix for new FMPE args * Consistent naming for FMPE and NPSE * Bug fix for Neural ODE sampling with ScoreEstimators * less numsims for vfestimators in tests * Test changes * Bug fix, bounded epochs on default??? Add a better convergence chechk... * Remove print * new mlp which performs better... * Allow setting num_sims in minisbibm for eval * Add arg for num sims, cache results by default * some refactorings * Fix formating * Formatting, make defaults more uniform * Make factories more SBI-like * Some estimator tests added * Formatting, fix kwargs errors, more tests * Fix tests and init transformer last layer as zero * Remove test jupyter :/ * Formatting, refactoring tests * Fix pyright * Remove what is expected to fail * Minor fixes * Small docstring update * Backward compatiblity warnings from some unused kwargs * Typing with vectorfield net * Simplify score estimator * Updates * Fixing transformer with cross attn * Add error msg for unsupported shapes * Better tests * refactored tests * Reverting wierd reshapings in score estimator. Removing code duplication on embedding net handing * Fix formating issues * Fixing inconsistencies * Fixing pyright * Fix embedding_net not passed * Fix embedding net bug * Remove redundant "num_blocks" * Adding some degree of backward compatibility on user interface. * Fixing failing test on new convergence check * Add transformer to bm * Must be okay that the files already exits bm * Fix merge bug. Add deprecation warnings for Score estimator keyword argument in NPSE * Fixing transformers... (no pos emb. and others) * Refactorings and tunings * deprecation warnings and small refactorings * Backwards compatibility * Move score_estimator tests to vf_estimator_tests, run doc notebook once * remove random wierd comment * Remove tolerance special cases * Consistent naming * Faster convergence for slighly worse performance * Backward compatibility for imports of NPSE and FMPE * Docstring update * Imporve docstrings * Backward compatibility * Use new keywords * Format * Add missing headers * Update sbi/inference/trainers/vfpe/base_vf_inference.py Co-authored-by: Jan <janfb@users.noreply.github.com> * Update sbi/inference/trainers/fmpe/__init__.py Co-authored-by: Jan <janfb@users.noreply.github.com> * Update sbi/inference/trainers/npse/__init__.py Co-authored-by: Jan <janfb@users.noreply.github.com> * Update sbi/inference/trainers/vfpe/fmpe.py Co-authored-by: Jan <janfb@users.noreply.github.com> * Update sbi/inference/trainers/vfpe/fmpe.py Co-authored-by: Jan <janfb@users.noreply.github.com> * Update sbi/inference/trainers/vfpe/npse.py Co-authored-by: Jan <janfb@users.noreply.github.com> * Update sbi/inference/trainers/vfpe/npse.py Co-authored-by: Jan <janfb@users.noreply.github.com> * Update sbi/neural_nets/__init__.py Co-authored-by: Jan <janfb@users.noreply.github.com> * Update sbi/neural_nets/estimators/flowmatching_estimator.py Co-authored-by: Jan <janfb@users.noreply.github.com> * Update sbi/neural_nets/estimators/score_estimator.py Co-authored-by: Jan <janfb@users.noreply.github.com> * Update sbi/neural_nets/factory.py Co-authored-by: Jan <janfb@users.noreply.github.com> * Update sbi/neural_nets/factory.py Co-authored-by: Jan <janfb@users.noreply.github.com> * Update sbi/neural_nets/net_builders/vector_field_nets.py Co-authored-by: Jan <janfb@users.noreply.github.com> * Update sbi/neural_nets/net_builders/vector_field_nets.py Co-authored-by: Jan <janfb@users.noreply.github.com> * Update sbi/neural_nets/factory.py Co-authored-by: Jan <janfb@users.noreply.github.com> * Update tests/bm_test.py Co-authored-by: Jan <janfb@users.noreply.github.com> * Update tests/bm_test.py Co-authored-by: Jan <janfb@users.noreply.github.com> * Update tests/bm_test.py Co-authored-by: Jan <janfb@users.noreply.github.com> * Update tests/bm_test.py Co-authored-by: Jan <janfb@users.noreply.github.com> * Add nugget as keyward argument to train * Imporve converged docstring * Better typing and docstrings and so on * docstring * Add context * Extended docstring * move protocol * Formating fix * Revert "move protocol" This reverts commit 7eeaa7d. * fix formating * removing deprecated * Fix typing * Positional argument for default builder model name * Formating * unify nets test * fixing builder * update notebooks * Formating and some text updates * formating * fix deprecation warning on default args * unnecessary * remove unecessary notes * refactor check for deprecation warning * fix mcmc params passing in test * Fix mnle_test * add missing import * Notebooks rerun without warning and with striped notebook outputs --------- Co-authored-by: Jaivardhan Kapoor <jaivardhan.kapoor@gmail.com> Co-authored-by: Jan <janfb@users.noreply.github.com> Co-authored-by: Jan <jan.boelts@mailbox.org>
1 parent 7064286 commit 29416f6

30 files changed

+2680
-1182
lines changed

docs/advanced_tutorials/19_flowmatching_and_scorematching.ipynb

Lines changed: 285 additions & 15 deletions
Large diffs are not rendered by default.

docs/advanced_tutorials/20_score_based_methods_new_features.ipynb

Lines changed: 98 additions & 23 deletions
Large diffs are not rendered by default.

sbi/inference/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@
77
check_if_proposal_has_default_x,
88
infer,
99
)
10-
from sbi.inference.trainers.fmpe import FMPE
1110
from sbi.inference.trainers.marginal import MarginalTrainer
1211
from sbi.inference.trainers.nle import MNLE, NLE_A
1312
from sbi.inference.trainers.npe import MNPE, NPE_A, NPE_B, NPE_C # noqa: F401
14-
from sbi.inference.trainers.npse import NPSE
1513
from sbi.inference.trainers.nre import BNRE, NRE_A, NRE_B, NRE_C # noqa: F401
14+
from sbi.inference.trainers.vfpe import FMPE, NPSE
1615

1716
SNL = SNLE = SNLE_A = NLE = NLE_A
1817
_nle_family = ["NLE"]

sbi/inference/trainers/base.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -673,13 +673,19 @@ def _raise_deprecation_warning(
673673
"""
674674

675675
deprecated_params = deprecated_params.copy()
676-
677-
is_default_mcmc_method = kwargs.get("mcmc_method") == "slice_np_vectorized"
678-
is_default_vi_method = kwargs.get("vi_method") == "rKL"
679-
680-
if not is_default_mcmc_method:
676+
default_mcmc_method = "slice_np_vectorized"
677+
default_vi_method = "rKL"
678+
679+
# Check if deprecated parameters are used
680+
if (
681+
kwargs.get("mcmc_method") == default_mcmc_method
682+
or kwargs.get("mcmc_method") is None
683+
):
681684
deprecated_params.append("mcmc_method")
682-
if not is_default_vi_method:
685+
if (
686+
kwargs.get("vi_method") == default_vi_method
687+
or kwargs.get("vi_method") is None
688+
):
683689
deprecated_params.append("vi_method")
684690

685691
if deprecated_params:

sbi/inference/trainers/npse/__init__.py

Lines changed: 0 additions & 4 deletions
This file was deleted.
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
3-
4-
from sbi.inference.trainers.fmpe.fmpe import FMPE
3+
from sbi.inference.trainers.vfpe.fmpe import FMPE
4+
from sbi.inference.trainers.vfpe.npse import NPSE

sbi/inference/trainers/npse/vector_field_inference.py renamed to sbi/inference/trainers/vfpe/base_vf_inference.py

Lines changed: 93 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import time
55
from abc import ABC, abstractmethod
66
from copy import deepcopy
7-
from typing import Any, Callable, Optional, Protocol, Tuple, Union
7+
from typing import Any, Callable, Literal, Optional, Protocol, Tuple, Union
88

99
import torch
1010
from torch import Tensor, ones
@@ -60,7 +60,10 @@ class VectorFieldTrainer(NeuralInference, ABC):
6060
def __init__(
6161
self,
6262
prior: Optional[Distribution] = None,
63-
vector_field_estimator_builder: Union[str, VectorFieldEstimatorBuilder] = "mlp",
63+
vector_field_estimator_builder: Union[
64+
Literal["mlp", "ada_mlp", "transformer", "transformer_cross_attn"],
65+
VectorFieldEstimatorBuilder,
66+
] = "mlp",
6467
device: str = "cpu",
6568
logging_level: Union[int, str] = "WARNING",
6669
summary_writer: Optional[SummaryWriter] = None,
@@ -106,15 +109,19 @@ def __init__(
106109
check_estimator_arg(vector_field_estimator_builder)
107110
if isinstance(vector_field_estimator_builder, str):
108111
self._build_neural_net = self._build_default_nn_fn(
109-
vector_field_estimator_builder=vector_field_estimator_builder, **kwargs
112+
model=vector_field_estimator_builder, **kwargs
110113
)
111114
else:
112115
self._build_neural_net = vector_field_estimator_builder
113116

114117
self._proposal_roundwise = []
115118

116119
@abstractmethod
117-
def _build_default_nn_fn(self, **kwargs) -> VectorFieldEstimatorBuilder:
120+
def _build_default_nn_fn(
121+
self,
122+
model: Literal["mlp", "ada_mlp", "transformer", "transformer_cross_attn"],
123+
**kwargs,
124+
) -> VectorFieldEstimatorBuilder:
118125
pass
119126

120127
def append_simulations(
@@ -209,12 +216,13 @@ def train(
209216
training_batch_size: int = 200,
210217
learning_rate: float = 5e-4,
211218
validation_fraction: float = 0.1,
212-
stop_after_epochs: int = 50,
213-
max_num_epochs: int = 500,
219+
stop_after_epochs: int = 20,
220+
max_num_epochs: int = 2**31 - 1,
214221
clip_max_norm: Optional[float] = 5.0,
215222
calibration_kernel: Optional[Callable] = None,
216223
ema_loss_decay: float = 0.1,
217-
validation_times: Union[Tensor, int] = 20,
224+
validation_times: Union[Tensor, int] = 10,
225+
validation_times_nugget: float = 0.05,
218226
resume_training: bool = False,
219227
force_first_round_loss: bool = False,
220228
discard_prior_samples: bool = False,
@@ -253,6 +261,9 @@ def train(
253261
training and validation losses.
254262
validation_times: Diffusion times at which to evaluate the validation loss
255263
to reduce variance of validation loss.
264+
validation_times_nugget: As both diffusion and flow matching losses often
265+
have high variance losses at the end, we add a small nugget to compute
266+
the validation loss. Default is 0.05 i.e. t_min + 0.05 or t_max - 0.5.
256267
resume_training: Can be used in case training time is limited, e.g. on a
257268
cluster. If `True`, the split between train and validation set, the
258269
optimizer, the number of epochs, and the best validation log-prob will
@@ -341,7 +352,9 @@ def default_calibration_kernel(x):
341352

342353
if isinstance(validation_times, int):
343354
validation_times = torch.linspace(
344-
self._neural_net.t_min, self._neural_net.t_max, validation_times
355+
self._neural_net.t_min + validation_times_nugget,
356+
self._neural_net.t_max - validation_times_nugget,
357+
validation_times,
345358
)
346359
assert isinstance(
347360
validation_times, Tensor
@@ -431,6 +444,8 @@ def default_calibration_kernel(x):
431444
times_batch, *([1] * (masks_batch.ndim - 1))
432445
)
433446

447+
# This will repeat the validation times for each batch in the
448+
# validation set.
434449
validation_times_rep = validation_times.repeat_interleave(
435450
val_batch_size, dim=0
436451
)
@@ -486,6 +501,76 @@ def default_calibration_kernel(x):
486501

487502
return deepcopy(self._neural_net)
488503

504+
def _converged(self, epoch: int, stop_after_epochs: int) -> bool:
505+
"""Return whether the training converged yet and save best model state so far.
506+
507+
Diffusion or flow matching objectives are inherently more stochastic than MLE
508+
for e.g. NPE because they additionally add "noise" by construction. We hence
509+
use a statistical approach to detect convergence by tracking standard deviation
510+
of validation losses. Training is considered converged when the current loss is
511+
significantly worse than the best loss for a sustained period (more than 2 std
512+
deviations above best).
513+
514+
NOTE: The standard deviation of the `validation_loss `is computed in a running
515+
fashion over the most recent 2 × stop_after_epochs loss values.
516+
517+
Args:
518+
epoch: Current epoch in training.
519+
stop_after_epochs: How many fruitless epochs to let pass before stopping.
520+
521+
Returns:
522+
Whether the training has stopped improving, i.e. has converged.
523+
"""
524+
converged = False
525+
526+
assert self._neural_net is not None
527+
neural_net = self._neural_net
528+
529+
# Initialize tracking variables if not exists
530+
if not hasattr(self, '_best_val_loss'):
531+
self._best_val_loss = float('inf')
532+
self._epochs_since_last_improvement = 0
533+
self._best_model_state_dict = None
534+
535+
# Check if we have a new best loss
536+
if self._val_loss < self._best_val_loss:
537+
self._best_val_loss = self._val_loss
538+
self._epochs_since_last_improvement = 0
539+
self._best_model_state_dict = deepcopy(neural_net.state_dict())
540+
else:
541+
# Only start statistical analysis after we have enough data
542+
if len(self._summary["validation_loss"]) >= stop_after_epochs:
543+
# Calculate running statistics of recent losses
544+
recent_losses = torch.tensor(
545+
self._summary["validation_loss"][-stop_after_epochs * 2 :]
546+
)
547+
loss_std = recent_losses.std().item()
548+
549+
# Calculate how many standard deviations the current loss is from the
550+
# best
551+
diff_to_best_normalized = (
552+
self._val_loss - self._best_val_loss
553+
) / loss_std
554+
# Consider it "no improvement" if current loss is significantly
555+
# worse than the best loss (more than 2 std deviations above best)
556+
# This accounts for natural fluctuations while being sensitive to
557+
# real degradation
558+
if diff_to_best_normalized > 2.0:
559+
self._epochs_since_last_improvement += 1
560+
else:
561+
# Reset counter if loss is within acceptable range
562+
self._epochs_since_last_improvement = 0
563+
else:
564+
return False
565+
566+
# If no validation improvement over many epochs, stop training.
567+
if self._epochs_since_last_improvement > stop_after_epochs - 1:
568+
if self._best_model_state_dict is not None:
569+
neural_net.load_state_dict(self._best_model_state_dict)
570+
converged = True
571+
572+
return converged
573+
489574
def _get_potential_function(
490575
self, prior: Distribution, estimator: ConditionalVectorFieldEstimator
491576
) -> Tuple[VectorFieldBasedPotential, TorchTransform]:

sbi/inference/trainers/fmpe/fmpe.py renamed to sbi/inference/trainers/vfpe/fmpe.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

44

5+
import warnings
56
from typing import Any, Dict, Literal, Optional, Union
67

78
from torch.distributions import Distribution
@@ -10,12 +11,12 @@
1011
from sbi import utils as utils
1112
from sbi.inference.posteriors.base_posterior import NeuralPosterior
1213
from sbi.inference.posteriors.posterior_parameters import VectorFieldPosteriorParameters
13-
from sbi.inference.trainers.npse.vector_field_inference import (
14+
from sbi.inference.trainers.vfpe.base_vf_inference import (
1415
VectorFieldEstimatorBuilder,
1516
VectorFieldTrainer,
1617
)
17-
from sbi.neural_nets import flowmatching_nn
18-
from sbi.neural_nets.estimators import ConditionalVectorFieldEstimator
18+
from sbi.neural_nets.estimators.base import ConditionalVectorFieldEstimator
19+
from sbi.neural_nets.factory import posterior_flow_nn
1920

2021

2122
class FMPE(VectorFieldTrainer):
@@ -24,7 +25,11 @@ class FMPE(VectorFieldTrainer):
2425
def __init__(
2526
self,
2627
prior: Optional[Distribution],
27-
density_estimator: Union[str, VectorFieldEstimatorBuilder] = "mlp",
28+
vf_estimator: Union[
29+
Literal["mlp", "ada_mlp", "transformer", "transformer_cross_attn"],
30+
VectorFieldEstimatorBuilder,
31+
] = "mlp",
32+
density_estimator: Optional[VectorFieldEstimatorBuilder] = None,
2833
device: str = "cpu",
2934
logging_level: Union[int, str] = "WARNING",
3035
summary_writer: Optional[SummaryWriter] = None,
@@ -35,11 +40,14 @@ def __init__(
3540
3641
Args:
3742
prior: Prior distribution.
38-
density_estimator: Neural network architecture used to learn the
39-
vector field estimator. Can be a string (e.g. 'mlp' or 'ada_mlp') or a
40-
callable that implements the `VectorFieldEstimatorBuilder` protocol
41-
with `__call__` that receives `theta` and `x` and returns a
43+
vf_estimator: Neural network architecture used to learn the
44+
vector field estimator. Can be a string (e.g. 'mlp', 'ada_mlp',
45+
'transformer' or 'transformer_cross_attn') or a callable that
46+
implements the `VectorFieldEstimatorBuilder` protocol with
47+
`__call__` that receives `theta` and `x` and returns a
4248
`ConditionalVectorFieldEstimator`.
49+
density_estimator: Deprecated. Use `vf_estimator` instead. When passed, a
50+
warning is raised and the `vf_estimator="mlp"` default is used.
4351
device: Device to use for training.
4452
logging_level: Logging level.
4553
summary_writer: Summary writer for tensorboard.
@@ -48,17 +56,24 @@ def __init__(
4856
`density_estimator` is a string.
4957
"""
5058

59+
if density_estimator is not None:
60+
warnings.warn(
61+
"`density_estimator` is deprecated and will be removed in a future "
62+
"release. Use `vf_estimator` instead.",
63+
FutureWarning,
64+
stacklevel=2,
65+
)
66+
vf_estimator = density_estimator
67+
5168
super().__init__(
5269
prior=prior,
5370
device=device,
5471
logging_level=logging_level,
5572
summary_writer=summary_writer,
5673
show_progress_bars=show_progress_bars,
57-
vector_field_estimator_builder=density_estimator,
74+
vector_field_estimator_builder=vf_estimator,
5875
**kwargs,
5976
)
60-
# density_estimator name is kept since it is public API, but it is
61-
# actually misleading since it is a builder for an estimator.
6277

6378
def build_posterior(
6479
self,
@@ -106,6 +121,9 @@ def build_posterior(
106121
posterior_parameters=posterior_parameters,
107122
)
108123

109-
def _build_default_nn_fn(self, **kwargs) -> VectorFieldEstimatorBuilder:
110-
model = kwargs.pop("vector_field_estimator_builder", "mlp")
111-
return flowmatching_nn(model=model, **kwargs)
124+
def _build_default_nn_fn(
125+
self,
126+
model: Literal["mlp", "ada_mlp", "transformer", "transformer_cross_attn"],
127+
**kwargs,
128+
) -> VectorFieldEstimatorBuilder:
129+
return posterior_flow_nn(model=model, **kwargs)

0 commit comments

Comments
 (0)