Skip to content

Commit ffffa70

Browse files
committed
feat: clean up walnuts a little bit
1 parent 10bd9b9 commit ffffa70

File tree

4 files changed

+135
-127
lines changed

4 files changed

+135
-127
lines changed

src/euclidean_hamiltonian.rs

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -334,16 +334,8 @@ impl<M: Math, Mass: MassMatrix<M>> Hamiltonian<M> for EuclideanHamiltonian<M, Ma
334334
if !logp_error.is_recoverable() {
335335
return LeapfrogResult::Err(logp_error);
336336
}
337-
let div_info = DivergenceInfo {
338-
logp_function_error: Some(Arc::new(Box::new(logp_error))),
339-
start_location: Some(math.box_array(start.point().position())),
340-
start_gradient: Some(math.box_array(&start.point().gradient)),
341-
start_momentum: Some(math.box_array(&start.point().momentum)),
342-
end_location: None,
343-
start_idx_in_trajectory: Some(start.point().index_in_trajectory()),
344-
end_idx_in_trajectory: None,
345-
energy_error: None,
346-
};
337+
let error = Arc::new(Box::new(logp_error));
338+
let div_info = DivergenceInfo::new_logp_function_error(math, start, error);
347339
collector.register_leapfrog(math, start, &out, Some(&div_info));
348340
return LeapfrogResult::Divergence(div_info);
349341
}
@@ -357,18 +349,10 @@ impl<M: Math, Mass: MassMatrix<M>> Hamiltonian<M> for EuclideanHamiltonian<M, Ma
357349

358350
start.point().set_psum(math, out_point, dir);
359351

352+
// TODO: energy error measured relative to initial point or previous point?
360353
let energy_error = out_point.energy_error();
361354
if (energy_error > self.max_energy_error) | !energy_error.is_finite() {
362-
let divergence_info = DivergenceInfo {
363-
logp_function_error: None,
364-
start_location: Some(math.box_array(start.point().position())),
365-
start_gradient: Some(math.box_array(start.point().gradient())),
366-
end_location: Some(math.box_array(&out_point.position)),
367-
start_momentum: Some(math.box_array(&out_point.momentum)),
368-
start_idx_in_trajectory: Some(start.index_in_trajectory()),
369-
end_idx_in_trajectory: Some(out.index_in_trajectory()),
370-
energy_error: Some(energy_error),
371-
};
355+
let divergence_info = DivergenceInfo::new_energy_error_too_large(math, start, &out);
372356
collector.register_leapfrog(math, start, &out, Some(&divergence_info));
373357
return LeapfrogResult::Divergence(divergence_info);
374358
}
@@ -447,4 +431,8 @@ impl<M: Math, Mass: MassMatrix<M>> Hamiltonian<M> for EuclideanHamiltonian<M, Ma
447431
fn step_size_mut(&mut self) -> &mut f64 {
448432
&mut self.step_size
449433
}
434+
435+
fn max_energy_error(&self) -> f64 {
436+
self.max_energy_error
437+
}
450438
}

src/hamiltonian.rs

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use crate::{
1616
/// a cutoff value or nan.
1717
/// - The logp function caused a recoverable error (eg if an ODE solver
1818
/// failed)
19+
#[non_exhaustive]
1920
#[derive(Debug, Clone)]
2021
pub struct DivergenceInfo {
2122
pub start_momentum: Option<Box<[f64]>>,
@@ -41,6 +42,58 @@ impl DivergenceInfo {
4142
logp_function_error: None,
4243
}
4344
}
45+
46+
pub fn new_energy_error_too_large<M: Math>(
47+
math: &mut M,
48+
start: &State<M, impl Point<M>>,
49+
stop: &State<M, impl Point<M>>,
50+
) -> Self {
51+
DivergenceInfo {
52+
logp_function_error: None,
53+
start_location: Some(math.box_array(start.point().position())),
54+
start_gradient: Some(math.box_array(start.point().gradient())),
55+
// TODO
56+
start_momentum: None,
57+
start_idx_in_trajectory: Some(start.index_in_trajectory()),
58+
end_location: Some(math.box_array(&stop.point().position())),
59+
end_idx_in_trajectory: Some(stop.index_in_trajectory()),
60+
// TODO
61+
energy_error: None,
62+
}
63+
}
64+
65+
pub fn new_logp_function_error<M: Math>(
66+
math: &mut M,
67+
start: &State<M, impl Point<M>>,
68+
logp_function_error: Arc<dyn std::error::Error + Send + Sync>,
69+
) -> Self {
70+
DivergenceInfo {
71+
logp_function_error: Some(logp_function_error),
72+
start_location: Some(math.box_array(start.point().position())),
73+
start_gradient: Some(math.box_array(start.point().gradient())),
74+
// TODO
75+
start_momentum: None,
76+
start_idx_in_trajectory: Some(start.index_in_trajectory()),
77+
end_location: None,
78+
end_idx_in_trajectory: None,
79+
energy_error: None,
80+
}
81+
}
82+
83+
pub fn new_not_reversible<M: Math>(math: &mut M, start: &State<M, impl Point<M>>) -> Self {
84+
// TODO add into about what went wrong
85+
DivergenceInfo {
86+
logp_function_error: None,
87+
start_location: Some(math.box_array(start.point().position())),
88+
start_gradient: Some(math.box_array(start.point().gradient())),
89+
// TODO
90+
start_momentum: None,
91+
start_idx_in_trajectory: Some(start.index_in_trajectory()),
92+
end_location: None,
93+
end_idx_in_trajectory: None,
94+
energy_error: None,
95+
}
96+
}
4497
}
4598

4699
#[derive(Debug, Copy, Clone)]
@@ -110,6 +163,40 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
110163
collector: &mut C,
111164
) -> LeapfrogResult<M, Self::Point>;
112165

166+
fn split_leapfrog<C: Collector<M, Self::Point>>(
167+
&mut self,
168+
math: &mut M,
169+
start: &State<M, Self::Point>,
170+
dir: Direction,
171+
num_splits: usize,
172+
collector: &mut C,
173+
) -> LeapfrogResult<M, Self::Point> {
174+
let step_size_factor = 1.0 / (num_splits as f64);
175+
let mut state = start.clone();
176+
177+
let mut min_energy = start.energy();
178+
let mut max_energy = min_energy;
179+
180+
for _ in 0..num_splits {
181+
state = match self.leapfrog(math, &state, dir, step_size_factor, collector) {
182+
LeapfrogResult::Ok(state) => state,
183+
LeapfrogResult::Divergence(info) => return LeapfrogResult::Divergence(info),
184+
LeapfrogResult::Err(err) => return LeapfrogResult::Err(err),
185+
};
186+
187+
let energy = state.energy();
188+
min_energy = min_energy.min(energy);
189+
max_energy = max_energy.max(energy);
190+
191+
// TODO: walnuts papers says to use abs, but c++ code doesn't?
192+
if (max_energy - min_energy) > self.max_energy_error() {
193+
let info = DivergenceInfo::new_energy_error_too_large(math, start, &state);
194+
return LeapfrogResult::Divergence(info);
195+
}
196+
}
197+
LeapfrogResult::Ok(state)
198+
}
199+
113200
fn is_turning(
114201
&self,
115202
math: &mut M,
@@ -141,4 +228,6 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
141228

142229
fn step_size(&self) -> f64;
143230
fn step_size_mut(&mut self) -> &mut f64;
231+
232+
fn max_energy_error(&self) -> f64;
144233
}

src/nuts.rs

Lines changed: 31 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -246,100 +246,46 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
246246
Some(ref options) => {
247247
// Walnuts implementation
248248
// TODO: Shouldn't all be in this one big function...
249-
let mut step_size_factor = 1.0;
250249
let mut num_steps = 1;
251250
let mut current = start.clone();
252251

253-
let mut success = false;
254-
255-
'step_size_search: for _ in 0..options.max_step_size_halvings {
256-
current = start.clone();
257-
let mut min_energy = current.energy();
258-
let mut max_energy = min_energy;
259-
260-
for _ in 0..num_steps {
261-
current = match hamiltonian.leapfrog(
262-
math,
263-
&current,
264-
direction,
265-
step_size_factor,
266-
collector,
267-
) {
268-
LeapfrogResult::Ok(state) => state,
269-
LeapfrogResult::Divergence(_) => {
270-
num_steps *= 2;
271-
step_size_factor *= 0.5;
272-
continue 'step_size_search;
273-
}
274-
LeapfrogResult::Err(err) => {
275-
return Err(NutsError::LogpFailure(err.into()));
276-
}
277-
};
278-
279-
// Update min/max energies
280-
let current_energy = current.energy();
281-
min_energy = min_energy.min(current_energy);
282-
max_energy = max_energy.max(current_energy);
283-
}
284-
285-
if max_energy - min_energy > options.max_energy_error {
286-
num_steps *= 2;
287-
step_size_factor *= 0.5;
288-
continue 'step_size_search;
289-
}
290-
291-
success = true;
292-
break 'step_size_search;
252+
let mut last_divergence = None;
253+
254+
for _ in 0..options.max_step_size_halvings {
255+
current = match hamiltonian
256+
.split_leapfrog(math, start, direction, num_steps, collector)
257+
{
258+
LeapfrogResult::Ok(state) => state,
259+
LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())),
260+
LeapfrogResult::Divergence(info) => {
261+
num_steps *= 2;
262+
last_divergence = Some(info);
263+
continue;
264+
}
265+
};
266+
break;
293267
}
294268

295-
if !success {
296-
// TODO: More info
297-
return Ok(Err(DivergenceInfo::new()));
269+
if let Some(info) = last_divergence {
270+
return Ok(Err(info));
298271
}
299272

300-
// TODO
301273
let back = direction.reverse();
302-
let mut current_backward;
303-
304274
let mut reversible = true;
305275

306-
'rev_step_size: while num_steps >= 2 {
276+
while num_steps >= 2 {
307277
num_steps /= 2;
308-
step_size_factor *= 0.5;
309-
310-
// TODO: Can we share code for the micro steps in the two directions?
311-
current_backward = current.clone();
312-
313-
let mut min_energy = current_backward.energy();
314-
let mut max_energy = min_energy;
315-
316-
for _ in 0..num_steps {
317-
current_backward = match hamiltonian.leapfrog(
318-
math,
319-
&current_backward,
320-
back,
321-
step_size_factor,
322-
collector,
323-
) {
324-
LeapfrogResult::Ok(state) => state,
325-
LeapfrogResult::Divergence(_) => {
326-
// We also reject in the backward direction, all is good so far...
327-
continue 'rev_step_size;
328-
}
329-
LeapfrogResult::Err(err) => {
330-
return Err(NutsError::LogpFailure(err.into()));
331-
}
332-
};
333-
334-
// Update min/max energies
335-
let current_energy = current_backward.energy();
336-
min_energy = min_energy.min(current_energy);
337-
max_energy = max_energy.max(current_energy);
338-
if max_energy - min_energy > options.max_energy_error {
339-
// We reject also in the backward direction, all good so far...
340-
continue 'rev_step_size;
278+
279+
match hamiltonian.split_leapfrog(math, &current, back, num_steps, collector) {
280+
LeapfrogResult::Ok(_) => (),
281+
LeapfrogResult::Divergence(_) => {
282+
// We also reject in the backward direction, all is good so far...
283+
continue;
284+
}
285+
LeapfrogResult::Err(err) => {
286+
return Err(NutsError::LogpFailure(err.into()));
341287
}
342-
}
288+
};
343289

344290
// We did not reject in the backward direction, so we are not reversible
345291
reversible = false;
@@ -350,13 +296,12 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
350296
let log_size = -current.point().energy_error();
351297
(log_size, current)
352298
} else {
353-
// TODO: More info
354-
return Ok(Err(DivergenceInfo::new()));
299+
return Ok(Err(DivergenceInfo::new_not_reversible(math, start)));
355300
}
356301
}
357302
None => {
358-
// Classical NUTS
359-
//
303+
// Classical NUTS.
304+
// TODO Is equivalent to walnuts with max_step_size_halvings = 0?
360305
let end = match hamiltonian.leapfrog(math, start, direction, 1.0, collector) {
361306
LeapfrogResult::Divergence(info) => return Ok(Err(info)),
362307
LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())),
@@ -393,7 +338,6 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
393338

394339
#[derive(Debug, Clone, Copy)]
395340
pub struct WalnutsOptions {
396-
pub max_energy_error: f64,
397341
pub max_step_size_halvings: u64,
398342
}
399343

src/transformed_hamiltonian.rs

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -481,16 +481,8 @@ impl<M: Math> Hamiltonian<M> for TransformedHamiltonian<M> {
481481
if !logp_error.is_recoverable() {
482482
return LeapfrogResult::Err(logp_error);
483483
}
484-
let div_info = DivergenceInfo {
485-
logp_function_error: Some(Arc::new(Box::new(logp_error))),
486-
start_location: Some(math.box_array(start.point().position())),
487-
start_gradient: Some(math.box_array(start.point().gradient())),
488-
start_momentum: None,
489-
end_location: None,
490-
start_idx_in_trajectory: Some(start.point().index_in_trajectory()),
491-
end_idx_in_trajectory: None,
492-
energy_error: None,
493-
};
484+
let logp_error = Arc::new(Box::new(logp_error));
485+
let div_info = DivergenceInfo::new_logp_function_error(math, start, logp_error);
494486
collector.register_leapfrog(math, start, &out, Some(&div_info));
495487
return LeapfrogResult::Divergence(div_info);
496488
}
@@ -502,16 +494,7 @@ impl<M: Math> Hamiltonian<M> for TransformedHamiltonian<M> {
502494

503495
let energy_error = out_point.energy_error();
504496
if (energy_error > self.max_energy_error) | !energy_error.is_finite() {
505-
let divergence_info = DivergenceInfo {
506-
logp_function_error: None,
507-
start_location: Some(math.box_array(start.point().position())),
508-
start_gradient: Some(math.box_array(start.point().gradient())),
509-
end_location: Some(math.box_array(out_point.position())),
510-
start_momentum: None,
511-
start_idx_in_trajectory: Some(start.index_in_trajectory()),
512-
end_idx_in_trajectory: Some(out.index_in_trajectory()),
513-
energy_error: Some(energy_error),
514-
};
497+
let divergence_info = DivergenceInfo::new_energy_error_too_large(math, start, &out);
515498
collector.register_leapfrog(math, start, &out, Some(&divergence_info));
516499
return LeapfrogResult::Divergence(divergence_info);
517500
}
@@ -618,4 +601,8 @@ impl<M: Math> Hamiltonian<M> for TransformedHamiltonian<M> {
618601
fn step_size_mut(&mut self) -> &mut f64 {
619602
&mut self.step_size
620603
}
604+
605+
fn max_energy_error(&self) -> f64 {
606+
self.max_energy_error
607+
}
621608
}

0 commit comments

Comments
 (0)