Skip to content

[WIP] Add support for graphical simulators #487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: dev
Choose a base branch
from

Conversation

daniel-habermann
Copy link
Contributor

@daniel-habermann daniel-habermann commented May 23, 2025

WIP pull request to add support for graphical simulators. I'm going to update this message and tag some people once it has reached a state where it makes sense to read further. All discussion and feedback welcome!

Summary and Motivation

This PR introduces initial support for graphical simulators. The main idea is to represent a complex simulation program as a directed acyclic graph (DAG), where nodes represent sets of parameters and edges denote conditional dependencies.
Such a structure is a natural representation for many Bayesian models because the joint distribution of parameters $p(\theta)$ can often be expressed in the form of some factorization along a DAG $p(\theta_1, \dots, \theta_N) = \prod_{i=1}^{N} p(\theta_i | \text{Parents}(\theta_i))$.

The benefit of making these dependency structures explicit is that the converse is also true: By stating the conditional dependencies, a corresponding DAG also encodes conditional independencies implied by the distribution, which we can then use to automatically build efficient network architectures, for example for multilevel models.

Current Implementation

Consider a standard two-level hierarchical model:

$$ \begin{aligned} \tau,\omega &\sim \text{Normal}^{+}(0, 1)\\ \lambda_j &\sim \text{Normal}(0, \tau)\\ x_{ij} &\sim \text{Normal}(\lambda_j, \omega)\\ \end{aligned} $$

Such a model can be represented by the following diagram:

where the dashed boxes denote that parameters are exchangeable. Currently, such a diagram would be implemented like this:

from bayesflow.experimental.graphical_simulator import GraphicalSimulator
import numpy as np

def sample_tau():
    tau = np.abs(np.random.normal())
    return dict(tau=tau)

def sample_omega():
    omega = np.abs(np.random.normal())
    return dict(omega=omega)

def sample_lambda_j(tau):
    lambda_j = np.abs(np.random.normal(loc=0, scale=tau))
    return dict(lambda_j=lambda_j)

def sample_x_ij(lambda_j, omega):
    x_ij = np.random.normal(loc=lambda_j, scale=omega)
    return dict(x_ij=x_ij)

simulator = GraphicalSimulator()

simulator.add_node("tau", sampling_function=sample_tau, sample_size=lambda: 1)
simulator.add_node("omega", sampling_function=sample_omega, sample_size=lambda: 1)
simulator.add_node("lambda_j", sampling_function=sample_lambda_j, sample_size=lambda: np.random.randint(5, 10))
simulator.add_node("x_ij", sampling_function=sample_x_ij, sample_size=lambda: np.random.randint(1, 10))

simulator.add_edge("tau", "lambda_j")
simulator.add_edge("lambda_j", "x_ij")
simulator.add_edge("omega", "x_ij")

Design space

There is still a long list of design choices:

How to determine how often each node is executed for each batch?

For multilevel models, we want to vary the number of groups and observations within each group for each batch. Currently, this is achieved by the sample_size function argument, which expects a callable returning an integer.
One question is if we even need such an argument, or could remove it by relying on something like the current meta_fn.

If we go the meta_fn route, do we have a single meta_fn for each node or a global one

How can we represent more exotic models or non-DAG structures, like state space models

How do we handle which nodes return observed data?

This becomes important when talking about graph inversions. Currently, we can attach arbitrary metadata to each node and the graph inversion algorithm searches for an "observed" keyword, but from a user perspective this should probably be improved. We might even not care about this at all because the adapter defines summary_conditions or inference_conditions.

How is all of this represented internally?

The resulting data structure is non-rectangular because each batch might have a different number of calls for each node.

@daniel-habermann daniel-habermann added feature New feature or request draft Draft Pull Request, Work in Progress labels May 23, 2025
@daniel-habermann daniel-habermann marked this pull request as draft May 23, 2025 14:29
Copy link

codecov bot commented May 23, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

@LarsKue
Copy link
Contributor

LarsKue commented May 23, 2025

I just thought about an alternative interface, would be glad for some feedback:

# we could make passing graph optional later
def sample_tau(graph):
    tau = np.abs(np.random.normal())
    return tau


def sample_omega(graph):
    omega = np.abs(np.random.normal())
    return omega


def sample_lambda(graph):
    ncol = graph.sample_int("ncol", 5, 10)
    tau = graph.sample_var("tau", shape=(ncol,))
    lamb = np.random.normal(loc=0, scale=tau)
    return lamb


def sample_x(graph):
    nrow = graph.sample_int("nrow", 1, 10)
    ncol = graph.sample_int("ncol", 5, 10)  # cached between sample_lambda and sample_x
    lamb = graph.sample_var("lambda", shape=(1, ncol))
    omega = graph.sample_var("omega", shape=(1, 1))

    x = np.random.normal(loc=lamb, scale=omega, size=(nrow, ncol))
    return x


graph = GraphSimulator()

# each node returns exactly one variable
graph.add_node("tau", sample_tau)
graph.add_node("omega", sample_omega)
graph.add_node("lambda", sample_lambda)
graph.add_node("x", sample_x)

# returns a list of dicts
samples = graph.sample(10)

I already have a working implementation for this, but it might not exactly fit the needs of the rest of the library, so I would like to discuss first.

@paul-buerkner
Copy link
Contributor

paul-buerkner commented May 24, 2025

Thank you @daniel-habermann for your PR! I will review it (specifically the interface) in the next couple of days.

@LarsKue since Daniel already spend quite a bit of time and thought for this PR, I would like to go with his implementation for now. If we end up not liking it for some reason, we can still discuss alternatives then.

@daniel-habermann
Copy link
Contributor Author

daniel-habermann commented May 26, 2025

I just thought about an alternative interface, would be glad for some feedback:

# we could make passing graph optional later
def sample_tau(graph):
    tau = np.abs(np.random.normal())
    return tau


def sample_omega(graph):
    omega = np.abs(np.random.normal())
    return omega


def sample_lambda(graph):
    ncol = graph.sample_int("ncol", 5, 10)
    tau = graph.sample_var("tau", shape=(ncol,))
    lamb = np.random.normal(loc=0, scale=tau)
    return lamb


def sample_x(graph):
    nrow = graph.sample_int("nrow", 1, 10)
    ncol = graph.sample_int("ncol", 5, 10)  # cached between sample_lambda and sample_x
    lamb = graph.sample_var("lambda", shape=(1, ncol))
    omega = graph.sample_var("omega", shape=(1, 1))

    x = np.random.normal(loc=lamb, scale=omega, size=(nrow, ncol))
    return x


graph = GraphSimulator()

# each node returns exactly one variable
graph.add_node("tau", sample_tau)
graph.add_node("omega", sample_omega)
graph.add_node("lambda", sample_lambda)
graph.add_node("x", sample_x)

# returns a list of dicts
samples = graph.sample(10)

I already have a working implementation for this, but it might not exactly fit the needs of the rest of the library, so I would like to discuss first.

I'm not against radically changing the suggested interface, but one design consideration is consistency: In the current interface, a user can return dictionaries with arbitrary keys, so it would be quite difficult to explain why this is possible when using bf.simulators.make_simulator, but not when building a graphical simulator.

The same is true for sample_var. The current interface allows passing inputs as parameters, so it would break consistency with the current interface if we suddenly have to request parameters from within a function. There are still some inconsistencies with the approach outlined in the first pr, for example sampling_function should at least be sampling_fn to stay consistent with meta_fn in the current interface.

What problem were you trying to resolve with your suggestion? If the concern is boilerplate and adding edges to the networks, I expect we can resolve almost all cases with code introspection, i.e. all a user has to provide are function definitions as in the usual make_simulator call, and the graph can then be inferred from the function arguments, because each function is a node and the edges can be retrieved by matching the dictionary outputs to input arguments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
draft Draft Pull Request, Work in Progress feature New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants