Skip to content

Conversation

nMaax
Copy link
Contributor

@nMaax nMaax commented Jul 16, 2025

Simformer

Important

This PR is part of Google Summer of Code 2025

Note

Before opening this PR, I initially experimented with the Simformer and auxiliary components in a separate branch of my sbi fork. You can find it here. I used such branch (simformer-dev) as a first environment where I could experiment solutions with more freedom, then I opened this PR once I got a minimum viable product. Such branch basically served as my working enviroment for the first month and a half of the GSOC. Nevertheless, all the code I finalized there has been fully incorporated into this PR you are reading.

More specifically, in such branch I mainly worked on a first version of the Simformer neural network architecture and the "masked" interface, I also attempted to introduce a Joint distribution interface, i.e., a parallel interface to the current "Posterior" approach in sbi that could generalize better to the Simformer case—as the Simformer do not work by means of "posterior", "likelihood" or such, but more generally by means of arbitrary conditionals. Neverthless, the later has been dropped to rather implement the use of a Wrapper class that could adapt the more general Simformer approach to the existing sbi posterior interface (see below for more information)

Implemented the Simformer, Gloeckler et al. 2024 ICML. The Simformer aims to unify the various simulation-based inference paradigms (posterior, likelihood, or arbitrary conditional sampling) within a single framework, allowing users to sample from any conditional distribution of interest—potentially acting also by a novel data generator if one samples the unconditioned joint distribution of all variables.

fig1b-1

The Simformer diverges from the standard sbi paradigm of data provided by means of theta and x, it rather exploits a full tensor inputs of data and two masks:

  • A condition_mask to identify which variables are latent (to be inferred by the Simformer) and which are observed (ground data)
  • A edge_mask to identify relationships between variables, equivalent to and adjacency matrix for a DAG. This mask will be directly used by the transformer attention block to mask-out certain attention scores.
white_all_in_one

Design of the Masked Classes

To accomplish this, it has been necessary to create some "parallel" classes of the current ScoreEstimator, VectorFieldEstimator, etc. to work by means of this "masked" paradigm.

Generally, each "Masked" version of other objects are provided exactly below their counterpart in the same python file, e.g. MaskedConditionalVectorFieldEstimator is exactly below the code block of ConditionalVectorFieldEstimator; and they simply consist in an overall re-factor of the original counterpart, where each use of a "theta and x" or "inputs and condition" has been replaced with a general "inputs, condtion_mask, and edge_mask".

It has been also introduced a Wrapper class able to adapt the original API of the Posterior to the Simformer one, thanks to this class one is able to simply call build_conditional() method directly on the Simformer inference object and obtain a standard Posterior object that works as always—given some fixed condition and edge masks. The Wrapper handles all the shapes automatically and perform auxiliary operations to pass the data to a Simformer network and the underlying masked estimator; this is done mainly through two helper functions: assemble_full_inputs() and disassemble_full_inputs(), which are able to convert between the $(\theta, x)$ setting to the full input tensor, and back.

At inference time, an edge_mask can be specified, otherwise it will be None (equivalent to a full ones tensor, but memory safer), condition_mask instead must be specifically passed at build_conditional time; another option is to directly use the build_posterior() and build_likelihood() method which will automatically generate an appropriate condition_mask based on posterior_latent_idx and posterior_observed_idx parameters specified at init() of the Simformer.

Also at training time an edge_mask can be specified, if not the default value will still be None, more generally the user can pass a Callable to generate condition or edge masks, so that one can simply choose the mask distributions they prefer. Sets of tensors/lists or even just one tensor can be passed as well. Masks are also generated just-in-time (JIT) for the training, that is, they are not provided at append_simulation(), but during the train() in order to save up memory. Differently from inference time, here if a condition mask is not specified, a default generator will be used, producing masks sampled by a $\text{Bernoulli}(p=0.5)$.

Note that the Simformer potentially allows the user to set any mask of their choice both at training and inference time, it is rather duty of the user to provide coherent definitions (callables, sets, or fixed tensors) that make sense, e.g. if the user passes a specific edge mask at training time, the Simformer will learn that specific DAG structure, it is then duty of the user to pass a coherent edge_mask also when calling build_conditional, build_posterior or build_likelihood.

Furthermore, the Simformer is also able to manage invalid inputs (nan's and inf's) natively, if handle_invalid_x=True then the Simformer will automatically spot invalid inputs at training time (still JIT) and switch their state on the condition mask as latent (to be inferred), other than also replace such values with small Gaussian noise for numerical stability.

Also, a Flow-matching equivalent of the Simformer (we assumed the above to be score-based) has been provided.

This PR then includes integration with the mini-sbibm benchmakr suite, and a notebook tutorial for the Simformer (under advanced_tutorials/docs), where I showcase its use. I also tried to make the API Reference as clear as possible for documentation.


Refactor of existing code

Parts of the existing code have been refactored, mainly to avoid repetition of code and keep everything DRY. The most important pieces of code that have been modified are:

  1. the SDE Estimators, where we moved the definition of mean_t, std_t etc. into some standard Mixins (e.g., instead of VEScoreEstimator(ConditionalScoreEstimator) one now have VarianceExplodingSDE which defined mean_t, std_t etc., and VEScoreEstimator becomes VEScoreEstimator(ConditionalScoreEstimator, VarianceExplodingSDE); so that I can also define easily MaskedVEScoreEstimator(MaskedConditionalScoreEstimator, VarianceExplodingSDE) without repeating the VE SDE pieces.)
  2. the NeuralInference interface, which has been split using a Mixin too (BaseNeuralInference) which defines shared properties of both NeuralInference and MaskedNeuralInference, this also requested some minor adjustments mainly for methods such as _resolve_prior() and _resolve_estimator(), most importantly a new NoPrior object has been created as a temporary solution for Keep prior optional and remove unnecessary copies of theas from ImproperPrior. #1635
  3. the ConditionalVectorFieldEstimator and the MaskedConditionalVectorFieldEstimator where simplified by moving shared code into a Mixin called BaseConditionalVectorFieldEstimator, mainly regarding mean_base, std_base properties, or methods such as diffusion_fn()

Summary of modified files

Files I modified should count to be the following:

sbi/inference

  • sbi/inference/trainers/base.py: Added MaskedNeuralInference.
  • sbi/inference/trainers/vfpe/base_vf_inference.py: Added MaskedVectorFieldEstimatorBuilder and MaskedVectorFieldInference (subclass of MaskedNeuralInference).
  • sbi/inference/trainers/vfpe/simformer.py: New file introducing the Simformer inference class.

sbi/neural_nets

  • sbi/neural_nets/factory.py: Added support for building Simformer networks (simformer_nn).

  • sbi/neural_nets/estimators/base.py: Added MaskedConditionalEstimator and MaskedConditionalVectorFieldEstimator (subclass of MaskedConditionalEstimator).

  • sbi/neural_nets/estimators/score_estimator.py:

    • Added MaskedConditionalScoreEstimator (subclass of MaskedConditionalVectorFieldEstimator), placed directly above ConditionalScoreEstimator.
    • Added MaskedVEScoreEstimator (subclass of MaskedConditionalScoreEstimator).
  • sbi/neural_nets/net_builders/vector_field_nets.py:

    • build_vector_field_estimator updated to support simformer and masked-score.
    • Introduced MaskedSimformerBlock, MaskedDiTBlock, SimformerNet (subclass of MaskedVectorFieldNet), and build_simformer_network (defines default architecture parameters).

sbi/utils

  • sbi/utils/vector_field_utils.py: Added MaskedVectorFieldNet.

sbi/analysis

  • sbi/analysis/plots.py: Minor fix to ensure CPU conversion in ensure_numpy() (added .cpu() before .numpy()).

Unit Test

Introduced benchmarks (mini_sbibm) and test for the simformer and related masked objects in

  • tests/linearGaussian_vector_field_test.py
  • tests/posterior_nn_test.py
  • tests/vector_field_nets_test.py
  • tests/vf_estimator_test.py (which also includes shape tests on the Wrapper)
  • tests/bm_test.py

Regarding linear gaussian tests, I tried to implement the simformer tests in existing methods as much as possible, nonetheless iid test and sde/ode sampling equivalence are still provided as separate dedicated tests and fixtures

New files

  • docs/advanced_tutorials/22_simformer.ipynb
  • sbi/inference/trainers/vfpe/simformer.py: including both Score-based and Flow-matching Simformer interfaces

Thank you

Thank you sbi and Google for this opportunity. It has been so rewarding implementing the Simformer: not only I learned something completely new itself, but most importantly I understood how to do it: having to familiarize with new concepts, writing code within code made by others, and following indications of mentors are the real value of this experience. Special thanks to my mentors Manuel (@manuelgloeckler ) and Jan (@janfb ) for accepting my proposal, and @manuelgloeckler in particular for having helped me throughout the whole journey!

@nMaax
Copy link
Contributor Author

nMaax commented Aug 30, 2025

After re-basing this PR to merge into main, a bunch of collaborators' commits from the original parent branch appeared here. Tried to clean a little by squashing commits, but it would have required 285 different conflicts solutions 😅 so I aborted the operation and had to keep everything as it is

@nMaax
Copy link
Contributor Author

nMaax commented Sep 1, 2025

Alright, as requested by Google:

If the pull request is going to have more work done after GSoC is over, make sure the last GSoC commit is noted.

I mark the below as the last commit for my GSoC. Nonetheless, I am still able to work more on this to implement advices and fixes after review👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants