Skip to content

Commit f49945d

Browse files
authored
Refactor proposal.py (#603)
* Refactor proposal.py * Fix test * Fix test 2 * Fix test 3
1 parent 41f47d5 commit f49945d

File tree

11 files changed

+168
-284
lines changed

11 files changed

+168
-284
lines changed

blackjax/adaptation/chees_adaptation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,8 +406,8 @@ def one_step(carry, rng_key):
406406
new_states, info = jax.vmap(_step_fn)(keys, states)
407407
new_adaptation_state = update(
408408
adaptation_state,
409-
info.proposal.state.position,
410-
info.proposal.state.momentum,
409+
info.proposal.position,
410+
info.proposal.momentum,
411411
states.position,
412412
info.acceptance_rate,
413413
info.is_divergent,

blackjax/mcmc/barker.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from jax.tree_util import tree_leaves, tree_map
2222

2323
from blackjax.base import SamplingAlgorithm
24+
from blackjax.mcmc.proposal import static_binomial_sampling
2425
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey
2526

2627
__all__ = ["BarkerState", "BarkerInfo", "init", "build_kernel", "barker_proposal"]
@@ -99,9 +100,7 @@ def ratio_proposal_nd(y, x, log_y, log_x):
99100
state.logdensity_grad,
100101
)
101102
ratio_proposal = sum(tree_leaves(ratios_proposals))
102-
log_p_accept = proposal.logdensity - state.logdensity + ratio_proposal
103-
p_accept = jnp.exp(log_p_accept)
104-
return jnp.minimum(1.0, p_accept)
103+
return proposal.logdensity - state.logdensity + ratio_proposal
105104

106105
def kernel(
107106
rng_key: PRNGKey, state: BarkerState, logdensity_fn: Callable, step_size: float
@@ -119,13 +118,12 @@ def kernel(
119118
proposed_pos, proposed_logdensity, proposed_logdensity_grad
120119
)
121120

122-
p_accept = _compute_acceptance_probability(state, proposed_state)
123-
124-
accept = jax.random.uniform(key_rmh) < p_accept
125-
126-
state = jax.lax.cond(accept, lambda: proposed_state, lambda: state)
127-
info = BarkerInfo(p_accept, accept, proposed_state)
128-
return state, info
121+
log_p_accept = _compute_acceptance_probability(state, proposed_state)
122+
accepted_state, info = static_binomial_sampling(
123+
key_rmh, log_p_accept, state, proposed_state
124+
)
125+
do_accept, p_accept, _ = info
126+
return accepted_state, BarkerInfo(p_accept, do_accept, proposed_state)
129127

130128
return kernel
131129

blackjax/mcmc/ghmc.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
import blackjax.mcmc.hmc as hmc
2121
import blackjax.mcmc.integrators as integrators
2222
import blackjax.mcmc.metrics as metrics
23-
import blackjax.mcmc.proposal as proposal
2423
from blackjax.base import SamplingAlgorithm
24+
from blackjax.mcmc.proposal import nonreversible_slice_sampling
2525
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey
2626
from blackjax.util import generate_gaussian_noise
2727

@@ -94,7 +94,6 @@ def build_kernel(
9494
returns a new state of the chain along with information about the transition.
9595
9696
"""
97-
sample_proposal = proposal.nonreversible_slice_sampling
9897

9998
def kernel(
10099
rng_key: PRNGKey,
@@ -143,7 +142,7 @@ def kernel(
143142
kinetic_energy_fn,
144143
step_size,
145144
divergence_threshold=divergence_threshold,
146-
sample_proposal=sample_proposal,
145+
sample_proposal=nonreversible_slice_sampling,
147146
)
148147

149148
key_momentum, key_noise = jax.random.split(rng_key)
@@ -158,14 +157,14 @@ def kernel(
158157
)
159158
# Note that ghmc use nonreversible_slice_sampling, which overloads the pattern
160159
# of SampleProposal and do not actually return the acceptance rate.
161-
proposal, info = proposal_generator(slice, integrator_state)
160+
proposal, info, slice_next = proposal_generator(slice, integrator_state)
162161
proposal = hmc.flip_momentum(proposal)
163162
state = GHMCState(
164163
position=proposal.position,
165164
momentum=proposal.momentum,
166165
logdensity=proposal.logdensity,
167166
logdensity_grad=proposal.logdensity_grad,
168-
slice=info.acceptance_rate,
167+
slice=slice_next,
169168
)
170169

171170
return state, info

blackjax/mcmc/hmc.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919
import blackjax.mcmc.integrators as integrators
2020
import blackjax.mcmc.metrics as metrics
21-
import blackjax.mcmc.proposal as proposal
2221
import blackjax.mcmc.trajectory as trajectory
2322
from blackjax.base import SamplingAlgorithm
23+
from blackjax.mcmc.proposal import safe_energy_diff, static_binomial_sampling
2424
from blackjax.mcmc.trajectory import hmc_energy
2525
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey
2626

@@ -166,7 +166,7 @@ def kernel(
166166
integrator_state = integrators.IntegratorState(
167167
position, momentum, logdensity, logdensity_grad
168168
)
169-
proposal, info = proposal_generator(key_integrator, integrator_state)
169+
proposal, info, _ = proposal_generator(key_integrator, integrator_state)
170170
proposal = HMCState(
171171
proposal.position, proposal.logdensity, proposal.logdensity_grad
172172
)
@@ -404,7 +404,7 @@ def hmc_proposal(
404404
num_integration_steps: int = 1,
405405
divergence_threshold: float = 1000,
406406
*,
407-
sample_proposal: Callable = proposal.static_binomial_sampling,
407+
sample_proposal: Callable = static_binomial_sampling,
408408
) -> Callable:
409409
"""Vanilla HMC algorithm.
410410
@@ -433,33 +433,32 @@ def hmc_proposal(
433433
434434
"""
435435
build_trajectory = trajectory.static_integration(integrator)
436-
init_proposal, generate_proposal = proposal.proposal_generator(
437-
hmc_energy(kinetic_energy)
438-
)
436+
hmc_energy_fn = hmc_energy(kinetic_energy)
439437

440438
def generate(
441439
rng_key, state: integrators.IntegratorState
442-
) -> tuple[integrators.IntegratorState, HMCInfo]:
440+
) -> tuple[integrators.IntegratorState, HMCInfo, ArrayTree]:
443441
"""Generate a new chain state."""
444442
end_state = build_trajectory(state, step_size, num_integration_steps)
445443
end_state = flip_momentum(end_state)
446-
proposal = init_proposal(state)
447-
new_proposal = generate_proposal(proposal.energy, end_state)
448-
is_diverging = -new_proposal.weight > divergence_threshold
449-
sampled_proposal, *info = sample_proposal(rng_key, proposal, new_proposal)
450-
do_accept, p_accept = info
444+
proposal_energy = hmc_energy_fn(state)
445+
new_energy = hmc_energy_fn(end_state)
446+
delta_energy = safe_energy_diff(proposal_energy, new_energy)
447+
is_diverging = -delta_energy > divergence_threshold
448+
sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state)
449+
do_accept, p_accept, other_proposal_info = info
451450

452451
info = HMCInfo(
453452
state.momentum,
454453
p_accept,
455454
do_accept,
456455
is_diverging,
457-
new_proposal.energy,
458-
new_proposal,
456+
new_energy,
457+
end_state,
459458
num_integration_steps,
460459
)
461460

462-
return sampled_proposal.state, info
461+
return sampled_state, info, other_proposal_info
463462

464463
return generate
465464

blackjax/mcmc/mala.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def transition_energy(state, new_state, step_size):
8989
)
9090
return -state.logdensity + 0.25 * (1.0 / step_size) * theta_dot
9191

92-
init_proposal, generate_proposal = proposal.asymmetric_proposal_generator(
92+
compute_acceptance_ratio = proposal.compute_asymmetric_acceptance_ratio(
9393
transition_energy
9494
)
9595
sample_proposal = proposal.static_binomial_sampling
@@ -106,15 +106,13 @@ def kernel(
106106
new_state = integrator(key_integrator, state, step_size)
107107
new_state = MALAState(*new_state)
108108

109-
proposal = init_proposal(state)
110-
new_proposal = generate_proposal(state, new_state, step_size=step_size)
111-
sampled_proposal, do_accept, p_accept = sample_proposal(
112-
key_rmh, proposal, new_proposal
113-
)
109+
log_p_accept = compute_acceptance_ratio(state, new_state, step_size=step_size)
110+
accepted_state, info = sample_proposal(key_rmh, log_p_accept, state, new_state)
111+
do_accept, p_accept, _ = info
114112

115113
info = MALAInfo(p_accept, do_accept)
116114

117-
return sampled_proposal.state, info
115+
return accepted_state, info
118116

119117
return kernel
120118

blackjax/mcmc/marginal_latent_gaussian.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import jax.scipy.linalg as linalg
2020

2121
from blackjax.base import SamplingAlgorithm
22+
from blackjax.mcmc.proposal import static_binomial_sampling
2223
from blackjax.types import Array, PRNGKey
2324

2425
__all__ = ["MarginalState", "MarginalInfo", "init_and_kernel", "mgrad_gaussian"]
@@ -121,13 +122,14 @@ def step(key: PRNGKey, state: MarginalState, delta):
121122
hxy = jnp.dot(U_x - temp_y, Gamma_3 * U_grad_y)
122123
hyx = jnp.dot(U_y - temp_x, Gamma_3 * U_grad_x)
123124

124-
alpha = jnp.minimum(1, jnp.exp(log_p_y - logdensity + hxy - hyx))
125-
accept = jax.random.uniform(u_key) < alpha
126-
125+
log_p_accept = log_p_y - logdensity + hxy - hyx
127126
proposed_state = MarginalState(y, log_p_y, grad_y, U_y, U_grad_y)
128-
state = jax.lax.cond(accept, lambda _: proposed_state, lambda _: state, None)
129-
info = MarginalInfo(alpha, accept, proposed_state)
130-
return state, info
127+
accepted_state, info = static_binomial_sampling(
128+
u_key, log_p_accept, state, proposed_state
129+
)
130+
do_accept, p_accept, _ = info
131+
info = MarginalInfo(p_accept, do_accept, proposed_state)
132+
return accepted_state, info
131133

132134
def init(position):
133135
logdensity, logdensity_grad = val_and_grad(position)

0 commit comments

Comments
 (0)