|
| 1 | +# Copyright 2020- The Blackjax Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +"""Algorithms to adapt the MCLMC kernel parameters, namely step size and L. |
| 15 | +
|
| 16 | +""" |
| 17 | + |
| 18 | +from typing import NamedTuple |
| 19 | + |
| 20 | +import jax |
| 21 | +import jax.numpy as jnp |
| 22 | +from jax.flatten_util import ravel_pytree |
| 23 | + |
| 24 | +from blackjax.diagnostics import effective_sample_size # type: ignore |
| 25 | +from blackjax.util import pytree_size |
| 26 | + |
| 27 | + |
| 28 | +class MCLMCAdaptationState(NamedTuple): |
| 29 | + """Represents the tunable parameters for MCLMC adaptation. |
| 30 | +
|
| 31 | + Attributes: |
| 32 | + L (float): The momentum decoherent rate for the MCLMC algorithm. |
| 33 | + step_size (float): The step size used for the MCLMC algorithm. |
| 34 | + """ |
| 35 | + |
| 36 | + L: float |
| 37 | + step_size: float |
| 38 | + |
| 39 | + |
| 40 | +def mclmc_find_L_and_step_size( |
| 41 | + mclmc_kernel, |
| 42 | + num_steps, |
| 43 | + state, |
| 44 | + rng_key, |
| 45 | + frac_tune1=0.1, |
| 46 | + frac_tune2=0.1, |
| 47 | + frac_tune3=0.1, |
| 48 | + desired_energy_var=5e-4, |
| 49 | + trust_in_estimate=1.5, |
| 50 | + num_effective_samples=150, |
| 51 | +): |
| 52 | + """ |
| 53 | + Finds the optimal value of the parameters for the MCLMC algorithm. |
| 54 | +
|
| 55 | + Args: |
| 56 | + mclmc_kernel (callable): The kernel function used for the MCMC algorithm. |
| 57 | + num_steps (int): The number of MCMC steps that will subsequently be run, after tuning. |
| 58 | + state (MCMCState): The initial state of the MCMC algorithm. |
| 59 | + rng_key (jax.random.PRNGKey): The random number generator key. |
| 60 | + frac_tune1 (float): The fraction of tuning for the first step of the adaptation. |
| 61 | + frac_tune2 (float): The fraction of tuning for the second step of the adaptation. |
| 62 | + frac_tune3 (float): The fraction of tuning for the third step of the adaptation. |
| 63 | + desired_energy_var (float): The desired energy variance for the MCMC algorithm. |
| 64 | + trust_in_estimate (float): The trust in the estimate of optimal stepsize. |
| 65 | + num_effective_samples (int): The number of effective samples for the MCMC algorithm. |
| 66 | +
|
| 67 | + Returns: |
| 68 | + tuple: A tuple containing the final state of the MCMC algorithm and the final hyperparameters. |
| 69 | +
|
| 70 | + Raises: |
| 71 | + None |
| 72 | +
|
| 73 | + Examples: |
| 74 | + # Define the kernel function |
| 75 | + def kernel(x): |
| 76 | + return x ** 2 |
| 77 | +
|
| 78 | + # Define the initial state |
| 79 | + initial_state = MCMCState(position=0, momentum=1) |
| 80 | +
|
| 81 | + # Generate a random number generator key |
| 82 | + rng_key = jax.random.PRNGKey(0) |
| 83 | +
|
| 84 | + # Find the optimal parameters for the MCLMC algorithm |
| 85 | + final_state, final_params = mclmc_find_L_and_step_size( |
| 86 | + mclmc_kernel=kernel, |
| 87 | + num_steps=1000, |
| 88 | + state=initial_state, |
| 89 | + rng_key=rng_key, |
| 90 | + frac_tune1=0.2, |
| 91 | + frac_tune2=0.3, |
| 92 | + frac_tune3=0.1, |
| 93 | + desired_energy_var=1e-4, |
| 94 | + trust_in_estimate=2.0, |
| 95 | + num_effective_samples=200, |
| 96 | + ) |
| 97 | + """ |
| 98 | + dim = pytree_size(state.position) |
| 99 | + params = MCLMCAdaptationState(jnp.sqrt(dim), jnp.sqrt(dim) * 0.25) |
| 100 | + part1_key, part2_key = jax.random.split(rng_key, 2) |
| 101 | + |
| 102 | + state, params = make_L_step_size_adaptation( |
| 103 | + kernel=mclmc_kernel, |
| 104 | + dim=dim, |
| 105 | + frac_tune1=frac_tune1, |
| 106 | + frac_tune2=frac_tune2, |
| 107 | + desired_energy_var=desired_energy_var, |
| 108 | + trust_in_estimate=trust_in_estimate, |
| 109 | + num_effective_samples=num_effective_samples, |
| 110 | + )(state, params, num_steps, part1_key) |
| 111 | + |
| 112 | + if frac_tune3 != 0: |
| 113 | + state, params = make_adaptation_L(mclmc_kernel, frac=frac_tune3, Lfactor=0.4)( |
| 114 | + state, params, num_steps, part2_key |
| 115 | + ) |
| 116 | + |
| 117 | + return state, params |
| 118 | + |
| 119 | + |
| 120 | +def make_L_step_size_adaptation( |
| 121 | + kernel, |
| 122 | + dim, |
| 123 | + frac_tune1, |
| 124 | + frac_tune2, |
| 125 | + desired_energy_var=1e-3, |
| 126 | + trust_in_estimate=1.5, |
| 127 | + num_effective_samples=150, |
| 128 | +): |
| 129 | + """Adapts the stepsize and L of the MCLMC kernel. Designed for the unadjusted MCLMC""" |
| 130 | + |
| 131 | + decay_rate = (num_effective_samples - 1.0) / (num_effective_samples + 1.0) |
| 132 | + |
| 133 | + def predictor(previous_state, params, adaptive_state, rng_key): |
| 134 | + """does one step with the dynamics and updates the prediction for the optimal stepsize |
| 135 | + Designed for the unadjusted MCHMC""" |
| 136 | + |
| 137 | + time, x_average, step_size_max = adaptive_state |
| 138 | + |
| 139 | + # dynamics |
| 140 | + next_state, info = kernel( |
| 141 | + rng_key=rng_key, |
| 142 | + state=previous_state, |
| 143 | + L=params.L, |
| 144 | + step_size=params.step_size, |
| 145 | + ) |
| 146 | + # step updating |
| 147 | + success, state, step_size_max, energy_change = handle_nans( |
| 148 | + previous_state, |
| 149 | + next_state, |
| 150 | + params.step_size, |
| 151 | + step_size_max, |
| 152 | + info.energy_change, |
| 153 | + ) |
| 154 | + |
| 155 | + # Warning: var = 0 if there were nans, but we will give it a very small weight |
| 156 | + xi = ( |
| 157 | + jnp.square(energy_change) / (dim * desired_energy_var) |
| 158 | + ) + 1e-8 # 1e-8 is added to avoid divergences in log xi |
| 159 | + weight = jnp.exp( |
| 160 | + -0.5 * jnp.square(jnp.log(xi) / (6.0 * trust_in_estimate)) |
| 161 | + ) # the weight reduces the impact of stepsizes which are much larger on much smaller than the desired one. |
| 162 | + |
| 163 | + x_average = decay_rate * x_average + weight * ( |
| 164 | + xi / jnp.power(params.step_size, 6.0) |
| 165 | + ) |
| 166 | + time = decay_rate * time + weight |
| 167 | + step_size = jnp.power( |
| 168 | + x_average / time, -1.0 / 6.0 |
| 169 | + ) # We use the Var[E] = O(eps^6) relation here. |
| 170 | + step_size = (step_size < step_size_max) * step_size + ( |
| 171 | + step_size > step_size_max |
| 172 | + ) * step_size_max # if the proposed stepsize is above the stepsize where we have seen divergences |
| 173 | + params_new = params._replace(step_size=step_size) |
| 174 | + |
| 175 | + return state, params_new, params_new, (time, x_average, step_size_max), success |
| 176 | + |
| 177 | + def update_kalman(x, state, outer_weight, success, step_size): |
| 178 | + """kalman filter to estimate the size of the posterior""" |
| 179 | + time, x_average, x_squared_average = state |
| 180 | + weight = outer_weight * step_size * success |
| 181 | + zero_prevention = 1 - outer_weight |
| 182 | + x_average = (time * x_average + weight * x) / ( |
| 183 | + time + weight + zero_prevention |
| 184 | + ) # Update <f(x)> with a Kalman filter |
| 185 | + x_squared_average = (time * x_squared_average + weight * jnp.square(x)) / ( |
| 186 | + time + weight + zero_prevention |
| 187 | + ) # Update <f(x)> with a Kalman filter |
| 188 | + time += weight |
| 189 | + return (time, x_average, x_squared_average) |
| 190 | + |
| 191 | + adap0 = (0.0, 0.0, jnp.inf) |
| 192 | + |
| 193 | + def step(iteration_state, weight_and_key): |
| 194 | + """does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize""" |
| 195 | + |
| 196 | + outer_weight, rng_key = weight_and_key |
| 197 | + state, params, adaptive_state, kalman_state = iteration_state |
| 198 | + state, params, params_final, adaptive_state, success = predictor( |
| 199 | + state, params, adaptive_state, rng_key |
| 200 | + ) |
| 201 | + position, _ = ravel_pytree(state.position) |
| 202 | + kalman_state = update_kalman( |
| 203 | + position, kalman_state, outer_weight, success, params.step_size |
| 204 | + ) |
| 205 | + |
| 206 | + return (state, params_final, adaptive_state, kalman_state), None |
| 207 | + |
| 208 | + def L_step_size_adaptation(state, params, num_steps, rng_key): |
| 209 | + num_steps1, num_steps2 = int(num_steps * frac_tune1), int( |
| 210 | + num_steps * frac_tune2 |
| 211 | + ) |
| 212 | + L_step_size_adaptation_keys = jax.random.split(rng_key, num_steps1 + num_steps2) |
| 213 | + |
| 214 | + # we use the last num_steps2 to compute the diagonal preconditioner |
| 215 | + outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) |
| 216 | + |
| 217 | + # initial state of the kalman filter |
| 218 | + kalman_state = (0.0, jnp.zeros(dim), jnp.zeros(dim)) |
| 219 | + |
| 220 | + # run the steps |
| 221 | + kalman_state = jax.lax.scan( |
| 222 | + step, |
| 223 | + init=(state, params, adap0, kalman_state), |
| 224 | + xs=(outer_weights, L_step_size_adaptation_keys), |
| 225 | + length=num_steps1 + num_steps2, |
| 226 | + )[0] |
| 227 | + state, params, _, kalman_state_output = kalman_state |
| 228 | + |
| 229 | + L = params.L |
| 230 | + # determine L |
| 231 | + if num_steps2 != 0.0: |
| 232 | + _, F1, F2 = kalman_state_output |
| 233 | + variances = F2 - jnp.square(F1) |
| 234 | + L = jnp.sqrt(jnp.sum(variances)) |
| 235 | + |
| 236 | + return state, MCLMCAdaptationState(L, params.step_size) |
| 237 | + |
| 238 | + return L_step_size_adaptation |
| 239 | + |
| 240 | + |
| 241 | +def make_adaptation_L(kernel, frac, Lfactor): |
| 242 | + """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" |
| 243 | + |
| 244 | + def adaptation_L(state, params, num_steps, key): |
| 245 | + num_steps = int(num_steps * frac) |
| 246 | + adaptation_L_keys = jax.random.split(key, num_steps) |
| 247 | + |
| 248 | + # run kernel in the normal way |
| 249 | + state, info = jax.lax.scan( |
| 250 | + f=lambda s, k: ( |
| 251 | + kernel(rng_key=k, state=s, L=params.L, step_size=params.step_size) |
| 252 | + ), |
| 253 | + init=state, |
| 254 | + xs=adaptation_L_keys, |
| 255 | + ) |
| 256 | + samples = info.transformed_position # tranform is the identity here |
| 257 | + flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples) |
| 258 | + flat_samples = flat_samples.reshape(2, num_steps // 2, -1) |
| 259 | + ESS = effective_sample_size(flat_samples) |
| 260 | + |
| 261 | + return state, params._replace( |
| 262 | + L=Lfactor * params.step_size * jnp.mean(num_steps / ESS) |
| 263 | + ) |
| 264 | + |
| 265 | + return adaptation_L |
| 266 | + |
| 267 | + |
| 268 | +def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_change): |
| 269 | + """if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case.""" |
| 270 | + |
| 271 | + reduced_step_size = 0.8 |
| 272 | + p, unravel_fn = ravel_pytree(next_state.position) |
| 273 | + nonans = jnp.all(jnp.isfinite(p)) |
| 274 | + state, step_size, kinetic_change = jax.tree_util.tree_map( |
| 275 | + lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), |
| 276 | + (next_state, step_size_max, kinetic_change), |
| 277 | + (previous_state, step_size * reduced_step_size, 0.0), |
| 278 | + ) |
| 279 | + |
| 280 | + return nonans, state, step_size, kinetic_change |
0 commit comments