Open
Description
It would be nice to be able to compute the gradient of the log probability (in the new BayesFlow version).
For example, with the jax-backend, I would like to be able to do something along this lines:
from jax import grad
def partial_objective(theta):
log_prop = approximator.log_prob(
data={'theta': theta, 'x': x},
batch_size=1,
)
return log_prop[0]
grad(partial_objective)(theta)
The main reason this is not working at the moment, is the adapter:
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[1]
.
Any idea, how this could be done?
Edit: also just evaluating the log_prop
at the moment works only if one does not use standardize in the adapter. I think this is related to #233. However, calling grad(approximator._log_prob(...))
circumventing the adapter works.