Skip to content

Commit ef3f0ec

Browse files
committed
Add tests to demonstrate the problem
1 parent fb6d4e0 commit ef3f0ec

File tree

2 files changed

+135
-0
lines changed

2 files changed

+135
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.similarity;
21+
22+
import org.junit.jupiter.params.ParameterizedTest;
23+
import org.junit.jupiter.params.provider.Arguments;
24+
import org.junit.jupiter.params.provider.MethodSource;
25+
26+
import java.util.Arrays;
27+
import java.util.List;
28+
import java.util.stream.Collector;
29+
import java.util.stream.Collectors;
30+
import java.util.stream.Stream;
31+
32+
import static org.assertj.core.api.Assertions.assertThat;
33+
34+
class JaccardSimilarityTest {
35+
36+
@ParameterizedTest(name = "{2}")
37+
@MethodSource("listCollectors")
38+
void shouldPassAtAllCasesOfListInput(
39+
Collector<Number, ?, List<Number>> firstListCollector,
40+
Collector<Number, ?, List<Number>> secondListCollector,
41+
String label
42+
) {
43+
var arr1 = new int[]{1,2,3};
44+
var arr2 = new int[]{1,2,3};
45+
List<Number> l1 = Arrays.stream(arr1).boxed().collect(firstListCollector);
46+
List<Number> l2 = Arrays.stream(arr2).boxed().collect(secondListCollector);
47+
48+
var similarities = new SimilaritiesFunc();
49+
var jaccarded = similarities.jaccardSimilarity(l1, l2);
50+
assertThat(jaccarded).isEqualTo(1);
51+
}
52+
53+
private static Stream<Arguments> listCollectors() {
54+
return Stream.of(
55+
Arguments.of(
56+
Collectors.toUnmodifiableList(), Collectors.toUnmodifiableList(), "Unmodifiable, Unmodifiable"
57+
),
58+
Arguments.of(
59+
Collectors.toUnmodifiableList(), Collectors.toList(), "Unmodifiable, Modifiable"
60+
),
61+
Arguments.of(
62+
Collectors.toList(), Collectors.toList(), "Modifiable, Modifiable"
63+
),
64+
Arguments.of(
65+
Collectors.toList(), Collectors.toUnmodifiableList(), "Modifiable, Unmodifiable"
66+
)
67+
);
68+
}
69+
70+
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.similarity;
21+
22+
import org.junit.jupiter.api.BeforeEach;
23+
import org.junit.jupiter.api.Test;
24+
import org.neo4j.gds.BaseTest;
25+
import org.neo4j.gds.compat.GraphDatabaseApiProxy;
26+
27+
import static org.assertj.core.api.Assertions.assertThat;
28+
import static org.assertj.core.api.Assertions.assertThatNoException;
29+
import static org.assertj.core.api.InstanceOfAssertFactories.DOUBLE;
30+
import static org.junit.jupiter.api.Assertions.assertEquals;
31+
32+
class JaccardWithCypherTest extends BaseTest {
33+
34+
@BeforeEach
35+
void setUp() throws Exception {
36+
GraphDatabaseApiProxy.registerFunctions(db, SimilaritiesFunc.class);
37+
}
38+
39+
@Test
40+
void testJaccardFunctionWithInputFromDatabase() {
41+
assertThatNoException().isThrownBy(
42+
() -> runQueryWithResultConsumer(
43+
"CREATE (t:Test {listone: [1, 5], listtwo: [5, 5]}) RETURN gds.similarity.jaccard(t.listone, t.listtwo) AS score",
44+
result -> {
45+
assertThat(result.hasNext()).isTrue();
46+
var score = result.next().get("score");
47+
assertThat(score)
48+
.asInstanceOf(DOUBLE)
49+
.isEqualTo(1.0 / 3.0);
50+
}
51+
)
52+
);
53+
}
54+
55+
@Test
56+
void testJaccardFunction() {
57+
assertThatNoException().isThrownBy(
58+
() ->
59+
runQueryWithResultConsumer(
60+
"RETURN gds.similarity.jaccard([1, 5], [5, 5]) AS score",
61+
result -> assertEquals(1.0 / 3.0, result.next().get("score"))
62+
)
63+
);
64+
}
65+
}

0 commit comments

Comments
 (0)