Skip to content

Commit 9b239d8

Browse files
Merge pull request #11193 from neo-technology/max-flow-memory-estimation
Max-flow memory estimation
2 parents d641a9c + 7dac4a0 commit 9b239d8

File tree

27 files changed

+506
-115
lines changed

27 files changed

+506
-115
lines changed

algo-params/common/src/main/java/org/neo4j/gds/InputNodes.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,8 @@ public interface InputNodes {
2626
InputNodes EMPTY_INPUT_NODES = new ListInputNodes(List.of());
2727

2828
Collection<Long> inputNodes();
29+
30+
default int size(){
31+
return inputNodes().size();
32+
}
2933
}

algo/src/main/java/org/neo4j/gds/maxflow/DischargeTask.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,6 @@ public class DischargeTask implements Runnable {
4949
private PHASE phase;
5050
private long localWork;
5151

52-
53-
5452
public DischargeTask(
5553
FlowGraph flowGraph,
5654
HugeDoubleArray excess,

algo/src/main/java/org/neo4j/gds/maxflow/Discharging.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,17 @@
3232
import java.util.List;
3333
import java.util.concurrent.atomic.AtomicLong;
3434

35-
public class Discharging {
35+
final class Discharging {
3636
private final AtomicWorkingSet workingSet;
3737
private final Concurrency concurrency;
3838
private final Collection<Runnable> dischargeTasks;
3939

40-
public static Discharging createDischarging(
40+
41+
static Discharging createDischarging(
4142
FlowGraph flowGraph,
4243
HugeDoubleArray excess,
4344
HugeLongArray label,
44-
HugeLongArray tempLabel,
4545
HugeAtomicDoubleArray addedExcess,
46-
HugeAtomicBitSet isDiscovered,
4746
AtomicWorkingSet workingSet,
4847
long targetNode,
4948
long beta,
@@ -52,6 +51,10 @@ public static Discharging createDischarging(
5251
HugeLongArrayQueue[] threadQueues
5352
)
5453
{
54+
55+
var tempLabel = HugeLongArray.newArray(flowGraph.nodeCount());
56+
var isDiscovered = HugeAtomicBitSet.create(flowGraph.nodeCount());
57+
5558
List<Runnable> dischargeTasks = new ArrayList<>();
5659
for (int i = 0; i < concurrency.value(); i++) {
5760
dischargeTasks.add(new DischargeTask(
@@ -78,7 +81,7 @@ private Discharging(AtomicWorkingSet workingSet, Collection<Runnable> dischargeT
7881
this.concurrency = concurrency;
7982
}
8083

81-
public void processWorkingSet() {
84+
void processWorkingSet() {
8285
//Discharge working set
8386
RunWithConcurrency.builder().concurrency(concurrency).tasks(dischargeTasks).build().run();
8487
workingSet.resetIdx();

algo/src/main/java/org/neo4j/gds/maxflow/FlowGraph.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ public FlowGraph concurrentCopy() {
159159
}
160160

161161
private void forEachOriginalRelationship(long nodeId, ResidualEdgeConsumer consumer) {
162-
//todo: Rename original, since it also includes 'non-reverse' edges from superNodes
163162
var relIdx = new MutableLong(indPtr.get(nodeId));
164163
RelationshipWithPropertyConsumer originalConsumer = (s, t, capacity) -> {
165164
var residualCapacity = capacity - flow.get(relIdx.longValue());
@@ -209,15 +208,15 @@ public void push(long relIdx, double delta, boolean isReverse) {
209208
}
210209
}
211210

212-
long originalEdgeCount() {
211+
private long originalEdgeCount() {
213212
return graph.relationshipCount();
214213
}
215214

216215
long edgeCount() {
217216
return graph.relationshipCount() + supply.length + demand.length;
218217
}
219218

220-
public long originalNodeCount() {
219+
long originalNodeCount() {
221220
return graph.nodeCount();
222221
}
223222

algo/src/main/java/org/neo4j/gds/maxflow/GlobalRelabeling.java

Lines changed: 1 addition & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,7 @@
2929
import java.util.Collection;
3030
import java.util.List;
3131

32-
enum Phase {
33-
TRAVERSE,
34-
SYNC
35-
}
36-
37-
public class GlobalRelabeling {
32+
public final class GlobalRelabeling {
3833
private final long nodeCount;
3934
private final HugeLongArray label;
4035
private final AtomicWorkingSet frontier;
@@ -91,77 +86,3 @@ public void globalRelabeling() {
9186

9287
}
9388
}
94-
95-
class GlobalRelabellingBFSTask implements Runnable {
96-
private final FlowGraph flowGraph;
97-
private final AtomicWorkingSet frontier;
98-
private final HugeLongArrayQueue localDiscoveredVertices;
99-
private final HugeAtomicBitSet verticesIsDiscovered;
100-
private final HugeLongArray label;
101-
private final long batchSize;
102-
private final long LOCAL_QUEUE_BOUND = 128L;
103-
private Phase phase;
104-
105-
GlobalRelabellingBFSTask(
106-
FlowGraph flowGraph,
107-
AtomicWorkingSet frontier,
108-
HugeAtomicBitSet vertexIsDiscovered,
109-
HugeLongArray label,
110-
HugeLongArrayQueue localDiscoveredVertices
111-
) {
112-
this.flowGraph = flowGraph;
113-
this.frontier = frontier;
114-
// this.localDiscoveredVertices = HugeLongArrayQueue.newQueue(flowGraph.nodeCount()); //think //fixme: Don't allocate every time
115-
this.localDiscoveredVertices = localDiscoveredVertices;
116-
this.verticesIsDiscovered = vertexIsDiscovered;
117-
this.label = label;
118-
this.phase = Phase.TRAVERSE;
119-
this.batchSize = 1024L;
120-
}
121-
122-
@Override
123-
public void run() {
124-
if (phase == Phase.TRAVERSE) {
125-
traverse();
126-
} else {
127-
addToFrontier();
128-
}
129-
}
130-
131-
private void singleTraverse(long v) {
132-
var newLabel = label.get(v) + 1;
133-
flowGraph.forEachRelationship(
134-
v, (s, t, relIdx, residualCapacity, isReverse) -> {
135-
//(s)-->(t) //want t-->s to have free capacity. (can push from t to s)
136-
if (flowGraph.residualCapacity(relIdx, isReverse) <= 0.0) {
137-
return true;
138-
}
139-
if (!verticesIsDiscovered.getAndSet(t)) {
140-
localDiscoveredVertices.add(t);
141-
label.set(t, newLabel);
142-
}
143-
return true;
144-
}
145-
);
146-
}
147-
148-
public void traverse() {
149-
long oldIdx;
150-
while ((oldIdx = frontier.getAndAdd(batchSize)) < frontier.size()) {
151-
long toIdx = Math.min(oldIdx + batchSize, frontier.size());
152-
frontier.consumeBatch(oldIdx, toIdx, this::singleTraverse);
153-
}
154-
155-
//do some local processing if the localQueue is small enough
156-
while (!localDiscoveredVertices.isEmpty() && localDiscoveredVertices.size() < LOCAL_QUEUE_BOUND) {
157-
long nodeId = localDiscoveredVertices.remove();
158-
singleTraverse(nodeId);
159-
}
160-
phase = Phase.SYNC;
161-
}
162-
163-
public void addToFrontier() {
164-
frontier.batchPush(localDiscoveredVertices);
165-
phase = Phase.TRAVERSE;
166-
}
167-
}
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.maxflow;
21+
22+
import org.neo4j.gds.collections.ha.HugeLongArray;
23+
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
24+
import org.neo4j.gds.core.utils.paged.HugeLongArrayQueue;
25+
26+
class GlobalRelabellingBFSTask implements Runnable {
27+
private final FlowGraph flowGraph;
28+
private final AtomicWorkingSet frontier;
29+
private final HugeLongArrayQueue localDiscoveredVertices;
30+
private final HugeAtomicBitSet verticesIsDiscovered;
31+
private final HugeLongArray label;
32+
private final long batchSize;
33+
private final long LOCAL_QUEUE_BOUND = 128L;
34+
private Phase phase;
35+
36+
GlobalRelabellingBFSTask(
37+
FlowGraph flowGraph,
38+
AtomicWorkingSet frontier,
39+
HugeAtomicBitSet vertexIsDiscovered,
40+
HugeLongArray label,
41+
HugeLongArrayQueue localDiscoveredVertices
42+
) {
43+
this.flowGraph = flowGraph;
44+
this.frontier = frontier;
45+
this.localDiscoveredVertices = localDiscoveredVertices;
46+
this.verticesIsDiscovered = vertexIsDiscovered;
47+
this.label = label;
48+
this.phase = Phase.TRAVERSE;
49+
this.batchSize = 1024L;
50+
}
51+
52+
@Override
53+
public void run() {
54+
if (phase == Phase.TRAVERSE) {
55+
traverse();
56+
} else {
57+
addToFrontier();
58+
}
59+
}
60+
61+
private void singleTraverse(long v) {
62+
var newLabel = label.get(v) + 1;
63+
flowGraph.forEachRelationship(
64+
v, (s, t, relIdx, residualCapacity, isReverse) -> {
65+
//(s)-->(t) //want t-->s to have free capacity. (can push from t to s)
66+
if (flowGraph.residualCapacity(relIdx, isReverse) <= 0.0) {
67+
return true;
68+
}
69+
if (!verticesIsDiscovered.getAndSet(t)) {
70+
localDiscoveredVertices.add(t);
71+
label.set(t, newLabel);
72+
}
73+
return true;
74+
}
75+
);
76+
}
77+
78+
public void traverse() {
79+
long oldIdx;
80+
while ((oldIdx = frontier.getAndAdd(batchSize)) < frontier.size()) {
81+
long toIdx = Math.min(oldIdx + batchSize, frontier.size());
82+
frontier.consumeBatch(oldIdx, toIdx, this::singleTraverse);
83+
}
84+
85+
//do some local processing if the localQueue is small enough
86+
while (!localDiscoveredVertices.isEmpty() && localDiscoveredVertices.size() < LOCAL_QUEUE_BOUND) {
87+
long nodeId = localDiscoveredVertices.remove();
88+
singleTraverse(nodeId);
89+
}
90+
phase = Phase.SYNC;
91+
}
92+
93+
private void addToFrontier() {
94+
frontier.batchPush(localDiscoveredVertices);
95+
phase = Phase.TRAVERSE;
96+
}
97+
98+
private enum Phase {
99+
TRAVERSE,
100+
SYNC
101+
}
102+
}

algo/src/main/java/org/neo4j/gds/maxflow/MaxFlow.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import org.neo4j.gds.collections.ha.HugeDoubleArray;
2525
import org.neo4j.gds.collections.ha.HugeLongArray;
2626
import org.neo4j.gds.collections.haa.HugeAtomicDoubleArray;
27-
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
2827
import org.neo4j.gds.core.utils.paged.HugeLongArrayQueue;
2928
import org.neo4j.gds.core.utils.paged.ParallelDoublePageCreator;
3029
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
@@ -87,9 +86,7 @@ private void maximizeFlow(Preflow preflow, long sourceNode, long targetNode) { /
8786
var addedExcess = HugeAtomicDoubleArray.of(
8887
nodeCount,
8988
ParallelDoublePageCreator.passThrough(parameters.concurrency())
90-
); //fixme
91-
var tempLabel = HugeLongArray.newArray(nodeCount);
92-
var isDiscovered = HugeAtomicBitSet.create(nodeCount);
89+
);
9390
var workingSet = new AtomicWorkingSet(nodeCount);
9491
for (var nodeId = 0; nodeId < nodeCount; nodeId++) {
9592
if (excess.get(nodeId) > 0.0) {
@@ -117,9 +114,7 @@ private void maximizeFlow(Preflow preflow, long sourceNode, long targetNode) { /
117114
flowGraph,
118115
excess,
119116
label,
120-
tempLabel,
121117
addedExcess,
122-
isDiscovered,
123118
workingSet,
124119
targetNode,
125120
parameters.beta(),

0 commit comments

Comments
 (0)