Skip to content

Commit a0f08dc

Browse files
committed
Fix Jaccard similarity function when input is unmodifiable
1 parent ef3f0ec commit a0f08dc

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

algo/src/main/java/org/neo4j/gds/similarity/SimilaritiesFunc.java

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.neo4j.procedure.UserFunction;
2626
import org.neo4j.values.storable.Values;
2727

28+
import java.util.ArrayList;
2829
import java.util.Comparator;
2930
import java.util.HashSet;
3031
import java.util.List;
@@ -152,20 +153,18 @@ private int validateLength(List<Number> vector1, List<Number> vector2) {
152153
* @return The jaccard score, the intersection divided by the union of the input lists.
153154
*/
154155
private double jaccard(List<Number> vector1, List<Number> vector2) {
155-
vector1.removeIf(IS_NULL);
156-
vector2.removeIf(IS_NULL);
157-
vector1.sort(NUMBER_COMPARATOR);
158-
vector2.sort(NUMBER_COMPARATOR);
156+
var sortedVector1 = removeNullsAndSort(vector1);
157+
var sortedVector2 = removeNullsAndSort(vector2);
159158

160159
int index1 = 0;
161160
int index2 = 0;
162161

163162
int intersection = 0;
164163
double union = 0;
165164

166-
while (index1 < vector1.size() && index2 < vector2.size()) {
167-
Number val1 = vector1.get(index1);
168-
Number val2 = vector2.get(index2);
165+
while (index1 < sortedVector1.size() && index2 < sortedVector2.size()) {
166+
Number val1 = sortedVector1.get(index1);
167+
Number val2 = sortedVector2.get(index2);
169168
int compare = NUMBER_COMPARATOR.compare(val1, val2);
170169

171170
if (compare == 0) {
@@ -183,11 +182,23 @@ private double jaccard(List<Number> vector1, List<Number> vector2) {
183182
}
184183

185184
// the remainder, if any, is never shared so add to the union
186-
union += (vector1.size() - index1) + (vector2.size() - index2);
185+
union += (sortedVector1.size() - index1) + (sortedVector2.size() - index2);
187186

188187
return union == 0 ? 1 : intersection / union;
189188
}
190189

190+
private static List<Number> removeNulls(List<Number> input) {
191+
var output = new ArrayList<>(input);
192+
output.removeIf(IS_NULL);
193+
return output;
194+
}
195+
196+
private static List<Number> removeNullsAndSort(List<Number> input) {
197+
var output = removeNulls(input);
198+
output.sort(NUMBER_COMPARATOR);
199+
return output;
200+
}
201+
191202
private static double getDoubleValue(Number value) {
192203
return Optional.ofNullable(value).map(Number::doubleValue).orElse(0D);
193204
}

0 commit comments

Comments
 (0)