@@ -120,7 +120,7 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
120120 H : Hamiltonian < M > ,
121121 R : rand:: Rng + ?Sized ,
122122 {
123- let mut other = match self . single_step ( math, hamiltonian, direction, collector) {
123+ let mut other = match self . single_step ( math, hamiltonian, direction, options , collector) {
124124 Ok ( Ok ( tree) ) => tree,
125125 Ok ( Err ( info) ) => return ExtendResult :: Diverging ( self , info) ,
126126 Err ( err) => return ExtendResult :: Err ( err) ,
@@ -213,19 +213,141 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
213213 math : & mut M ,
214214 hamiltonian : & mut H ,
215215 direction : Direction ,
216+ options : & NutsOptions ,
216217 collector : & mut C ,
217218 ) -> Result < std:: result:: Result < NutsTree < M , H , C > , DivergenceInfo > > {
218219 let start = match direction {
219220 Direction :: Forward => & self . right ,
220221 Direction :: Backward => & self . left ,
221222 } ;
222- let end = match hamiltonian. leapfrog ( math, start, direction, collector) {
223- LeapfrogResult :: Divergence ( info) => return Ok ( Err ( info) ) ,
224- LeapfrogResult :: Err ( err) => return Err ( NutsError :: LogpFailure ( err. into ( ) ) ) ,
225- LeapfrogResult :: Ok ( end) => end,
223+
224+ let ( log_size, end) = match options. walnuts_options {
225+ Some ( ref options) => {
226+ // Walnuts implementation
227+ // TODO: Shouldn't all be in this one big function...
228+ let mut step_size_factor = 1.0 ;
229+ let mut num_steps = 1 ;
230+ let mut current = start. clone ( ) ;
231+
232+ let mut success = false ;
233+
234+ ' step_size_search: for _ in 0 ..options. max_step_size_halvings {
235+ current = start. clone ( ) ;
236+ let mut min_energy = current. energy ( ) ;
237+ let mut max_energy = min_energy;
238+
239+ for _ in 0 ..num_steps {
240+ current = match hamiltonian. leapfrog (
241+ math,
242+ & current,
243+ direction,
244+ step_size_factor,
245+ collector,
246+ ) {
247+ LeapfrogResult :: Ok ( state) => state,
248+ LeapfrogResult :: Divergence ( _) => {
249+ num_steps *= 2 ;
250+ step_size_factor *= 0.5 ;
251+ continue ' step_size_search;
252+ }
253+ LeapfrogResult :: Err ( err) => {
254+ return Err ( NutsError :: LogpFailure ( err. into ( ) ) ) ;
255+ }
256+ } ;
257+
258+ // Update min/max energies
259+ let current_energy = current. energy ( ) ;
260+ min_energy = min_energy. min ( current_energy) ;
261+ max_energy = max_energy. max ( current_energy) ;
262+ }
263+
264+ if max_energy - min_energy > options. max_energy_error {
265+ num_steps *= 2 ;
266+ step_size_factor *= 0.5 ;
267+ continue ' step_size_search;
268+ }
269+
270+ success = true ;
271+ break ' step_size_search;
272+ }
273+
274+ if !success {
275+ // TODO: More info
276+ return Ok ( Err ( DivergenceInfo :: new ( ) ) ) ;
277+ }
278+
279+ // TODO
280+ let back = direction. reverse ( ) ;
281+ let mut current_backward;
282+
283+ let mut reversible = true ;
284+
285+ ' rev_step_size: while num_steps >= 2 {
286+ num_steps /= 2 ;
287+ step_size_factor *= 0.5 ;
288+
289+ // TODO: Can we share code for the micro steps in the two directions?
290+ current_backward = current. clone ( ) ;
291+
292+ let mut min_energy = current_backward. energy ( ) ;
293+ let mut max_energy = min_energy;
294+
295+ for _ in 0 ..num_steps {
296+ current_backward = match hamiltonian. leapfrog (
297+ math,
298+ & current_backward,
299+ back,
300+ step_size_factor,
301+ collector,
302+ ) {
303+ LeapfrogResult :: Ok ( state) => state,
304+ LeapfrogResult :: Divergence ( _) => {
305+ // We also reject in the backward direction, all is good so far...
306+ continue ' rev_step_size;
307+ }
308+ LeapfrogResult :: Err ( err) => {
309+ return Err ( NutsError :: LogpFailure ( err. into ( ) ) ) ;
310+ }
311+ } ;
312+
313+ // Update min/max energies
314+ let current_energy = current_backward. energy ( ) ;
315+ min_energy = min_energy. min ( current_energy) ;
316+ max_energy = max_energy. max ( current_energy) ;
317+ if max_energy - min_energy > options. max_energy_error {
318+ // We reject also in the backward direction, all good so far...
319+ continue ' rev_step_size;
320+ }
321+ }
322+
323+ // We did not reject in the backward direction, so we are not reversible
324+ reversible = false ;
325+ break ;
326+ }
327+
328+ if reversible {
329+ let log_size = -current. point ( ) . energy_error ( ) ;
330+ ( log_size, current)
331+ } else {
332+ // TODO: More info
333+ return Ok ( Err ( DivergenceInfo :: new ( ) ) ) ;
334+ }
335+ }
336+ None => {
337+ // Classical NUTS
338+ //
339+ let end = match hamiltonian. leapfrog ( math, start, direction, 1.0 , collector) {
340+ LeapfrogResult :: Divergence ( info) => return Ok ( Err ( info) ) ,
341+ LeapfrogResult :: Err ( err) => return Err ( NutsError :: LogpFailure ( err. into ( ) ) ) ,
342+ LeapfrogResult :: Ok ( end) => end,
343+ } ;
344+
345+ let log_size = -end. point ( ) . energy_error ( ) ;
346+
347+ ( log_size, end)
348+ }
226349 } ;
227350
228- let log_size = -end. point ( ) . energy_error ( ) ;
229351 Ok ( Ok ( NutsTree {
230352 right : end. clone ( ) ,
231353 left : end. clone ( ) ,
@@ -248,12 +370,21 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
248370 }
249371}
250372
373+ #[ derive( Debug , Clone , Copy ) ]
374+ pub struct WalnutsOptions {
375+ pub max_energy_error : f64 ,
376+ pub max_step_size_halvings : u64 ,
377+ }
378+
379+ #[ derive( Debug , Clone , Copy ) ]
251380pub struct NutsOptions {
252381 pub maxdepth : u64 ,
253382 pub store_gradient : bool ,
254383 pub store_unconstrained : bool ,
255384 pub check_turning : bool ,
256385 pub store_divergences : bool ,
386+
387+ pub walnuts_options : Option < WalnutsOptions > ,
257388}
258389
259390pub ( crate ) fn draw < M , H , R , C > (
0 commit comments