Skip to content

Commit d5598e7

Browse files
authored
chore: replace legacy jax.random.PRNGKey with modern jax.random.key (#2134)
* chore: replace legacy `jax.random.PRNGKey` with modern `jax.random.key` * fix: files reformatted with `ruff==0.15.0` * test: update PRNG key test * test: update test effected by the `optax==0.2.7` release * Revert "test: update test effected by the `optax==0.2.7` release" This reverts commit 14d3ee8. * Revert "fix: files reformatted with `ruff==0.15.0`" This reverts commit 2482781. * test: pickle `jax.random.key` * test: use legacy keys in `test/test_pickle.py` until jax-ml/jax#35065 is resolved
1 parent 115eabd commit d5598e7

File tree

119 files changed

+20764
-20792
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

119 files changed

+20764
-20792
lines changed

README.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ Let us infer the values of the unknown parameters in our model by running MCMC u
6262

6363
>>> nuts_kernel = NUTS(eight_schools)
6464
>>> mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
65-
>>> rng_key = random.PRNGKey(0)
65+
>>> rng_key = random.key(0)
6666
>>> mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))
6767

6868
```
@@ -111,7 +111,7 @@ The values above 1 for the split Gelman Rubin diagnostic (`r_hat`) indicates tha
111111

112112
>>> nuts_kernel = NUTS(eight_schools_noncentered)
113113
>>> mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
114-
>>> rng_key = random.PRNGKey(0)
114+
>>> rng_key = random.key(0)
115115
>>> mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))
116116
>>> mcmc.print_summary(exclude_deterministic=False) # doctest: +SKIP
117117

@@ -161,7 +161,7 @@ Now, let us assume that we have a new school for which we have not observed any
161161
... return numpyro.sample('obs', dist.Normal(mu, tau))
162162

163163
>>> predictive = Predictive(new_school, mcmc.get_samples())
164-
>>> samples_predictive = predictive(random.PRNGKey(1))
164+
>>> samples_predictive = predictive(random.key(1))
165165
>>> print(np.mean(samples_predictive['obs'])) # doctest: +SKIP
166166
3.9886456
167167

@@ -286,18 +286,18 @@ conda install -c conda-forge numpyro
286286

287287
1. Unlike in Pyro, `numpyro.sample('x', dist.Normal(0, 1))` does not work. Why?
288288

289-
You are most likely using a `numpyro.sample` statement outside an inference context. JAX does not have a global random state, and as such, distribution samplers need an explicit random number generator key ([PRNGKey](https://jax.readthedocs.io/en/latest/jax.random.html#jax.random.PRNGKey)) to generate samples from. NumPyro's inference algorithms use the [seed](https://num.pyro.ai/en/latest/handlers.html#seed) handler to thread in a random number generator key, behind the scenes.
289+
You are most likely using a `numpyro.sample` statement outside an inference context. JAX does not have a global random state, and as such, distribution samplers need an explicit random number generator key ([PRNG Key](https://jax.readthedocs.io/en/latest/jax.random.html#jax.random.key)) to generate samples from. NumPyro's inference algorithms use the [seed](https://num.pyro.ai/en/latest/handlers.html#seed) handler to thread in a random number generator key, behind the scenes.
290290

291291
Your options are:
292292

293-
- Call the distribution directly and provide a `PRNGKey`, e.g. `dist.Normal(0, 1).sample(PRNGKey(0))`
294-
- Provide the `rng_key` argument to `numpyro.sample`. e.g. `numpyro.sample('x', dist.Normal(0, 1), rng_key=PRNGKey(0))`.
293+
- Call the distribution directly and provide a PRNG key, e.g. `dist.Normal(0, 1).sample(key(0))`
294+
- Provide the `rng_key` argument to `numpyro.sample`. e.g. `numpyro.sample('x', dist.Normal(0, 1), rng_key=key(0))`.
295295
- Wrap the code in a `seed` handler, used either as a context manager or as a function that wraps over the original callable. e.g.
296296

297297
```python
298-
with handlers.seed(rng_seed=0): # random.PRNGKey(0) is used
299-
x = numpyro.sample('x', dist.Beta(1, 1)) # uses a PRNGKey split from random.PRNGKey(0)
300-
y = numpyro.sample('y', dist.Bernoulli(x)) # uses different PRNGKey split from the last one
298+
with handlers.seed(rng_seed=0): # random.key(0) is used
299+
x = numpyro.sample('x', dist.Beta(1, 1)) # uses a PRNG key split from random.key(0)
300+
y = numpyro.sample('y', dist.Bernoulli(x)) # uses different PRNG key split from the last one
301301
```
302302

303303
, or as a higher order function:

docs/source/contrib.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ Other kernels can be used similarly.
183183
... return x, y_obs
184184
185185
186-
>>> rng_key = random.PRNGKey(seed=42)
186+
>>> rng_key = random.key(seed=42)
187187
>>> rng_key, rng_subkey = random.split(rng_key)
188188
>>> x, y_obs = generate_synthetic_data(
189189
... rng_key=rng_subkey, start=0, stop=1, num=80, scale=0.3

examples/annotation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,12 +284,12 @@ def main(args):
284284
num_chains=args.num_chains,
285285
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
286286
)
287-
mcmc.run(random.PRNGKey(0), *data)
287+
mcmc.run(random.key(0), *data)
288288
mcmc.print_summary()
289289

290290
posterior_samples = mcmc.get_samples()
291291
predictive = Predictive(model, posterior_samples, infer_discrete=True)
292-
discrete_samples = predictive(random.PRNGKey(1), *data)
292+
discrete_samples = predictive(random.key(1), *data)
293293

294294
item_class = vmap(lambda x: jnp.bincount(x, length=4), in_axes=1)(
295295
discrete_samples["c"].squeeze(-1)

examples/ar2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def run_inference(model, args, rng_key, y):
105105
def main(args):
106106
# generate artificial dataset
107107
num_data = args.num_data
108-
rng_key = jax.random.PRNGKey(0)
108+
rng_key = jax.random.key(0)
109109
t = jnp.arange(0, num_data)
110110
y = jnp.sin(t) + random.normal(rng_key, (num_data,)) * 0.1
111111

examples/baseball.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def main(args):
195195
for i, model in enumerate(
196196
(fully_pooled, not_pooled, partially_pooled, partially_pooled_with_logit)
197197
):
198-
rng_key, rng_key_predict = random.split(random.PRNGKey(i + 1))
198+
rng_key, rng_key_predict = random.split(random.key(i + 1))
199199
zs = run_inference(model, at_bats, hits, rng_key, args)
200200
predict(model, at_bats, hits, zs, rng_key_predict, player_names)
201201
predict(

examples/bnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def main(args):
126126
X, Y, X_test = get_data(N=N, D_X=D_X)
127127

128128
# do inference
129-
rng_key, rng_key_predict = random.split(random.PRNGKey(0))
129+
rng_key, rng_key_predict = random.split(random.key(0))
130130
samples = run_inference(model, args, rng_key, X, Y, D_H)
131131

132132
# predict Y_test at inputs X_test

examples/capture_recapture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def main(args):
343343
)
344344

345345
model = models[args.model]
346-
rng_key = random.PRNGKey(args.rng_seed)
346+
rng_key = random.key(args.rng_seed)
347347
run_inference(model, capture_history, sex, rng_key, args)
348348

349349

examples/covtype.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def model(data, labels, subsample_size=None):
7272

7373

7474
def benchmark_hmc(args, features, labels):
75-
rng_key = random.PRNGKey(1)
75+
rng_key = random.key(1)
7676
start = time.time()
7777
# a MAP estimate at the following source
7878
# https://github.com/google/edward2/blob/master/examples/no_u_turn_sampler/logistic_regression.py#L117
@@ -174,7 +174,7 @@ def benchmark_hmc(args, features, labels):
174174
subsample_size = 1000
175175
guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8])
176176
svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())
177-
svi_result = svi.run(random.PRNGKey(2), 2000, features, labels)
177+
svi_result = svi.run(random.key(2), 2000, features, labels)
178178
params, losses = svi_result.params, svi_result.losses
179179
plt.plot(losses)
180180
plt.show()

examples/cvae-flax/train_baseline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
def create_train_state(model, x, learning_rate_fn):
13-
params = model.init(random.PRNGKey(0), x)
13+
params = model.init(random.key(0), x)
1414
tx = optax.adam(learning_rate_fn)
1515
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
1616
return state
@@ -65,7 +65,7 @@ def train_baseline(
6565
):
6666
state = create_train_state(model, train_fetch(0, train_idx)[0], 0.003)
6767

68-
rng = random.PRNGKey(0)
68+
rng = random.key(0)
6969
best_val_loss = jnp.inf
7070
best_state = state
7171
for i in range(n_epochs):

examples/cvae-flax/train_cvae.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,14 @@ def train_cvae(
7878
n_epochs=100,
7979
):
8080
svi, state = create_train_state(
81-
random.PRNGKey(23), model, guide, train_fetch, baseline_params, 0.003
81+
random.key(23), model, guide, train_fetch, baseline_params, 0.003
8282
)
8383

8484
p1 = baseline_params.unfreeze()["params"]["Dense_0"]["kernel"]
8585
p2 = state.optim_state[1][0]["baseline$params"]["Dense_0"]["kernel"]
8686
assert jnp.all(p1 == p2)
8787

88-
rng = random.PRNGKey(0)
88+
rng = random.key(0)
8989
best_val_loss = jnp.inf
9090
best_state = state
9191
for i in range(n_epochs):

0 commit comments

Comments
 (0)