Skip to content

Commit d54c320

Browse files
chris1011011IoannisPanagiotas
authored andcommitted
track progress based on sampling size
Co-authored-by: Ioannis Panagiotas <ioannis.panagiotas@neotechnology.com>
1 parent 6cb8870 commit d54c320

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

algo/src/main/java/org/neo4j/gds/betweenness/BetweennessCentralityFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,6 @@ private static MemoryEstimations.Builder bcTaskMemoryEstimationBuilder(boolean w
121121

122122
@Override
123123
public Task progressTask(Graph graph, CONFIG config) {
124-
return Tasks.leaf(taskName(), graph.nodeCount());
124+
return Tasks.leaf(taskName(), config.samplingSize().orElse(graph.nodeCount()));
125125
}
126126
}

algo/src/test/java/org/neo4j/gds/betweenness/BetweennessCentralityTest.java

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,34 @@
1919
*/
2020
package org.neo4j.gds.betweenness;
2121

22+
import org.junit.jupiter.api.Test;
2223
import org.junit.jupiter.params.ParameterizedTest;
2324
import org.junit.jupiter.params.provider.Arguments;
2425
import org.junit.jupiter.params.provider.MethodSource;
2526
import org.junit.jupiter.params.provider.ValueSource;
27+
import org.neo4j.gds.TestProgressTracker;
28+
import org.neo4j.gds.compat.Neo4jProxy;
29+
import org.neo4j.gds.compat.TestLog;
2630
import org.neo4j.gds.core.CypherMapWrapper;
2731
import org.neo4j.gds.core.concurrency.Pools;
2832
import org.neo4j.gds.core.utils.mem.MemoryRange;
2933
import org.neo4j.gds.core.utils.paged.HugeAtomicDoubleArray;
34+
import org.neo4j.gds.core.utils.progress.EmptyTaskRegistryFactory;
3035
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
3136
import org.neo4j.gds.extension.TestGraph;
3237

3338
import java.util.Map;
3439
import java.util.Optional;
3540
import java.util.stream.Stream;
3641

42+
import static org.assertj.core.api.Assertions.assertThat;
3743
import static org.junit.jupiter.api.Assertions.assertEquals;
3844
import static org.neo4j.gds.Orientation.UNDIRECTED;
3945
import static org.neo4j.gds.TestSupport.assertMemoryEstimation;
4046
import static org.neo4j.gds.TestSupport.crossArguments;
4147
import static org.neo4j.gds.TestSupport.fromGdl;
48+
import static org.neo4j.gds.assertj.Extractors.removingThreadId;
49+
import static org.neo4j.gds.assertj.Extractors.replaceTimings;
4250

4351
class BetweennessCentralityTest {
4452

@@ -182,4 +190,59 @@ void testMemoryEstimation(int concurrency, long expectedMinBytes, long expectedM
182190
MemoryRange.of(expectedMinBytes, expectedMaxBytes)
183191
);
184192
}
193+
194+
@Test
195+
void testShouldLogProgress() {
196+
var config = BetweennessCentralityStreamConfigImpl.builder().samplingSize(2L).build();
197+
var factory = new BetweennessCentralityFactory<>();
198+
var log = Neo4jProxy.testLog();
199+
var testGraph = fromGdl(DIAMOND, "diamond");
200+
var progressTracker = new TestProgressTracker(
201+
factory.progressTask(testGraph, config),
202+
log,
203+
4,
204+
EmptyTaskRegistryFactory.INSTANCE
205+
);
206+
factory.build(testGraph, config, progressTracker).compute();
207+
208+
assertThat(log.getMessages(TestLog.INFO))
209+
.extracting(removingThreadId())
210+
.extracting(replaceTimings())
211+
.containsExactly(
212+
"BetweennessCentrality :: Start",
213+
"BetweennessCentrality 50%",
214+
"BetweennessCentrality 100%",
215+
"BetweennessCentrality :: Finished"
216+
);
217+
}
218+
219+
@Test
220+
void testShouldLogProgressNoSampling() {
221+
var config = BetweennessCentralityStreamConfigImpl.builder().build();
222+
var factory = new BetweennessCentralityFactory<>();
223+
var log = Neo4jProxy.testLog();
224+
var testGraph = fromGdl(DIAMOND, "diamond");
225+
var progressTracker = new TestProgressTracker(
226+
factory.progressTask(testGraph, config),
227+
log,
228+
4,
229+
EmptyTaskRegistryFactory.INSTANCE
230+
);
231+
factory.build(testGraph, config, progressTracker).compute();
232+
233+
assertThat(log.getMessages(TestLog.INFO))
234+
.extracting(removingThreadId())
235+
.extracting(replaceTimings())
236+
.containsExactly(
237+
"BetweennessCentrality :: Start",
238+
"BetweennessCentrality 14%",
239+
"BetweennessCentrality 28%",
240+
"BetweennessCentrality 42%",
241+
"BetweennessCentrality 57%",
242+
"BetweennessCentrality 71%",
243+
"BetweennessCentrality 85%",
244+
"BetweennessCentrality 100%",
245+
"BetweennessCentrality :: Finished"
246+
);
247+
}
185248
}

0 commit comments

Comments
 (0)