Skip to content

Commit 3e0efa2

Browse files
committed
feat: clean up walnuts a little bit
1 parent 47bd2ca commit 3e0efa2

File tree

10 files changed

+219
-154
lines changed

10 files changed

+219
-154
lines changed

src/adapt_strategy.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,11 +291,12 @@ where
291291
start: &State<M, P>,
292292
end: &State<M, P>,
293293
divergence_info: Option<&DivergenceInfo>,
294+
num_substeps: u64,
294295
) {
295296
self.collector1
296-
.register_leapfrog(math, start, end, divergence_info);
297+
.register_leapfrog(math, start, end, divergence_info, num_substeps);
297298
self.collector2
298-
.register_leapfrog(math, start, end, divergence_info);
299+
.register_leapfrog(math, start, end, divergence_info, num_substeps);
299300
}
300301

301302
fn register_draw(&mut self, math: &mut M, state: &State<M, P>, info: &crate::nuts::SampleInfo) {

src/chain.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ where
157157
&mut self.hamiltonian,
158158
&self.options,
159159
&mut self.collector,
160+
self.draw_count < 70,
160161
)?;
161162
let mut position: Box<[f64]> = vec![0f64; math.dim()].into();
162163
state.write_position(math, &mut position);
@@ -235,6 +236,7 @@ pub struct NutsStats<P: HasDims, H: Storable<P>, A: Storable<P>, D: Storable<P>>
235236
pub divergence_end: Option<Vec<f64>>,
236237
#[storable(dims("unconstrained_parameter"))]
237238
pub divergence_momentum: Option<Vec<f64>>,
239+
non_reversible: Option<bool>,
238240
//pub divergence_message: Option<String>,
239241
#[storable(ignore)]
240242
_phantom: PhantomData<fn() -> P>,
@@ -303,7 +305,7 @@ impl<M: Math, R: rand::Rng, A: AdaptStrategy<M>> SamplerStats<M> for NutsChain<M
303305
.and_then(|d| d.end_location.as_ref().map(|v| v.as_ref().to_vec())),
304306
divergence_momentum: div_info
305307
.and_then(|d| d.start_momentum.as_ref().map(|v| v.as_ref().to_vec())),
306-
//divergence_message: self.divergence_msg.clone(),
308+
non_reversible: div_info.and_then(|d| Some(d.non_reversible)),
307309
_phantom: PhantomData,
308310
}
309311
}

src/euclidean_hamiltonian.rs

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ impl<M: Math, Mass: MassMatrix<M>> Hamiltonian<M> for EuclideanHamiltonian<M, Ma
225225
math: &mut M,
226226
start: &State<M, Self::Point>,
227227
dir: Direction,
228-
step_size_factor: f64,
228+
step_size_splits: u64,
229229
collector: &mut C,
230230
) -> LeapfrogResult<M, Self::Point> {
231231
let mut out = self.pool().new_state(math);
@@ -238,7 +238,7 @@ impl<M: Math, Mass: MassMatrix<M>> Hamiltonian<M> for EuclideanHamiltonian<M, Ma
238238
Direction::Backward => -1,
239239
};
240240

241-
let epsilon = (sign as f64) * self.step_size * step_size_factor;
241+
let epsilon = (sign as f64) * self.step_size / (step_size_splits as f64);
242242

243243
start
244244
.point()
@@ -250,17 +250,9 @@ impl<M: Math, Mass: MassMatrix<M>> Hamiltonian<M> for EuclideanHamiltonian<M, Ma
250250
if !logp_error.is_recoverable() {
251251
return LeapfrogResult::Err(logp_error);
252252
}
253-
let div_info = DivergenceInfo {
254-
logp_function_error: Some(Arc::new(Box::new(logp_error))),
255-
start_location: Some(math.box_array(start.point().position())),
256-
start_gradient: Some(math.box_array(&start.point().gradient)),
257-
start_momentum: Some(math.box_array(&start.point().momentum)),
258-
end_location: None,
259-
start_idx_in_trajectory: Some(start.point().index_in_trajectory()),
260-
end_idx_in_trajectory: None,
261-
energy_error: None,
262-
};
263-
collector.register_leapfrog(math, start, &out, Some(&div_info));
253+
let error = Arc::new(Box::new(logp_error));
254+
let div_info = DivergenceInfo::new_logp_function_error(math, start, error);
255+
collector.register_leapfrog(math, start, &out, Some(&div_info), step_size_splits);
264256
return LeapfrogResult::Divergence(div_info);
265257
}
266258

@@ -273,23 +265,21 @@ impl<M: Math, Mass: MassMatrix<M>> Hamiltonian<M> for EuclideanHamiltonian<M, Ma
273265

274266
start.point().set_psum(math, out_point, dir);
275267

268+
// TODO: energy error measured relative to initial point or previous point?
276269
let energy_error = out_point.energy_error();
277270
if (energy_error > self.max_energy_error) | !energy_error.is_finite() {
278-
let divergence_info = DivergenceInfo {
279-
logp_function_error: None,
280-
start_location: Some(math.box_array(start.point().position())),
281-
start_gradient: Some(math.box_array(start.point().gradient())),
282-
end_location: Some(math.box_array(&out_point.position)),
283-
start_momentum: Some(math.box_array(&out_point.momentum)),
284-
start_idx_in_trajectory: Some(start.index_in_trajectory()),
285-
end_idx_in_trajectory: Some(out.index_in_trajectory()),
286-
energy_error: Some(energy_error),
287-
};
288-
collector.register_leapfrog(math, start, &out, Some(&divergence_info));
271+
let divergence_info = DivergenceInfo::new_energy_error_too_large(math, start, &out);
272+
collector.register_leapfrog(
273+
math,
274+
start,
275+
&out,
276+
Some(&divergence_info),
277+
step_size_splits,
278+
);
289279
return LeapfrogResult::Divergence(divergence_info);
290280
}
291281

292-
collector.register_leapfrog(math, start, &out, None);
282+
collector.register_leapfrog(math, start, &out, None, step_size_splits);
293283

294284
LeapfrogResult::Ok(out)
295285
}
@@ -363,4 +353,8 @@ impl<M: Math, Mass: MassMatrix<M>> Hamiltonian<M> for EuclideanHamiltonian<M, Ma
363353
fn step_size_mut(&mut self) -> &mut f64 {
364354
&mut self.step_size
365355
}
356+
357+
fn max_energy_error(&self) -> f64 {
358+
self.max_energy_error
359+
}
366360
}

src/hamiltonian.rs

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use crate::{
1616
/// a cutoff value or nan.
1717
/// - The logp function caused a recoverable error (eg if an ODE solver
1818
/// failed)
19+
#[non_exhaustive]
1920
#[derive(Debug, Clone)]
2021
pub struct DivergenceInfo {
2122
pub start_momentum: Option<Box<[f64]>>,
@@ -26,6 +27,7 @@ pub struct DivergenceInfo {
2627
pub end_idx_in_trajectory: Option<i64>,
2728
pub start_idx_in_trajectory: Option<i64>,
2829
pub logp_function_error: Option<Arc<dyn std::error::Error + Send + Sync>>,
30+
pub non_reversible: bool,
2931
}
3032

3133
impl DivergenceInfo {
@@ -39,8 +41,67 @@ impl DivergenceInfo {
3941
end_idx_in_trajectory: None,
4042
start_idx_in_trajectory: None,
4143
logp_function_error: None,
44+
non_reversible: false,
4245
}
4346
}
47+
48+
pub fn new_energy_error_too_large<M: Math>(
49+
math: &mut M,
50+
start: &State<M, impl Point<M>>,
51+
stop: &State<M, impl Point<M>>,
52+
) -> Self {
53+
DivergenceInfo {
54+
logp_function_error: None,
55+
start_location: Some(math.box_array(start.point().position())),
56+
start_gradient: Some(math.box_array(start.point().gradient())),
57+
// TODO
58+
start_momentum: None,
59+
start_idx_in_trajectory: Some(start.index_in_trajectory()),
60+
end_location: Some(math.box_array(&stop.point().position())),
61+
end_idx_in_trajectory: Some(stop.index_in_trajectory()),
62+
// TODO
63+
energy_error: None,
64+
non_reversible: false,
65+
}
66+
}
67+
68+
pub fn new_logp_function_error<M: Math>(
69+
math: &mut M,
70+
start: &State<M, impl Point<M>>,
71+
logp_function_error: Arc<dyn std::error::Error + Send + Sync>,
72+
) -> Self {
73+
DivergenceInfo {
74+
logp_function_error: Some(logp_function_error),
75+
start_location: Some(math.box_array(start.point().position())),
76+
start_gradient: Some(math.box_array(start.point().gradient())),
77+
// TODO
78+
start_momentum: None,
79+
start_idx_in_trajectory: Some(start.index_in_trajectory()),
80+
end_location: None,
81+
end_idx_in_trajectory: None,
82+
energy_error: None,
83+
non_reversible: false,
84+
}
85+
}
86+
87+
pub fn new_not_reversible<M: Math>(math: &mut M, start: &State<M, impl Point<M>>) -> Self {
88+
// TODO add info about what went wrong
89+
DivergenceInfo {
90+
logp_function_error: None,
91+
start_location: Some(math.box_array(start.point().position())),
92+
start_gradient: Some(math.box_array(start.point().gradient())),
93+
// TODO
94+
start_momentum: None,
95+
start_idx_in_trajectory: Some(start.index_in_trajectory()),
96+
end_location: None,
97+
end_idx_in_trajectory: None,
98+
energy_error: None,
99+
non_reversible: true,
100+
}
101+
}
102+
pub fn new_max_step_size_halvings<M: Math>(math: &mut M, num_steps: u64, info: Self) -> Self {
103+
info // TODO
104+
}
44105
}
45106

46107
#[derive(Debug, Copy, Clone)]
@@ -106,10 +167,44 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
106167
math: &mut M,
107168
start: &State<M, Self::Point>,
108169
dir: Direction,
109-
step_size_factor: f64,
170+
step_size_splits: u64,
110171
collector: &mut C,
111172
) -> LeapfrogResult<M, Self::Point>;
112173

174+
fn split_leapfrog<C: Collector<M, Self::Point>>(
175+
&mut self,
176+
math: &mut M,
177+
start: &State<M, Self::Point>,
178+
dir: Direction,
179+
num_steps: u64,
180+
collector: &mut C,
181+
max_error: f64,
182+
) -> LeapfrogResult<M, Self::Point> {
183+
let mut state = start.clone();
184+
185+
let mut min_energy = start.energy();
186+
let mut max_energy = min_energy;
187+
188+
for _ in 0..num_steps {
189+
state = match self.leapfrog(math, &state, dir, num_steps, collector) {
190+
LeapfrogResult::Ok(state) => state,
191+
LeapfrogResult::Divergence(info) => return LeapfrogResult::Divergence(info),
192+
LeapfrogResult::Err(err) => return LeapfrogResult::Err(err),
193+
};
194+
let energy = state.energy();
195+
min_energy = min_energy.min(energy);
196+
max_energy = max_energy.max(energy);
197+
198+
// TODO: walnuts papers says to use abs, but c++ code doesn't?
199+
if max_energy - min_energy > max_error {
200+
let info = DivergenceInfo::new_energy_error_too_large(math, start, &state);
201+
return LeapfrogResult::Divergence(info);
202+
}
203+
}
204+
205+
LeapfrogResult::Ok(state)
206+
}
207+
113208
fn is_turning(
114209
&self,
115210
math: &mut M,
@@ -141,4 +236,6 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
141236

142237
fn step_size(&self) -> f64;
143238
fn step_size_mut(&mut self) -> &mut f64;
239+
240+
fn max_energy_error(&self) -> f64;
144241
}

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ pub use cpu_math::{CpuLogpFunc, CpuMath, CpuMathError};
125125
pub use hamiltonian::DivergenceInfo;
126126
pub use math_base::{LogpError, Math};
127127
pub use model::Model;
128-
pub use nuts::NutsError;
128+
pub use nuts::{NutsError, WalnutsOptions};
129129
pub use sampler::{
130130
ChainProgress, DiagGradNutsSettings, LowRankNutsSettings, NutsSettings, Progress,
131131
ProgressCallback, Sampler, SamplerWaitResult, Settings, TransformedNutsSettings,

0 commit comments

Comments
 (0)