@@ -59,22 +59,41 @@ pub struct SampleInfo {
5959}
6060
6161/// A part of the trajectory tree during NUTS sampling.
62+ ///
63+ /// Corresponds to SpanW in walnuts C++ code
6264struct NutsTree < M : Math , H : Hamiltonian < M > , C : Collector < M , H :: Point > > {
6365 /// The left position of the tree.
6466 ///
6567 /// The left side always has the smaller index_in_trajectory.
6668 /// Leapfrogs in backward direction will replace the left.
69+ ///
70+ /// theta_bk_, rho_bk_, grad_theta_bk_, logp_bk_ in C++ code
6771 left : State < M , H :: Point > ,
72+
73+ /// The right position of the tree.
74+ ///
75+ /// theta_fw_, rho_fw_, grad_theta_fw_, logp_fw_ in C++ code
6876 right : State < M , H :: Point > ,
6977
7078 /// A draw from the trajectory between left and right using
7179 /// multinomial sampling.
80+ ///
81+ /// theta_select_ in C++ code
7282 draw : State < M , H :: Point > ,
83+
84+ /// Constant for acceptance probability
85+ ///
86+ /// logp_ in C++ code
7387 log_size : f64 ,
88+
89+ /// The depth of the tree
7490 depth : u64 ,
7591
7692 /// A tree is the main tree if it contains the initial point
7793 /// of the trajectory.
94+ ///
95+ /// This is used to determine whether to use Metropolis
96+ /// accptance or Barker
7897 is_main : bool ,
7998 _phantom2 : PhantomData < C > ,
8099}
@@ -171,6 +190,7 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
171190 }
172191 }
173192
193+ // `combine` in C++ code
174194 fn merge_into < R : rand:: Rng + ?Sized > (
175195 & mut self ,
176196 _math : & mut M ,
@@ -208,6 +228,7 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
208228 self . log_size = log_size;
209229 }
210230
231+ // Corresponds to `build_leaf` in C++ code
211232 fn single_step (
212233 & self ,
213234 math : & mut M ,
0 commit comments