You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
chore: replace legacy jax.random.PRNGKey with modern jax.random.key (#2134)
* chore: replace legacy `jax.random.PRNGKey` with modern `jax.random.key`
* fix: files reformatted with `ruff==0.15.0`
* test: update PRNG key test
* test: update test effected by the `optax==0.2.7` release
* Revert "test: update test effected by the `optax==0.2.7` release"
This reverts commit 14d3ee8.
* Revert "fix: files reformatted with `ruff==0.15.0`"
This reverts commit 2482781.
* test: pickle `jax.random.key`
* test: use legacy keys in `test/test_pickle.py` until
jax-ml/jax#35065 is resolved
1. Unlike in Pyro, `numpyro.sample('x', dist.Normal(0, 1))` does not work. Why?
288
288
289
-
You are most likely using a `numpyro.sample` statement outside an inference context. JAX does not have a global random state, and as such, distribution samplers need an explicit random number generator key ([PRNGKey](https://jax.readthedocs.io/en/latest/jax.random.html#jax.random.PRNGKey)) to generate samples from. NumPyro's inference algorithms use the [seed](https://num.pyro.ai/en/latest/handlers.html#seed) handler to thread in a random number generator key, behind the scenes.
289
+
You are most likely using a `numpyro.sample` statement outside an inference context. JAX does not have a global random state, and as such, distribution samplers need an explicit random number generator key ([PRNG Key](https://jax.readthedocs.io/en/latest/jax.random.html#jax.random.key)) to generate samples from. NumPyro's inference algorithms use the [seed](https://num.pyro.ai/en/latest/handlers.html#seed) handler to thread in a random number generator key, behind the scenes.
290
290
291
291
Your options are:
292
292
293
-
- Call the distribution directly and provide a `PRNGKey`, e.g. `dist.Normal(0, 1).sample(PRNGKey(0))`
294
-
- Provide the `rng_key` argument to `numpyro.sample`. e.g. `numpyro.sample('x', dist.Normal(0, 1), rng_key=PRNGKey(0))`.
293
+
- Call the distribution directly and provide a PRNG key, e.g. `dist.Normal(0, 1).sample(key(0))`
294
+
- Provide the `rng_key` argument to `numpyro.sample`. e.g. `numpyro.sample('x', dist.Normal(0, 1), rng_key=key(0))`.
295
295
- Wrap the code in a `seed` handler, used either as a context manager or as a function that wraps over the original callable. e.g.
296
296
297
297
```python
298
-
with handlers.seed(rng_seed=0): # random.PRNGKey(0) is used
299
-
x = numpyro.sample('x', dist.Beta(1, 1)) # uses a PRNGKey split from random.PRNGKey(0)
300
-
y = numpyro.sample('y', dist.Bernoulli(x)) # uses different PRNGKey split from the last one
298
+
with handlers.seed(rng_seed=0): # random.key(0) is used
299
+
x = numpyro.sample('x', dist.Beta(1, 1)) # uses a PRNG key split from random.key(0)
300
+
y = numpyro.sample('y', dist.Bernoulli(x)) # uses different PRNG key split from the last one
0 commit comments