-
Notifications
You must be signed in to change notification settings - Fork 196
Closed
Labels
architectureInternal changes without API consequencesInternal changes without API consequencesenhancementNew feature or requestNew feature or request
Milestone
Description
In order to give more flexibility, we could expose the entire training loop to the user:
from sbi.inference import nre_loss, ratio_estimator_based_potential, MCMCPosterior
net = classifier_nn("mdn")
data_loader = ...
for e in range(epochs):
optim.zero_grad()
loss = nre_loss(net, theta, x)
loss.backward()
optim.step()
potential = ratio_estimator_based_potential(net, prior, x_o)
posterior = MCMCPosterior(potential, proposal=prior)
This would require two things:
- a good way to deal with z-scoring
- separating the loss functions from the rest
janfb
Metadata
Metadata
Assignees
Labels
architectureInternal changes without API consequencesInternal changes without API consequencesenhancementNew feature or requestNew feature or request