Skip to content

Commit 039b277

Browse files
reubenharryjunpenglaoJakobRobnik
authored
Add MCLMC sampler (#586)
* initial draft of mclmc * refactor * wip * wip * wip * wip * wip * fix pre-commit * remove dim from class * add docstrings * add mclmc to init * move minimal_norm to integrators * move update pos and momentum * remove params * Infer the shape from inverse_mass_matrix outside the function step * use tree_map * integration now aligned with mclmc repo * dE and logdensity align too (fixed sign error) * make L and step size arguments to kernel * rough draft of tuning: works * remove inv mass matrix * almost correct * almost correct * move tuning to adaptation * tuning works in this commit * clean up 1 * remove sigma from tuning * wip * fix linting * rename T and V * uniformity wip * make uniform implementation of integrators * make uniform implementation of integrators * fix minimal norm integrator * add warning to tune3 * Refactor integrators.py to make it more general. Also add momentum update based on Esh dynamics Co-authored-by: Reuben Cohn-Gordon <reubenharry@gmail.com> * temp: explore * Refactor to use integrator generation functions * Additional refactoring Also add test for esh momentum update. Co-authored-by: Reuben Cohn-Gordon <reubenharry@gmail.com> * Minor clean up. * Use standard JAX ops * new integrator * add references * flake * temporarily add 'explore' * temporarily add 'explore' * Adding a test for energy preservation. Co-authored-by: Reuben Cohn-Gordon <reubenharry@gmail.com> * fix formatting * wip: tests * use pytrees for partially_refresh_momentum, and add test * update docstring * remove 'explore' * fix pre-commit * adding randomized MCHMC * wip checkpoint on tuning * align blackjax and mclmc repos, for tuning * use effective_sample_size * patial rename * rename * clean up tuning * clean up tuning * RANDOMIZE KEYS * ADD TEST * ADD TEST * MERGE MAIN * INCREASE CODE COVERAGE * REMOVE REDUNDANT LINE * ADD NAME 'mclmc' * SPLIT KEYS AND FIX DOCSTRING * FIX MINOR ERRORS * FIX MINOR ERRORS * RANDOMIZE KEYS (reversion) * PRECOMMIT CLEAN UP * ADD KWARGS FOR DEFAULT HYPERPARAMS * UPDATE ESS * NAME CHANGES * NAME CHANGES * MINOR FIXES --------- Co-authored-by: Junpeng Lao <junpenglao@gmail.com> Co-authored-by: jakob.robnik <jakob.robnik@gmail.com>
1 parent f49945d commit 039b277

File tree

8 files changed

+567
-1
lines changed

8 files changed

+567
-1
lines changed

blackjax/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from blackjax._version import __version__
22

33
from .adaptation.chees_adaptation import chees_adaptation
4+
from .adaptation.mclmc_adaptation import mclmc_find_L_and_step_size
45
from .adaptation.meads_adaptation import meads_adaptation
56
from .adaptation.pathfinder_adaptation import pathfinder_adaptation
67
from .adaptation.window_adaptation import window_adaptation
@@ -12,6 +13,7 @@
1213
from .mcmc.hmc import dynamic_hmc, hmc
1314
from .mcmc.mala import mala
1415
from .mcmc.marginal_latent_gaussian import mgrad_gaussian
16+
from .mcmc.mclmc import mclmc
1517
from .mcmc.nuts import nuts
1618
from .mcmc.periodic_orbital import orbital_hmc
1719
from .mcmc.random_walk import additive_step_random_walk, irmh, rmh
@@ -40,6 +42,7 @@
4042
"additive_step_random_walk",
4143
"rmh",
4244
"irmh",
45+
"mclmc",
4346
"elliptical_slice",
4447
"ghmc",
4548
"barker_proposal",
@@ -51,6 +54,7 @@
5154
"meads_adaptation",
5255
"chees_adaptation",
5356
"pathfinder_adaptation",
57+
"mclmc_find_L_and_step_size", # mclmc adaptation
5458
"adaptive_tempered_smc", # smc
5559
"tempered_smc",
5660
"meanfield_vi", # variational inference

blackjax/adaptation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from . import (
22
chees_adaptation,
3+
mclmc_adaptation,
34
meads_adaptation,
45
pathfinder_adaptation,
56
window_adaptation,
@@ -10,4 +11,5 @@
1011
"meads_adaptation",
1112
"window_adaptation",
1213
"pathfinder_adaptation",
14+
"mclmc_adaptation",
1315
]
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
# Copyright 2020- The Blackjax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Algorithms to adapt the MCLMC kernel parameters, namely step size and L.
15+
16+
"""
17+
18+
from typing import NamedTuple
19+
20+
import jax
21+
import jax.numpy as jnp
22+
from jax.flatten_util import ravel_pytree
23+
24+
from blackjax.diagnostics import effective_sample_size # type: ignore
25+
from blackjax.util import pytree_size
26+
27+
28+
class MCLMCAdaptationState(NamedTuple):
29+
"""Represents the tunable parameters for MCLMC adaptation.
30+
31+
Attributes:
32+
L (float): The momentum decoherent rate for the MCLMC algorithm.
33+
step_size (float): The step size used for the MCLMC algorithm.
34+
"""
35+
36+
L: float
37+
step_size: float
38+
39+
40+
def mclmc_find_L_and_step_size(
41+
mclmc_kernel,
42+
num_steps,
43+
state,
44+
rng_key,
45+
frac_tune1=0.1,
46+
frac_tune2=0.1,
47+
frac_tune3=0.1,
48+
desired_energy_var=5e-4,
49+
trust_in_estimate=1.5,
50+
num_effective_samples=150,
51+
):
52+
"""
53+
Finds the optimal value of the parameters for the MCLMC algorithm.
54+
55+
Args:
56+
mclmc_kernel (callable): The kernel function used for the MCMC algorithm.
57+
num_steps (int): The number of MCMC steps that will subsequently be run, after tuning.
58+
state (MCMCState): The initial state of the MCMC algorithm.
59+
rng_key (jax.random.PRNGKey): The random number generator key.
60+
frac_tune1 (float): The fraction of tuning for the first step of the adaptation.
61+
frac_tune2 (float): The fraction of tuning for the second step of the adaptation.
62+
frac_tune3 (float): The fraction of tuning for the third step of the adaptation.
63+
desired_energy_var (float): The desired energy variance for the MCMC algorithm.
64+
trust_in_estimate (float): The trust in the estimate of optimal stepsize.
65+
num_effective_samples (int): The number of effective samples for the MCMC algorithm.
66+
67+
Returns:
68+
tuple: A tuple containing the final state of the MCMC algorithm and the final hyperparameters.
69+
70+
Raises:
71+
None
72+
73+
Examples:
74+
# Define the kernel function
75+
def kernel(x):
76+
return x ** 2
77+
78+
# Define the initial state
79+
initial_state = MCMCState(position=0, momentum=1)
80+
81+
# Generate a random number generator key
82+
rng_key = jax.random.PRNGKey(0)
83+
84+
# Find the optimal parameters for the MCLMC algorithm
85+
final_state, final_params = mclmc_find_L_and_step_size(
86+
mclmc_kernel=kernel,
87+
num_steps=1000,
88+
state=initial_state,
89+
rng_key=rng_key,
90+
frac_tune1=0.2,
91+
frac_tune2=0.3,
92+
frac_tune3=0.1,
93+
desired_energy_var=1e-4,
94+
trust_in_estimate=2.0,
95+
num_effective_samples=200,
96+
)
97+
"""
98+
dim = pytree_size(state.position)
99+
params = MCLMCAdaptationState(jnp.sqrt(dim), jnp.sqrt(dim) * 0.25)
100+
part1_key, part2_key = jax.random.split(rng_key, 2)
101+
102+
state, params = make_L_step_size_adaptation(
103+
kernel=mclmc_kernel,
104+
dim=dim,
105+
frac_tune1=frac_tune1,
106+
frac_tune2=frac_tune2,
107+
desired_energy_var=desired_energy_var,
108+
trust_in_estimate=trust_in_estimate,
109+
num_effective_samples=num_effective_samples,
110+
)(state, params, num_steps, part1_key)
111+
112+
if frac_tune3 != 0:
113+
state, params = make_adaptation_L(mclmc_kernel, frac=frac_tune3, Lfactor=0.4)(
114+
state, params, num_steps, part2_key
115+
)
116+
117+
return state, params
118+
119+
120+
def make_L_step_size_adaptation(
121+
kernel,
122+
dim,
123+
frac_tune1,
124+
frac_tune2,
125+
desired_energy_var=1e-3,
126+
trust_in_estimate=1.5,
127+
num_effective_samples=150,
128+
):
129+
"""Adapts the stepsize and L of the MCLMC kernel. Designed for the unadjusted MCLMC"""
130+
131+
decay_rate = (num_effective_samples - 1.0) / (num_effective_samples + 1.0)
132+
133+
def predictor(previous_state, params, adaptive_state, rng_key):
134+
"""does one step with the dynamics and updates the prediction for the optimal stepsize
135+
Designed for the unadjusted MCHMC"""
136+
137+
time, x_average, step_size_max = adaptive_state
138+
139+
# dynamics
140+
next_state, info = kernel(
141+
rng_key=rng_key,
142+
state=previous_state,
143+
L=params.L,
144+
step_size=params.step_size,
145+
)
146+
# step updating
147+
success, state, step_size_max, energy_change = handle_nans(
148+
previous_state,
149+
next_state,
150+
params.step_size,
151+
step_size_max,
152+
info.energy_change,
153+
)
154+
155+
# Warning: var = 0 if there were nans, but we will give it a very small weight
156+
xi = (
157+
jnp.square(energy_change) / (dim * desired_energy_var)
158+
) + 1e-8 # 1e-8 is added to avoid divergences in log xi
159+
weight = jnp.exp(
160+
-0.5 * jnp.square(jnp.log(xi) / (6.0 * trust_in_estimate))
161+
) # the weight reduces the impact of stepsizes which are much larger on much smaller than the desired one.
162+
163+
x_average = decay_rate * x_average + weight * (
164+
xi / jnp.power(params.step_size, 6.0)
165+
)
166+
time = decay_rate * time + weight
167+
step_size = jnp.power(
168+
x_average / time, -1.0 / 6.0
169+
) # We use the Var[E] = O(eps^6) relation here.
170+
step_size = (step_size < step_size_max) * step_size + (
171+
step_size > step_size_max
172+
) * step_size_max # if the proposed stepsize is above the stepsize where we have seen divergences
173+
params_new = params._replace(step_size=step_size)
174+
175+
return state, params_new, params_new, (time, x_average, step_size_max), success
176+
177+
def update_kalman(x, state, outer_weight, success, step_size):
178+
"""kalman filter to estimate the size of the posterior"""
179+
time, x_average, x_squared_average = state
180+
weight = outer_weight * step_size * success
181+
zero_prevention = 1 - outer_weight
182+
x_average = (time * x_average + weight * x) / (
183+
time + weight + zero_prevention
184+
) # Update <f(x)> with a Kalman filter
185+
x_squared_average = (time * x_squared_average + weight * jnp.square(x)) / (
186+
time + weight + zero_prevention
187+
) # Update <f(x)> with a Kalman filter
188+
time += weight
189+
return (time, x_average, x_squared_average)
190+
191+
adap0 = (0.0, 0.0, jnp.inf)
192+
193+
def step(iteration_state, weight_and_key):
194+
"""does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize"""
195+
196+
outer_weight, rng_key = weight_and_key
197+
state, params, adaptive_state, kalman_state = iteration_state
198+
state, params, params_final, adaptive_state, success = predictor(
199+
state, params, adaptive_state, rng_key
200+
)
201+
position, _ = ravel_pytree(state.position)
202+
kalman_state = update_kalman(
203+
position, kalman_state, outer_weight, success, params.step_size
204+
)
205+
206+
return (state, params_final, adaptive_state, kalman_state), None
207+
208+
def L_step_size_adaptation(state, params, num_steps, rng_key):
209+
num_steps1, num_steps2 = int(num_steps * frac_tune1), int(
210+
num_steps * frac_tune2
211+
)
212+
L_step_size_adaptation_keys = jax.random.split(rng_key, num_steps1 + num_steps2)
213+
214+
# we use the last num_steps2 to compute the diagonal preconditioner
215+
outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))
216+
217+
# initial state of the kalman filter
218+
kalman_state = (0.0, jnp.zeros(dim), jnp.zeros(dim))
219+
220+
# run the steps
221+
kalman_state = jax.lax.scan(
222+
step,
223+
init=(state, params, adap0, kalman_state),
224+
xs=(outer_weights, L_step_size_adaptation_keys),
225+
length=num_steps1 + num_steps2,
226+
)[0]
227+
state, params, _, kalman_state_output = kalman_state
228+
229+
L = params.L
230+
# determine L
231+
if num_steps2 != 0.0:
232+
_, F1, F2 = kalman_state_output
233+
variances = F2 - jnp.square(F1)
234+
L = jnp.sqrt(jnp.sum(variances))
235+
236+
return state, MCLMCAdaptationState(L, params.step_size)
237+
238+
return L_step_size_adaptation
239+
240+
241+
def make_adaptation_L(kernel, frac, Lfactor):
242+
"""determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)"""
243+
244+
def adaptation_L(state, params, num_steps, key):
245+
num_steps = int(num_steps * frac)
246+
adaptation_L_keys = jax.random.split(key, num_steps)
247+
248+
# run kernel in the normal way
249+
state, info = jax.lax.scan(
250+
f=lambda s, k: (
251+
kernel(rng_key=k, state=s, L=params.L, step_size=params.step_size)
252+
),
253+
init=state,
254+
xs=adaptation_L_keys,
255+
)
256+
samples = info.transformed_position # tranform is the identity here
257+
flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples)
258+
flat_samples = flat_samples.reshape(2, num_steps // 2, -1)
259+
ESS = effective_sample_size(flat_samples)
260+
261+
return state, params._replace(
262+
L=Lfactor * params.step_size * jnp.mean(num_steps / ESS)
263+
)
264+
265+
return adaptation_L
266+
267+
268+
def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_change):
269+
"""if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case."""
270+
271+
reduced_step_size = 0.8
272+
p, unravel_fn = ravel_pytree(next_state.position)
273+
nonans = jnp.all(jnp.isfinite(p))
274+
state, step_size, kinetic_change = jax.tree_util.tree_map(
275+
lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old),
276+
(next_state, step_size_max, kinetic_change),
277+
(previous_state, step_size * reduced_step_size, 0.0),
278+
)
279+
280+
return nonans, state, step_size, kinetic_change

blackjax/mcmc/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
hmc,
66
mala,
77
marginal_latent_gaussian,
8+
mclmc,
89
nuts,
910
periodic_orbital,
1011
random_walk,
@@ -20,4 +21,5 @@
2021
"periodic_orbital",
2122
"marginal_latent_gaussian",
2223
"random_walk",
24+
"mclmc",
2325
]

blackjax/mcmc/integrators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,5 +365,5 @@ def noneuclidean_integrator(
365365

366366

367367
noneuclidean_leapfrog = generate_noneuclidean_integrator(velocity_verlet_cofficients)
368-
noneuclidean_mclachlan = generate_noneuclidean_integrator(mclachlan_cofficients)
369368
noneuclidean_yoshida = generate_noneuclidean_integrator(yoshida_cofficients)
369+
noneuclidean_mclachlan = generate_noneuclidean_integrator(mclachlan_cofficients)

0 commit comments

Comments
 (0)