-
Notifications
You must be signed in to change notification settings - Fork 69
[WIP] Add support for graphical simulators #487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
[WIP] Add support for graphical simulators #487
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅ |
I just thought about an alternative interface, would be glad for some feedback: # we could make passing graph optional later
def sample_tau(graph):
tau = np.abs(np.random.normal())
return tau
def sample_omega(graph):
omega = np.abs(np.random.normal())
return omega
def sample_lambda(graph):
ncol = graph.sample_int("ncol", 5, 10)
tau = graph.sample_var("tau", shape=(ncol,))
lamb = np.random.normal(loc=0, scale=tau)
return lamb
def sample_x(graph):
nrow = graph.sample_int("nrow", 1, 10)
ncol = graph.sample_int("ncol", 5, 10) # cached between sample_lambda and sample_x
lamb = graph.sample_var("lambda", shape=(1, ncol))
omega = graph.sample_var("omega", shape=(1, 1))
x = np.random.normal(loc=lamb, scale=omega, size=(nrow, ncol))
return x
graph = GraphSimulator()
# each node returns exactly one variable
graph.add_node("tau", sample_tau)
graph.add_node("omega", sample_omega)
graph.add_node("lambda", sample_lambda)
graph.add_node("x", sample_x)
# returns a list of dicts
samples = graph.sample(10) I already have a working implementation for this, but it might not exactly fit the needs of the rest of the library, so I would like to discuss first. |
Thank you @daniel-habermann for your PR! I will review it (specifically the interface) in the next couple of days. @LarsKue since Daniel already spend quite a bit of time and thought for this PR, I would like to go with his implementation for now. If we end up not liking it for some reason, we can still discuss alternatives then. |
I'm not against radically changing the suggested interface, but one design consideration is consistency: In the current interface, a user can return dictionaries with arbitrary keys, so it would be quite difficult to explain why this is possible when using The same is true for What problem were you trying to resolve with your suggestion? If the concern is boilerplate and adding edges to the networks, I expect we can resolve almost all cases with code introspection, i.e. all a user has to provide are function definitions as in the usual |
WIP pull request to add support for graphical simulators. I'm going to update this message and tag some people once it has reached a state where it makes sense to read further. All discussion and feedback welcome!
Summary and Motivation
This PR introduces initial support for graphical simulators. The main idea is to represent a complex simulation program as a directed acyclic graph (DAG), where nodes represent sets of parameters and edges denote conditional dependencies.$p(\theta)$ can often be expressed in the form of some factorization along a DAG $p(\theta_1, \dots, \theta_N) = \prod_{i=1}^{N} p(\theta_i | \text{Parents}(\theta_i))$ .
Such a structure is a natural representation for many Bayesian models because the joint distribution of parameters
The benefit of making these dependency structures explicit is that the converse is also true: By stating the conditional dependencies, a corresponding DAG also encodes conditional independencies implied by the distribution, which we can then use to automatically build efficient network architectures, for example for multilevel models.
Current Implementation
Consider a standard two-level hierarchical model:
Such a model can be represented by the following diagram:
where the dashed boxes denote that parameters are exchangeable. Currently, such a diagram would be implemented like this:
Design space
There is still a long list of design choices:
How to determine how often each node is executed for each batch?
For multilevel models, we want to vary the number of groups and observations within each group for each batch. Currently, this is achieved by the
sample_size
function argument, which expects a callable returning an integer.One question is if we even need such an argument, or could remove it by relying on something like the current
meta_fn
.If we go the
meta_fn
route, do we have a singlemeta_fn
for each node or a global oneHow can we represent more exotic models or non-DAG structures, like state space models
How do we handle which nodes return observed data?
This becomes important when talking about graph inversions. Currently, we can attach arbitrary metadata to each node and the graph inversion algorithm searches for an "observed" keyword, but from a user perspective this should probably be improved. We might even not care about this at all because the adapter defines
summary_conditions
orinference_conditions
.How is all of this represented internally?
The resulting data structure is non-rectangular because each batch might have a different number of calls for each node.