Skip to content

Commit 1d5b748

Browse files
Progress tracking and zero-comparisons
Co-authored-by: Ioannis Panagiotas <ioannis.panagiotas@neo4j.com>
1 parent 97af180 commit 1d5b748

File tree

7 files changed

+125
-27
lines changed

7 files changed

+125
-27
lines changed

algo/src/main/java/org/neo4j/gds/maxflow/DischargeTask.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ public class DischargeTask implements Runnable {
4949
private PHASE phase;
5050
private long localWork;
5151

52+
double TOLERANCE = 1e-9;
53+
5254
public DischargeTask(
5355
FlowGraph flowGraph,
5456
HugeDoubleArray excess,
@@ -110,19 +112,17 @@ void discharge(long v) {
110112
tempLabel.set(v, label.get(v));
111113
final var e = new MutableDouble(excess.get(v));
112114

113-
while (e.doubleValue() > 0) {
115+
while (e.doubleValue() > TOLERANCE) {
114116
final var newLabel = new MutableLong(nodeCount);
115117
final var breakOuter = new MutableBoolean(false);
116118

117-
//todo: Check to improve zero comparisons!
118-
119119
ResidualEdgeConsumer consumer = (long s, long t, long relIdx, double residualCapacity, boolean isReverse) -> {
120-
if (residualCapacity <= 0.0) {
120+
if (residualCapacity < TOLERANCE) {
121121
return true; //skip
122122
}
123123
var admissible = (tempLabel.get(s) == label.get(t) + 1);
124124
if (admissible) {
125-
if (excess.get(t) > 0.0) {
125+
if (excess.get(t) > TOLERANCE) {
126126
boolean win = (label.get(s) == label.get(t) + 1 || label.get(s) + 1 < label.get(t) || (label.get(
127127
s) == label.get(t) && s < t));
128128
if (!win) {
@@ -138,7 +138,7 @@ void discharge(long v) {
138138
if (!isDiscovered.getAndSet(t)) {
139139
localDiscoveredVertices.add(t);
140140
}
141-
if (e.doubleValue() <= 0.0) {
141+
if (e.doubleValue() < TOLERANCE) {
142142
breakOuter.setTrue();
143143
return false;
144144
}
@@ -161,7 +161,7 @@ void discharge(long v) {
161161
}
162162
}
163163
addedExcess.getAndAdd(v, (e.doubleValue() - excess.get(v)));
164-
if (e.doubleValue() > 0.0 && !isDiscovered.getAndSet(v)) {
164+
if (e.doubleValue() > TOLERANCE && !isDiscovered.getAndSet(v)) {
165165
localDiscoveredVertices.add(v);
166166
}
167167
}

algo/src/main/java/org/neo4j/gds/maxflow/MaxFlow.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,13 @@ public MaxFlow(
5151
}
5252

5353
public FlowResult compute() {
54+
progressTracker.beginSubTask();
5455
var preflow = initPreflow();
5556
var superSource = preflow.flowGraph().superSource();
5657
var superTarget = preflow.flowGraph().superTarget();
5758
maximizeFlow(preflow, superSource, superTarget);
5859
maximizeFlow(preflow, superTarget, superSource);
60+
progressTracker.endSubTask();
5961
return preflow.flowGraph().createFlowResult();
6062
}
6163

@@ -88,7 +90,9 @@ private void maximizeFlow(Preflow preflow, long sourceNode, long targetNode) { /
8890
ParallelDoublePageCreator.passThrough(parameters.concurrency())
8991
);
9092
var workingSet = new AtomicWorkingSet(nodeCount);
93+
var initialTotalExcess = 0D;
9194
for (var nodeId = 0; nodeId < nodeCount; nodeId++) {
95+
initialTotalExcess += excess.get(nodeId);
9296
if (excess.get(nodeId) > 0.0) {
9397
workingSet.push(nodeId);
9498
}
@@ -123,14 +127,19 @@ private void maximizeFlow(Preflow preflow, long sourceNode, long targetNode) { /
123127
parameters.concurrency(),
124128
threadQueues,
125129
terminationFlag
126-
);
130+
);
127131

132+
var excessAtDestinations = excess.get(sourceNode) + excess.get(targetNode);
128133
while (!workingSet.isEmpty()) {
129134
if (parameters.freq() * workSinceLastGR.doubleValue() > parameters.alpha() * nodeCount + edgeCount) {
130135
globalRelabeling.globalRelabeling();
131136
workSinceLastGR.set(0L);
132137
}
133138
discharging.processWorkingSet();
139+
140+
var newExcessAtDestinations = excess.get(sourceNode) + excess.get(targetNode);
141+
progressTracker.logProgress((long) ( Math.ceil(newExcessAtDestinations * progressTracker.currentVolume() / initialTotalExcess) - Math.ceil(excessAtDestinations * progressTracker.currentVolume() / initialTotalExcess) ));
142+
excessAtDestinations = newExcessAtDestinations;
134143
}
135144
}
136145
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.maxflow;
21+
22+
import org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel;
23+
import org.neo4j.gds.core.utils.progress.tasks.Task;
24+
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
25+
26+
public final class MaxFlowTask {
27+
28+
private MaxFlowTask() {}
29+
30+
public static Task create() {
31+
return Tasks.leaf(AlgorithmLabel.MaxFlow.asString(), 100);
32+
}
33+
}

algo/src/test/java/org/neo4j/gds/maxflow/MaxFlowMemoryEstimateDefinitionTest.java

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ class MaxFlowMemoryEstimateDefinitionTest {
2929
@ParameterizedTest
3030
@CsvSource(
3131
{
32-
"1_000, 1_000, 110_288",
33-
"1_000, 10_000, 794_288",
34-
"1_000_000, 1_000_000, 109_252_248",
35-
"1_000_000, 10_000_000, 793_263_232"
32+
"1_000, 1_000, 110_312",
33+
"1_000, 10_000, 794_312",
34+
"1_000_000, 1_000_000, 109_252_272",
35+
"1_000_000, 10_000_000, 793_263_256"
3636
}
3737
)
3838
void shouldEstimateMemoryWithChangingGraphDimensionsCorrectly(long nodeCount, long relationshipCount, long expected){
@@ -47,10 +47,10 @@ void shouldEstimateMemoryWithChangingGraphDimensionsCorrectly(long nodeCount, lo
4747
@ParameterizedTest
4848
@CsvSource(
4949
{
50-
"1_000, 1, 110_288",
51-
"1_000, 4, 113_744",
52-
"100_000, 1, 10_926_160",
53-
"100_000, 4, 11_226_616"
50+
"1_000, 1, 110_312",
51+
"1_000, 4, 113_792",
52+
"100_000, 1, 10_926_184",
53+
"100_000, 4, 11_226_664"
5454
}
5555
)
5656
void shouldEstimateMemoryWithChangingConcurrencyCorrectly(long nodeAndRelCount, int concurrency, long expected){
@@ -65,10 +65,10 @@ void shouldEstimateMemoryWithChangingConcurrencyCorrectly(long nodeAndRelCount,
6565
@ParameterizedTest
6666
@CsvSource(
6767
{
68-
"1_000,1,1, 110_288",
69-
"1_000,10,1, 110_864",
70-
"1_000,1,10, 110_864",
71-
"1_000,10,10, 111_440"
68+
"1_000,1,1, 110_312",
69+
"1_000,10,1, 110_888",
70+
"1_000,1,10, 110_888",
71+
"1_000,10,10, 111_464"
7272

7373
}
7474
)

algo/src/test/java/org/neo4j/gds/maxflow/MaxFlowTest.java

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,29 @@
2626
import org.neo4j.gds.ListInputNodes;
2727
import org.neo4j.gds.MapInputNodes;
2828
import org.neo4j.gds.Orientation;
29-
import org.neo4j.gds.TestSupport;
29+
import org.neo4j.gds.TestProgressTracker;
3030
import org.neo4j.gds.api.Graph;
3131
import org.neo4j.gds.beta.generator.PropertyProducer;
3232
import org.neo4j.gds.beta.generator.RandomGraphGenerator;
33+
import org.neo4j.gds.compat.TestLog;
3334
import org.neo4j.gds.core.concurrency.Concurrency;
35+
import org.neo4j.gds.core.utils.logging.LoggerForProgressTrackingAdapter;
36+
import org.neo4j.gds.core.utils.progress.EmptyTaskRegistryFactory;
37+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
3438
import org.neo4j.gds.extension.GdlExtension;
3539
import org.neo4j.gds.extension.GdlGraph;
3640
import org.neo4j.gds.extension.Inject;
3741
import org.neo4j.gds.extension.TestGraph;
42+
import org.neo4j.gds.logging.GdsTestLog;
43+
import org.neo4j.gds.termination.TerminationFlag;
3844

3945
import java.util.List;
4046
import java.util.Map;
4147

4248
import static org.assertj.core.api.Assertions.assertThat;
49+
import static org.neo4j.gds.TestSupport.fromGdl;
50+
import static org.neo4j.gds.assertj.Extractors.removingThreadId;
51+
import static org.neo4j.gds.assertj.Extractors.replaceTimings;
4352
import static org.neo4j.gds.beta.generator.RelationshipDistribution.UNIFORM;
4453
import static org.neo4j.gds.maxflow.MaxFlow.ALPHA;
4554
import static org.neo4j.gds.maxflow.MaxFlow.BETA;
@@ -82,7 +91,7 @@ Graph generateUniform(long nodeCount, int avgDegree) {
8291

8392
void testGraph(Graph graph, InputNodes sourceNodes, InputNodes targetNodes, double expectedFlow, int concurrency) {
8493
var params = new MaxFlowParameters(sourceNodes, targetNodes, new Concurrency(concurrency), ALPHA, BETA, FREQ);
85-
var x = new MaxFlow(graph, params, null, null);
94+
var x = new MaxFlow(graph, params, ProgressTracker.NULL_TRACKER, TerminationFlag.RUNNING_TRUE);
8695
var result = x.compute();
8796
assertThat(result.totalFlow()).isCloseTo(expectedFlow, Offset.offset(TOLERANCE));
8897
}
@@ -105,7 +114,7 @@ void test() {
105114

106115
@Test
107116
void test2a() {
108-
var graph = TestSupport.fromGdl(
117+
var graph = fromGdl(
109118
"""
110119
CREATE
111120
(a0)-[:R {capacity: 91}]->(a1),
@@ -123,7 +132,7 @@ void test2a() {
123132

124133
@Test
125134
void test2() {
126-
var graph = TestSupport.fromGdl(
135+
var graph = fromGdl(
127136
"""
128137
CREATE
129138
(a0)-[:R {capacity: 50}]->(a5),
@@ -160,7 +169,7 @@ void test2() {
160169

161170
@Test
162171
void test3() {
163-
var graph = TestSupport.fromGdl(
172+
var graph = fromGdl(
164173
"""
165174
CREATE
166175
(a0)-[:R {capacity: 50}]->(a10),
@@ -342,4 +351,47 @@ void test5() {
342351
362.79999999999995,
343352
4);
344353
}
354+
355+
@Test
356+
void shouldLogProgress() {
357+
var graph = generateUniform(100L, 10);
358+
var log = new GdsTestLog();
359+
var testTracker = new TestProgressTracker(
360+
MaxFlowTask.create(),
361+
new LoggerForProgressTrackingAdapter(log),
362+
new Concurrency(4),
363+
EmptyTaskRegistryFactory.INSTANCE
364+
);
365+
366+
new MaxFlow(
367+
graph,
368+
new MaxFlowParameters(
369+
new ListInputNodes(List.of(0L)),
370+
new ListInputNodes(List.of(2L)),
371+
new Concurrency(4),
372+
6L,
373+
12L,
374+
.5D
375+
),
376+
testTracker,
377+
TerminationFlag.RUNNING_TRUE
378+
).compute();
379+
380+
assertThat(log.getMessages(TestLog.INFO))
381+
.extracting(removingThreadId())
382+
.extracting(replaceTimings())
383+
.containsExactly(
384+
"MaxFlow :: Start",
385+
"MaxFlow 7%",
386+
"MaxFlow 40%",
387+
"MaxFlow 61%",
388+
"MaxFlow 68%",
389+
"MaxFlow 81%",
390+
"MaxFlow 88%",
391+
"MaxFlow 89%",
392+
"MaxFlow 94%",
393+
"MaxFlow 100%",
394+
"MaxFlow :: Finished"
395+
);
396+
}
345397
}

applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/PathFindingAlgorithmsBusinessFacade.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import org.neo4j.gds.kspanningtree.KSpanningTreeTask;
4242
import org.neo4j.gds.maxflow.FlowResult;
4343
import org.neo4j.gds.maxflow.MaxFlowBaseConfig;
44+
import org.neo4j.gds.maxflow.MaxFlowTask;
4445
import org.neo4j.gds.paths.RelationshipCountProgressTaskFactory;
4546
import org.neo4j.gds.paths.astar.config.ShortestPathAStarBaseConfig;
4647
import org.neo4j.gds.paths.bellmanford.AllShortestPathsBellmanFordBaseConfig;
@@ -215,7 +216,10 @@ PathFindingResult longestPath(Graph graph, DagLongestPathBaseConfig configuratio
215216
}
216217

217218
FlowResult maxFlow(Graph graph, MaxFlowBaseConfig configuration) {
218-
var progressTracker = ProgressTracker.NULL_TRACKER;
219+
var progressTracker = createProgressTracker(
220+
MaxFlowTask.create(),
221+
configuration
222+
);
219223

220224
return algorithmMachinery.getResult(
221225
() -> algorithms.maxFlow(

doc/modules/ROOT/pages/algorithms/max-flow.adoc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ RETURN nodeCount, relationshipCount, bytesMin, bytesMax, requiredMemory
258258
[opts="header"]
259259
|===
260260
| nodeCount | relationshipCount | bytesMin | bytesMax | requiredMemory
261-
| 6 | 7 | 2256 | 2256 | "2256 Bytes"
261+
| 6 | 7 | 2304 | 2304 | "2304 Bytes"
262262
|===
263263
--
264264

0 commit comments

Comments
 (0)