2525import org .neo4j .procedure .UserFunction ;
2626import org .neo4j .values .storable .Values ;
2727
28+ import java .util .ArrayList ;
2829import java .util .Comparator ;
2930import java .util .HashSet ;
3031import java .util .List ;
@@ -152,20 +153,18 @@ private int validateLength(List<Number> vector1, List<Number> vector2) {
152153 * @return The jaccard score, the intersection divided by the union of the input lists.
153154 */
154155 private double jaccard (List <Number > vector1 , List <Number > vector2 ) {
155- vector1 .removeIf (IS_NULL );
156- vector2 .removeIf (IS_NULL );
157- vector1 .sort (NUMBER_COMPARATOR );
158- vector2 .sort (NUMBER_COMPARATOR );
156+ var sortedVector1 = removeNullsAndSort (vector1 );
157+ var sortedVector2 = removeNullsAndSort (vector2 );
159158
160159 int index1 = 0 ;
161160 int index2 = 0 ;
162161
163162 int intersection = 0 ;
164163 double union = 0 ;
165164
166- while (index1 < vector1 .size () && index2 < vector2 .size ()) {
167- Number val1 = vector1 .get (index1 );
168- Number val2 = vector2 .get (index2 );
165+ while (index1 < sortedVector1 .size () && index2 < sortedVector2 .size ()) {
166+ Number val1 = sortedVector1 .get (index1 );
167+ Number val2 = sortedVector2 .get (index2 );
169168 int compare = NUMBER_COMPARATOR .compare (val1 , val2 );
170169
171170 if (compare == 0 ) {
@@ -183,11 +182,23 @@ private double jaccard(List<Number> vector1, List<Number> vector2) {
183182 }
184183
185184 // the remainder, if any, is never shared so add to the union
186- union += (vector1 .size () - index1 ) + (vector2 .size () - index2 );
185+ union += (sortedVector1 .size () - index1 ) + (sortedVector2 .size () - index2 );
187186
188187 return union == 0 ? 1 : intersection / union ;
189188 }
190189
190+ private static List <Number > removeNulls (List <Number > input ) {
191+ var output = new ArrayList <>(input );
192+ output .removeIf (IS_NULL );
193+ return output ;
194+ }
195+
196+ private static List <Number > removeNullsAndSort (List <Number > input ) {
197+ var output = removeNulls (input );
198+ output .sort (NUMBER_COMPARATOR );
199+ return output ;
200+ }
201+
191202 private static double getDoubleValue (Number value ) {
192203 return Optional .ofNullable (value ).map (Number ::doubleValue ).orElse (0D );
193204 }
0 commit comments