Skip to content

Gradient of the log probability #244

Open
@arrjon

Description

@arrjon

It would be nice to be able to compute the gradient of the log probability (in the new BayesFlow version).

For example, with the jax-backend, I would like to be able to do something along this lines:

from jax import grad

def partial_objective(theta):
    log_prop = approximator.log_prob(
        data={'theta': theta, 'x': x},
        batch_size=1,
    )
    return log_prop[0]

grad(partial_objective)(theta)

The main reason this is not working at the moment, is the adapter:
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[1].
Any idea, how this could be done?

Edit: also just evaluating the log_prop at the moment works only if one does not use standardize in the adapter. I think this is related to #233. However, calling grad(approximator._log_prob(...)) circumventing the adapter works.

Metadata

Metadata

Assignees

Labels

featureNew feature or request

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions