Skip to content

Training interface #766

@michaeldeistler

Description

@michaeldeistler

In order to give more flexibility, we could expose the entire training loop to the user:

from sbi.inference import nre_loss, ratio_estimator_based_potential, MCMCPosterior

net = classifier_nn("mdn")

data_loader = ...
for e in range(epochs):
    optim.zero_grad()
    loss = nre_loss(net, theta, x)
    loss.backward()
    optim.step()

potential = ratio_estimator_based_potential(net, prior, x_o)
posterior = MCMCPosterior(potential, proposal=prior)

This would require two things:

  • a good way to deal with z-scoring
  • separating the loss functions from the rest

Metadata

Metadata

Labels

architectureInternal changes without API consequencesenhancementNew feature or request

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions