Skip to content

Commit 3ebfaff

Browse files
Fix incorrect workload calculation for filtered node similarity
Co-authored-by: Veselin Nikolov <veselin.nikolov@neotechnology.com>
1 parent 0756066 commit 3ebfaff

File tree

2 files changed

+67
-2
lines changed

2 files changed

+67
-2
lines changed

algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarity.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -437,8 +437,14 @@ private LongStream checkProgress(LongStream stream) {
437437
}
438438

439439
private long calculateWorkload() {
440-
long workload = nodesToCompare * nodesToCompare;
441-
if (concurrency == 1) {
440+
//for each source node, examine all their target nodes
441+
//if no filter then sourceNodes == targetNodes
442+
long workload = sourceNodes.cardinality() * targetNodes.cardinality();
443+
444+
//when on concurrency of 1 on not-filtered similarity, we only compare nodeId with greater indexed nodes
445+
// so work is halved. This does not hold for filtered similarity, since the targetNodes might be lesser indexed.
446+
boolean isNotFiltered = sourceNodes.equals(NodeFilter.noOp) && targetNodeFilter.equals(NodeFilter.noOp);
447+
if (concurrency == 1 && isNotFiltered) {
442448
workload = workload / 2;
443449
}
444450
return workload;

algo/src/test/java/org/neo4j/gds/similarity/filterednodesim/FilteredNodeSimilarityTest.java

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@
2020
package org.neo4j.gds.similarity.filterednodesim;
2121

2222
import org.junit.jupiter.api.Test;
23+
import org.junit.jupiter.params.ParameterizedTest;
24+
import org.junit.jupiter.params.provider.ValueSource;
25+
import org.neo4j.gds.TestProgressTracker;
26+
import org.neo4j.gds.compat.Neo4jProxy;
27+
import org.neo4j.gds.compat.TestLog;
28+
import org.neo4j.gds.core.utils.progress.EmptyTaskRegistryFactory;
2329
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2430
import org.neo4j.gds.extension.GdlExtension;
2531
import org.neo4j.gds.extension.GdlGraph;
@@ -30,6 +36,8 @@
3036
import java.util.List;
3137

3238
import static org.assertj.core.api.Assertions.assertThat;
39+
import static org.neo4j.gds.assertj.Extractors.removingThreadId;
40+
import static org.neo4j.gds.compat.TestLog.INFO;
3341

3442
@GdlExtension
3543
class FilteredNodeSimilarityTest {
@@ -154,4 +162,55 @@ void shouldSurviveIoannisFurtherObjections() {
154162

155163
nodeSimilarity.release();
156164
}
165+
166+
@ParameterizedTest
167+
@ValueSource(ints = {1, 2})
168+
void shouldLogProgressAccurately(int concurrency) {
169+
var sourceNodeFilter = List.of(2L, 3L);
170+
171+
var config = ImmutableFilteredNodeSimilarityStreamConfig.builder()
172+
.sourceNodeFilter(NodeFilterSpecFactory.create(sourceNodeFilter))
173+
.concurrency(concurrency)
174+
.topK(0)
175+
.topN(10)
176+
.build();
177+
var progressTask = new FilteredNodeSimilarityFactory<>().progressTask(graph, config);
178+
TestLog log = Neo4jProxy.testLog();
179+
var progressTracker = new TestProgressTracker(
180+
progressTask,
181+
log,
182+
concurrency,
183+
EmptyTaskRegistryFactory.INSTANCE
184+
);
185+
186+
187+
new FilteredNodeSimilarityFactory<>().build(
188+
graph,
189+
config,
190+
progressTracker
191+
).compute();
192+
193+
194+
assertThat(log.getMessages(INFO))
195+
.extracting(removingThreadId())
196+
.containsExactly(
197+
"FilteredNodeSimilarity :: Start",
198+
"FilteredNodeSimilarity :: prepare :: Start",
199+
"FilteredNodeSimilarity :: prepare 33%",
200+
"FilteredNodeSimilarity :: prepare 55%",
201+
"FilteredNodeSimilarity :: prepare 66%",
202+
"FilteredNodeSimilarity :: prepare 100%",
203+
"FilteredNodeSimilarity :: prepare :: Finished",
204+
"FilteredNodeSimilarity :: compare node pairs :: Start",
205+
"FilteredNodeSimilarity :: compare node pairs 12%",
206+
"FilteredNodeSimilarity :: compare node pairs 25%",
207+
"FilteredNodeSimilarity :: compare node pairs 37%",
208+
"FilteredNodeSimilarity :: compare node pairs 50%",
209+
"FilteredNodeSimilarity :: compare node pairs 62%",
210+
"FilteredNodeSimilarity :: compare node pairs 75%",
211+
"FilteredNodeSimilarity :: compare node pairs 100%",
212+
"FilteredNodeSimilarity :: compare node pairs :: Finished",
213+
"FilteredNodeSimilarity :: Finished"
214+
);
215+
}
157216
}

0 commit comments

Comments
 (0)