-
Notifications
You must be signed in to change notification settings - Fork 196
Simformer #1621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Simformer #1621
Conversation
Removing code duplication on embedding net handing
…le full input in MVF Wrapper for 2-dim tensors
After re-basing this PR to merge into main, a bunch of collaborators' commits from the original parent branch appeared here. Tried to clean a little by squashing commits, but it would have required 285 different conflicts solutions 😅 so I aborted the operation and had to keep everything as it is |
…t use of simformer (condition is an empty tensor)
… up time of not slow tests
…t is default True, in linear gaussian vf test
…ing a warning in case it is detected
…xture to gpu Pass device information to IID method in VectorFieldBasedPotential
Alright, as requested by Google:
I mark the below as the last commit for my GSoC. Nonetheless, I am still able to work more on this to implement advices and fixes after review👍 |
Simformer
Important
This PR is part of Google Summer of Code 2025
Note
Before opening this PR, I initially experimented with the Simformer and auxiliary components in a separate branch of my sbi fork. You can find it here. I used such branch (
simformer-dev
) as a first environment where I could experiment solutions with more freedom, then I opened this PR once I got a minimum viable product. Such branch basically served as my working enviroment for the first month and a half of the GSOC. Nevertheless, all the code I finalized there has been fully incorporated into this PR you are reading.More specifically, in such branch I mainly worked on a first version of the Simformer neural network architecture and the "masked" interface, I also attempted to introduce a Joint distribution interface, i.e., a parallel interface to the current "Posterior" approach in sbi that could generalize better to the Simformer case—as the Simformer do not work by means of "posterior", "likelihood" or such, but more generally by means of arbitrary conditionals. Neverthless, the later has been dropped to rather implement the use of a Wrapper class that could adapt the more general Simformer approach to the existing sbi posterior interface (see below for more information)
Implemented the Simformer, Gloeckler et al. 2024 ICML. The Simformer aims to unify the various simulation-based inference paradigms (posterior, likelihood, or arbitrary conditional sampling) within a single framework, allowing users to sample from any conditional distribution of interest—potentially acting also by a novel data generator if one samples the unconditioned joint distribution of all variables.
The Simformer diverges from the standard sbi paradigm of data provided by means of
theta
andx
, it rather exploits a full tensorinputs
of data and two masks:condition_mask
to identify which variables are latent (to be inferred by the Simformer) and which are observed (ground data)edge_mask
to identify relationships between variables, equivalent to and adjacency matrix for a DAG. This mask will be directly used by the transformer attention block to mask-out certain attention scores.Design of the Masked Classes
To accomplish this, it has been necessary to create some "parallel" classes of the current
ScoreEstimator
,VectorFieldEstimator
, etc. to work by means of this "masked" paradigm.Generally, each "Masked" version of other objects are provided exactly below their counterpart in the same python file, e.g.
MaskedConditionalVectorFieldEstimator
is exactly below the code block ofConditionalVectorFieldEstimator
; and they simply consist in an overall re-factor of the original counterpart, where each use of a "theta
andx
" or "inputs
andcondition
" has been replaced with a general "inputs
,condtion_mask
, andedge_mask
".It has been also introduced a Wrapper class able to adapt the original API of the Posterior to the Simformer one, thanks to this class one is able to simply call$(\theta, x)$ setting to the full input tensor, and back.
build_conditional()
method directly on the Simformer inference object and obtain a standard Posterior object that works as always—given some fixed condition and edge masks. The Wrapper handles all the shapes automatically and perform auxiliary operations to pass the data to a Simformer network and the underlying masked estimator; this is done mainly through two helper functions:assemble_full_inputs()
anddisassemble_full_inputs()
, which are able to convert between theAt inference time, an
edge_mask
can be specified, otherwise it will beNone
(equivalent to a full ones tensor, but memory safer),condition_mask
instead must be specifically passed atbuild_conditional
time; another option is to directly use thebuild_posterior()
andbuild_likelihood()
method which will automatically generate an appropriate condition_mask based onposterior_latent_idx
andposterior_observed_idx
parameters specified at init() of the Simformer.Also at training time an$\text{Bernoulli}(p=0.5)$ .
edge_mask
can be specified, if not the default value will still beNone
, more generally the user can pass a Callable to generate condition or edge masks, so that one can simply choose the mask distributions they prefer. Sets of tensors/lists or even just one tensor can be passed as well. Masks are also generated just-in-time (JIT) for the training, that is, they are not provided atappend_simulation()
, but during the train() in order to save up memory. Differently from inference time, here if a condition mask is not specified, a default generator will be used, producing masks sampled by aNote that the Simformer potentially allows the user to set any mask of their choice both at training and inference time, it is rather duty of the user to provide coherent definitions (callables, sets, or fixed tensors) that make sense, e.g. if the user passes a specific edge mask at training time, the Simformer will learn that specific DAG structure, it is then duty of the user to pass a coherent edge_mask also when calling build_conditional, build_posterior or build_likelihood.
Furthermore, the Simformer is also able to manage invalid inputs (
nan
's andinf
's) natively, ifhandle_invalid_x=True
then the Simformer will automatically spot invalid inputs at training time (still JIT) and switch their state on the condition mask as latent (to be inferred), other than also replace such values with small Gaussian noise for numerical stability.Also, a Flow-matching equivalent of the Simformer (we assumed the above to be score-based) has been provided.
This PR then includes integration with the
mini-sbibm
benchmakr suite, and a notebook tutorial for the Simformer (underadvanced_tutorials/docs
), where I showcase its use. I also tried to make the API Reference as clear as possible for documentation.Refactor of existing code
Parts of the existing code have been refactored, mainly to avoid repetition of code and keep everything DRY. The most important pieces of code that have been modified are:
mean_t
,std_t
etc. into some standard Mixins (e.g., instead ofVEScoreEstimator(ConditionalScoreEstimator)
one now haveVarianceExplodingSDE
which definedmean_t
,std_t
etc., andVEScoreEstimator
becomesVEScoreEstimator(ConditionalScoreEstimator, VarianceExplodingSDE)
; so that I can also define easilyMaskedVEScoreEstimator(MaskedConditionalScoreEstimator, VarianceExplodingSDE)
without repeating the VE SDE pieces.)NeuralInference
interface, which has been split using a Mixin too (BaseNeuralInference
) which defines shared properties of bothNeuralInference
andMaskedNeuralInference
, this also requested some minor adjustments mainly for methods such as_resolve_prior()
and_resolve_estimator()
, most importantly a newNoPrior
object has been created as a temporary solution for Keep prior optional and remove unnecessary copies of theas from ImproperPrior. #1635ConditionalVectorFieldEstimator
and theMaskedConditionalVectorFieldEstimator
where simplified by moving shared code into a Mixin calledBaseConditionalVectorFieldEstimator
, mainly regardingmean_base
,std_base
properties, or methods such asdiffusion_fn()
Summary of modified files
Files I modified should count to be the following:
sbi/inference
sbi/inference/trainers/base.py
: AddedMaskedNeuralInference
.sbi/inference/trainers/vfpe/base_vf_inference.py
: AddedMaskedVectorFieldEstimatorBuilder
andMaskedVectorFieldInference
(subclass ofMaskedNeuralInference
).sbi/inference/trainers/vfpe/simformer.py
: New file introducing the Simformer inference class.sbi/neural_nets
sbi/neural_nets/factory.py
: Added support for building Simformer networks (simformer_nn
).sbi/neural_nets/estimators/base.py
: AddedMaskedConditionalEstimator
andMaskedConditionalVectorFieldEstimator
(subclass ofMaskedConditionalEstimator
).sbi/neural_nets/estimators/score_estimator.py
:MaskedConditionalScoreEstimator
(subclass ofMaskedConditionalVectorFieldEstimator
), placed directly aboveConditionalScoreEstimator
.MaskedVEScoreEstimator
(subclass ofMaskedConditionalScoreEstimator
).sbi/neural_nets/net_builders/vector_field_nets.py
:build_vector_field_estimator
updated to supportsimformer
andmasked-score
.MaskedSimformerBlock
,MaskedDiTBlock
,SimformerNet
(subclass ofMaskedVectorFieldNet
), andbuild_simformer_network
(defines default architecture parameters).sbi/utils
sbi/utils/vector_field_utils.py
: AddedMaskedVectorFieldNet
.sbi/analysis
sbi/analysis/plots.py
: Minor fix to ensure CPU conversion inensure_numpy()
(added.cpu()
before.numpy()
).Unit Test
Introduced benchmarks (
mini_sbibm
) and test for the simformer and related masked objects intests/linearGaussian_vector_field_test.py
tests/posterior_nn_test.py
tests/vector_field_nets_test.py
tests/vf_estimator_test.py
(which also includes shape tests on the Wrapper)tests/bm_test.py
Regarding linear gaussian tests, I tried to implement the simformer tests in existing methods as much as possible, nonetheless iid test and sde/ode sampling equivalence are still provided as separate dedicated tests and fixtures
New files
docs/advanced_tutorials/22_simformer.ipynb
sbi/inference/trainers/vfpe/simformer.py
: including both Score-based and Flow-matching Simformer interfacesThank you
Thank you sbi and Google for this opportunity. It has been so rewarding implementing the Simformer: not only I learned something completely new itself, but most importantly I understood how to do it: having to familiarize with new concepts, writing code within code made by others, and following indications of mentors are the real value of this experience. Special thanks to my mentors Manuel (@manuelgloeckler ) and Jan (@janfb ) for accepting my proposal, and @manuelgloeckler in particular for having helped me throughout the whole journey!