@@ -16,6 +16,7 @@ use crate::{
16
16
/// a cutoff value or nan.
17
17
/// - The logp function caused a recoverable error (eg if an ODE solver
18
18
/// failed)
19
+ #[ non_exhaustive]
19
20
#[ derive( Debug , Clone ) ]
20
21
pub struct DivergenceInfo {
21
22
pub start_momentum : Option < Box < [ f64 ] > > ,
@@ -26,6 +27,7 @@ pub struct DivergenceInfo {
26
27
pub end_idx_in_trajectory : Option < i64 > ,
27
28
pub start_idx_in_trajectory : Option < i64 > ,
28
29
pub logp_function_error : Option < Arc < dyn std:: error:: Error + Send + Sync > > ,
30
+ pub non_reversible : bool ,
29
31
}
30
32
31
33
impl DivergenceInfo {
@@ -39,8 +41,67 @@ impl DivergenceInfo {
39
41
end_idx_in_trajectory : None ,
40
42
start_idx_in_trajectory : None ,
41
43
logp_function_error : None ,
44
+ non_reversible : false ,
42
45
}
43
46
}
47
+
48
+ pub fn new_energy_error_too_large < M : Math > (
49
+ math : & mut M ,
50
+ start : & State < M , impl Point < M > > ,
51
+ stop : & State < M , impl Point < M > > ,
52
+ ) -> Self {
53
+ DivergenceInfo {
54
+ logp_function_error : None ,
55
+ start_location : Some ( math. box_array ( start. point ( ) . position ( ) ) ) ,
56
+ start_gradient : Some ( math. box_array ( start. point ( ) . gradient ( ) ) ) ,
57
+ // TODO
58
+ start_momentum : None ,
59
+ start_idx_in_trajectory : Some ( start. index_in_trajectory ( ) ) ,
60
+ end_location : Some ( math. box_array ( & stop. point ( ) . position ( ) ) ) ,
61
+ end_idx_in_trajectory : Some ( stop. index_in_trajectory ( ) ) ,
62
+ // TODO
63
+ energy_error : None ,
64
+ non_reversible : false ,
65
+ }
66
+ }
67
+
68
+ pub fn new_logp_function_error < M : Math > (
69
+ math : & mut M ,
70
+ start : & State < M , impl Point < M > > ,
71
+ logp_function_error : Arc < dyn std:: error:: Error + Send + Sync > ,
72
+ ) -> Self {
73
+ DivergenceInfo {
74
+ logp_function_error : Some ( logp_function_error) ,
75
+ start_location : Some ( math. box_array ( start. point ( ) . position ( ) ) ) ,
76
+ start_gradient : Some ( math. box_array ( start. point ( ) . gradient ( ) ) ) ,
77
+ // TODO
78
+ start_momentum : None ,
79
+ start_idx_in_trajectory : Some ( start. index_in_trajectory ( ) ) ,
80
+ end_location : None ,
81
+ end_idx_in_trajectory : None ,
82
+ energy_error : None ,
83
+ non_reversible : false ,
84
+ }
85
+ }
86
+
87
+ pub fn new_not_reversible < M : Math > ( math : & mut M , start : & State < M , impl Point < M > > ) -> Self {
88
+ // TODO add info about what went wrong
89
+ DivergenceInfo {
90
+ logp_function_error : None ,
91
+ start_location : Some ( math. box_array ( start. point ( ) . position ( ) ) ) ,
92
+ start_gradient : Some ( math. box_array ( start. point ( ) . gradient ( ) ) ) ,
93
+ // TODO
94
+ start_momentum : None ,
95
+ start_idx_in_trajectory : Some ( start. index_in_trajectory ( ) ) ,
96
+ end_location : None ,
97
+ end_idx_in_trajectory : None ,
98
+ energy_error : None ,
99
+ non_reversible : true ,
100
+ }
101
+ }
102
+ pub fn new_max_step_size_halvings < M : Math > ( math : & mut M , num_steps : u64 , info : Self ) -> Self {
103
+ info // TODO
104
+ }
44
105
}
45
106
46
107
#[ derive( Debug , Copy , Clone ) ]
@@ -106,10 +167,44 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
106
167
math : & mut M ,
107
168
start : & State < M , Self :: Point > ,
108
169
dir : Direction ,
109
- step_size_factor : f64 ,
170
+ step_size_splits : u64 ,
110
171
collector : & mut C ,
111
172
) -> LeapfrogResult < M , Self :: Point > ;
112
173
174
+ fn split_leapfrog < C : Collector < M , Self :: Point > > (
175
+ & mut self ,
176
+ math : & mut M ,
177
+ start : & State < M , Self :: Point > ,
178
+ dir : Direction ,
179
+ num_steps : u64 ,
180
+ collector : & mut C ,
181
+ max_error : f64 ,
182
+ ) -> LeapfrogResult < M , Self :: Point > {
183
+ let mut state = start. clone ( ) ;
184
+
185
+ let mut min_energy = start. energy ( ) ;
186
+ let mut max_energy = min_energy;
187
+
188
+ for _ in 0 ..num_steps {
189
+ state = match self . leapfrog ( math, & state, dir, num_steps, collector) {
190
+ LeapfrogResult :: Ok ( state) => state,
191
+ LeapfrogResult :: Divergence ( info) => return LeapfrogResult :: Divergence ( info) ,
192
+ LeapfrogResult :: Err ( err) => return LeapfrogResult :: Err ( err) ,
193
+ } ;
194
+ let energy = state. energy ( ) ;
195
+ min_energy = min_energy. min ( energy) ;
196
+ max_energy = max_energy. max ( energy) ;
197
+
198
+ // TODO: walnuts papers says to use abs, but c++ code doesn't?
199
+ if max_energy - min_energy > max_error {
200
+ let info = DivergenceInfo :: new_energy_error_too_large ( math, start, & state) ;
201
+ return LeapfrogResult :: Divergence ( info) ;
202
+ }
203
+ }
204
+
205
+ LeapfrogResult :: Ok ( state)
206
+ }
207
+
113
208
fn is_turning (
114
209
& self ,
115
210
math : & mut M ,
@@ -141,4 +236,6 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
141
236
142
237
fn step_size ( & self ) -> f64 ;
143
238
fn step_size_mut ( & mut self ) -> & mut f64 ;
239
+
240
+ fn max_energy_error ( & self ) -> f64 ;
144
241
}
0 commit comments