Skip to content

Commit a1eaf3e

Browse files
authored
Handle tiny graphs for sampling (#9991)
1 parent def904a commit a1eaf3e

File tree

5 files changed

+87
-3
lines changed

5 files changed

+87
-3
lines changed

graph-sampling/src/main/java/org/neo4j/gds/graphsampling/samplers/SeenNodes.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ static SeenNodes create(
5858

5959
return new SeenNodes.GlobalSeenNodes(
6060
HugeAtomicBitSet.create(inputGraph.nodeCount()),
61-
Math.round(inputGraph.nodeCount() * samplingRatio)
61+
Math.max(1, Math.round(inputGraph.nodeCount() * samplingRatio))
6262
);
6363
}
6464

graph-sampling/src/main/java/org/neo4j/gds/graphsampling/samplers/rw/cnarw/CommonNeighbourAwareRandomWalk.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ public HugeAtomicBitSet compute(Graph inputGraph, ProgressTracker progressTracke
7777
config.samplingRatio()
7878
);
7979

80+
if (seenNodes.totalExpectedNodes() == 0) {
81+
progressTracker.endSubTask("Sample nodes");
82+
return seenNodes.sampledNodes();
83+
}
84+
85+
8086
progressTracker.beginSubTask("Do common neighbour aware random walks");
8187
progressTracker.setSteps(seenNodes.totalExpectedNodes());
8288

graph-sampling/src/main/java/org/neo4j/gds/graphsampling/samplers/rw/rwr/RandomWalkWithRestarts.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ public HugeAtomicBitSet compute(Graph inputGraph, ProgressTracker progressTracke
7575
config.samplingRatio()
7676
);
7777

78+
if (seenNodes.totalExpectedNodes() == 0) {
79+
progressTracker.endSubTask("Sample nodes");
80+
return seenNodes.sampledNodes();
81+
}
82+
7883
progressTracker.beginSubTask("Do random walks");
7984
progressTracker.setSteps(seenNodes.totalExpectedNodes());
8085

@@ -116,16 +121,17 @@ public HugeAtomicBitSet compute(Graph inputGraph, ProgressTracker progressTracke
116121

117122
@Override
118123
public Task progressTask(GraphStore graphStore) {
124+
long sampledNodes = 10 * Math.round(graphStore.nodeCount() * config.samplingRatio());
119125
if (config.nodeLabelStratification()) {
120126
return Tasks.task(
121127
"Sample nodes",
122128
Tasks.leaf("Count node labels", graphStore.nodeCount()),
123-
Tasks.leaf("Do random walks", 10 * Math.round(graphStore.nodeCount() * config.samplingRatio()))
129+
Tasks.leaf("Do random walks", sampledNodes)
124130
);
125131
} else {
126132
return Tasks.task(
127133
"Sample nodes",
128-
Tasks.leaf("Do random walks", 10 * Math.round(graphStore.nodeCount() * config.samplingRatio()))
134+
Tasks.leaf("Do random walks", sampledNodes)
129135
);
130136
}
131137
}

graph-sampling/src/test/java/org/neo4j/gds/graphsampling/samplers/rw/cnarw/CommonNeighbourAwareRandomWalkTest.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,27 @@
2222
import org.assertj.core.data.Offset;
2323
import org.assertj.core.data.Percentage;
2424
import org.junit.jupiter.api.Test;
25+
import org.junit.jupiter.params.ParameterizedTest;
26+
import org.junit.jupiter.params.provider.ValueSource;
2527
import org.neo4j.gds.NodeLabel;
28+
import org.neo4j.gds.TestProgressTracker;
29+
import org.neo4j.gds.TestTaskStore;
2630
import org.neo4j.gds.api.Graph;
2731
import org.neo4j.gds.api.GraphStore;
2832
import org.neo4j.gds.core.GraphDimensions;
2933
import org.neo4j.gds.core.concurrency.Concurrency;
3034
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
35+
import org.neo4j.gds.core.utils.progress.LocalTaskRegistryFactory;
3136
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
37+
import org.neo4j.gds.core.utils.progress.tasks.Task;
3238
import org.neo4j.gds.extension.GdlExtension;
3339
import org.neo4j.gds.extension.GdlGraph;
3440
import org.neo4j.gds.extension.IdFunction;
3541
import org.neo4j.gds.extension.Inject;
3642
import org.neo4j.gds.extension.TestGraph;
3743
import org.neo4j.gds.graphsampling.config.CommonNeighbourAwareRandomWalkConfig;
3844
import org.neo4j.gds.graphsampling.config.CommonNeighbourAwareRandomWalkConfigImpl;
45+
import org.neo4j.gds.logging.GdsTestLog;
3946
import org.neo4j.gds.mem.MemoryRange;
4047
import org.neo4j.gds.termination.TerminatedException;
4148
import org.neo4j.gds.termination.TerminationFlag;
@@ -172,6 +179,15 @@ class CommonNeighbourAwareRandomWalkTest {
172179
@Inject
173180
private TestGraph naturalUnionGraph;
174181

182+
@GdlGraph(graphNamePrefix = "tiny")
183+
private static String TINY_GRAPH = "()-->()";
184+
185+
@Inject
186+
private Graph tinyGraph;
187+
188+
@Inject
189+
private GraphStore tinyGraphStore;
190+
175191

176192
private Graph getGraph(CommonNeighbourAwareRandomWalkConfig config) {
177193
return graphStore.getGraph(
@@ -670,4 +686,23 @@ void checkTerminationFlag() {
670686
.isInstanceOf(TerminatedException.class)
671687
.hasMessageContaining("The execution has been terminated.");
672688
}
689+
690+
@ParameterizedTest
691+
@ValueSource(booleans = {true, false})
692+
void progressTrackingTinyGraph(boolean nodeLabelStratification) {
693+
var config = CommonNeighbourAwareRandomWalkConfigImpl.builder()
694+
.nodeLabelStratification(nodeLabelStratification)
695+
.build();
696+
697+
var cnar = new CommonNeighbourAwareRandomWalk(config);
698+
Task task = cnar.progressTask(tinyGraphStore);
699+
700+
TestTaskStore taskStore = new TestTaskStore();
701+
var taskRegistryFactory = new LocalTaskRegistryFactory("user", taskStore);
702+
var tracker = new TestProgressTracker(task, new GdsTestLog(), new Concurrency(4), taskRegistryFactory);
703+
704+
cnar.compute(tinyGraph, tracker);
705+
706+
assertThat(taskStore.tasks()).isEmpty();
707+
}
673708
}

graph-sampling/src/test/java/org/neo4j/gds/graphsampling/samplers/rw/rwr/RandomWalkWithRestartsTest.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,24 @@
2121

2222
import org.assertj.core.data.Offset;
2323
import org.junit.jupiter.api.Test;
24+
import org.junit.jupiter.params.ParameterizedTest;
25+
import org.junit.jupiter.params.provider.ValueSource;
2426
import org.neo4j.gds.NodeLabel;
27+
import org.neo4j.gds.TestProgressTracker;
28+
import org.neo4j.gds.TestTaskStore;
2529
import org.neo4j.gds.api.Graph;
2630
import org.neo4j.gds.api.GraphStore;
31+
import org.neo4j.gds.core.concurrency.Concurrency;
32+
import org.neo4j.gds.core.utils.progress.LocalTaskRegistryFactory;
2733
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
34+
import org.neo4j.gds.core.utils.progress.tasks.Task;
2835
import org.neo4j.gds.extension.GdlExtension;
2936
import org.neo4j.gds.extension.GdlGraph;
3037
import org.neo4j.gds.extension.IdFunction;
3138
import org.neo4j.gds.extension.Inject;
3239
import org.neo4j.gds.graphsampling.config.RandomWalkWithRestartsConfig;
3340
import org.neo4j.gds.graphsampling.config.RandomWalkWithRestartsConfigImpl;
41+
import org.neo4j.gds.logging.GdsTestLog;
3442
import org.neo4j.gds.termination.TerminatedException;
3543
import org.neo4j.gds.termination.TerminationFlag;
3644

@@ -79,6 +87,15 @@ class RandomWalkWithRestartsTest {
7987
@Inject
8088
private GraphStore graphStore;
8189

90+
@GdlGraph(graphNamePrefix = "tiny")
91+
private static String TINY_GRAPH = "()-->()";
92+
93+
@Inject
94+
private Graph tinyGraph;
95+
96+
@Inject
97+
private GraphStore tinyGraphStore;
98+
8299
Graph getGraph(RandomWalkWithRestartsConfig config) {
83100
return graphStore.getGraph(
84101
config.nodeLabelIdentifiers(graphStore),
@@ -400,4 +417,24 @@ void checkTerminationFlag() {
400417
.isInstanceOf(TerminatedException.class)
401418
.hasMessageContaining("The execution has been terminated.");
402419
}
420+
421+
@ParameterizedTest
422+
@ValueSource(booleans = {true, false})
423+
void progressTrackingTinyGraph(boolean nodeLabelStratification) {
424+
var config = RandomWalkWithRestartsConfigImpl.builder()
425+
.nodeLabelStratification(nodeLabelStratification)
426+
.build();
427+
428+
var rwr = new RandomWalkWithRestarts(config);
429+
Task task = rwr.progressTask(tinyGraphStore);
430+
431+
TestTaskStore taskStore = new TestTaskStore();
432+
var taskRegistryFactory = new LocalTaskRegistryFactory("user", taskStore);
433+
var tracker = new TestProgressTracker(task, new GdsTestLog(), new Concurrency(4), taskRegistryFactory);
434+
435+
rwr.compute(tinyGraph, tracker);
436+
437+
assertThat(taskStore.tasks()).isEmpty();
438+
}
439+
403440
}

0 commit comments

Comments
 (0)