Skip to content

Commit ebb8722

Browse files
Create memory estimation for max flow
1 parent 496a320 commit ebb8722

File tree

8 files changed

+325
-87
lines changed

8 files changed

+325
-87
lines changed

algo/src/main/java/org/neo4j/gds/maxflow/Discharging.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,17 @@
3232
import java.util.List;
3333
import java.util.concurrent.atomic.AtomicLong;
3434

35-
public class Discharging {
35+
final class Discharging {
3636
private final AtomicWorkingSet workingSet;
3737
private final Concurrency concurrency;
3838
private final Collection<Runnable> dischargeTasks;
3939

40-
public static Discharging createDischarging(
40+
41+
static Discharging createDischarging(
4142
FlowGraph flowGraph,
4243
HugeDoubleArray excess,
4344
HugeLongArray label,
44-
HugeLongArray tempLabel,
4545
HugeAtomicDoubleArray addedExcess,
46-
HugeAtomicBitSet isDiscovered,
4746
AtomicWorkingSet workingSet,
4847
long targetNode,
4948
long beta,
@@ -52,6 +51,10 @@ public static Discharging createDischarging(
5251
HugeLongArrayQueue[] threadQueues
5352
)
5453
{
54+
55+
var tempLabel = HugeLongArray.newArray(flowGraph.originalNodeCount());
56+
var isDiscovered = HugeAtomicBitSet.create(flowGraph.originalNodeCount());
57+
5558
List<Runnable> dischargeTasks = new ArrayList<>();
5659
for (int i = 0; i < concurrency.value(); i++) {
5760
dischargeTasks.add(new DischargeTask(
@@ -78,7 +81,7 @@ private Discharging(AtomicWorkingSet workingSet, Collection<Runnable> dischargeT
7881
this.concurrency = concurrency;
7982
}
8083

81-
public void processWorkingSet() {
84+
void processWorkingSet() {
8285
//Discharge working set
8386
RunWithConcurrency.builder().concurrency(concurrency).tasks(dischargeTasks).build().run();
8487
workingSet.resetIdx();

algo/src/main/java/org/neo4j/gds/maxflow/FlowGraph.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,15 +209,15 @@ public void push(long relIdx, double delta, boolean isReverse) {
209209
}
210210
}
211211

212-
long originalEdgeCount() {
212+
private long originalEdgeCount() {
213213
return graph.relationshipCount();
214214
}
215215

216216
long edgeCount() {
217217
return graph.relationshipCount() + supply.length + demand.length;
218218
}
219219

220-
public long originalNodeCount() {
220+
long originalNodeCount() {
221221
return graph.nodeCount();
222222
}
223223

algo/src/main/java/org/neo4j/gds/maxflow/GlobalRelabeling.java

Lines changed: 2 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,9 @@
2929
import java.util.Collection;
3030
import java.util.List;
3131

32-
enum Phase {
33-
TRAVERSE,
34-
SYNC
35-
}
3632

37-
public class GlobalRelabeling {
33+
34+
public final class GlobalRelabeling {
3835
private final long nodeCount;
3936
private final HugeLongArray label;
4037
private final AtomicWorkingSet frontier;
@@ -92,76 +89,3 @@ public void globalRelabeling() {
9289
}
9390
}
9491

95-
class GlobalRelabellingBFSTask implements Runnable {
96-
private final FlowGraph flowGraph;
97-
private final AtomicWorkingSet frontier;
98-
private final HugeLongArrayQueue localDiscoveredVertices;
99-
private final HugeAtomicBitSet verticesIsDiscovered;
100-
private final HugeLongArray label;
101-
private final long batchSize;
102-
private final long LOCAL_QUEUE_BOUND = 128L;
103-
private Phase phase;
104-
105-
GlobalRelabellingBFSTask(
106-
FlowGraph flowGraph,
107-
AtomicWorkingSet frontier,
108-
HugeAtomicBitSet vertexIsDiscovered,
109-
HugeLongArray label,
110-
HugeLongArrayQueue localDiscoveredVertices
111-
) {
112-
this.flowGraph = flowGraph;
113-
this.frontier = frontier;
114-
// this.localDiscoveredVertices = HugeLongArrayQueue.newQueue(flowGraph.nodeCount()); //think //fixme: Don't allocate every time
115-
this.localDiscoveredVertices = localDiscoveredVertices;
116-
this.verticesIsDiscovered = vertexIsDiscovered;
117-
this.label = label;
118-
this.phase = Phase.TRAVERSE;
119-
this.batchSize = 1024L;
120-
}
121-
122-
@Override
123-
public void run() {
124-
if (phase == Phase.TRAVERSE) {
125-
traverse();
126-
} else {
127-
addToFrontier();
128-
}
129-
}
130-
131-
private void singleTraverse(long v) {
132-
var newLabel = label.get(v) + 1;
133-
flowGraph.forEachRelationship(
134-
v, (s, t, relIdx, residualCapacity, isReverse) -> {
135-
//(s)-->(t) //want t-->s to have free capacity. (can push from t to s)
136-
if (flowGraph.residualCapacity(relIdx, isReverse) <= 0.0) {
137-
return true;
138-
}
139-
if (!verticesIsDiscovered.getAndSet(t)) {
140-
localDiscoveredVertices.add(t);
141-
label.set(t, newLabel);
142-
}
143-
return true;
144-
}
145-
);
146-
}
147-
148-
public void traverse() {
149-
long oldIdx;
150-
while ((oldIdx = frontier.getAndAdd(batchSize)) < frontier.size()) {
151-
long toIdx = Math.min(oldIdx + batchSize, frontier.size());
152-
frontier.consumeBatch(oldIdx, toIdx, this::singleTraverse);
153-
}
154-
155-
//do some local processing if the localQueue is small enough
156-
while (!localDiscoveredVertices.isEmpty() && localDiscoveredVertices.size() < LOCAL_QUEUE_BOUND) {
157-
long nodeId = localDiscoveredVertices.remove();
158-
singleTraverse(nodeId);
159-
}
160-
phase = Phase.SYNC;
161-
}
162-
163-
public void addToFrontier() {
164-
frontier.batchPush(localDiscoveredVertices);
165-
phase = Phase.TRAVERSE;
166-
}
167-
}
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.maxflow;
21+
22+
import org.neo4j.gds.collections.ha.HugeLongArray;
23+
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
24+
import org.neo4j.gds.core.utils.paged.HugeLongArrayQueue;
25+
26+
class GlobalRelabellingBFSTask implements Runnable {
27+
private final FlowGraph flowGraph;
28+
private final AtomicWorkingSet frontier;
29+
private final HugeLongArrayQueue localDiscoveredVertices;
30+
private final HugeAtomicBitSet verticesIsDiscovered;
31+
private final HugeLongArray label;
32+
private final long batchSize;
33+
private final long LOCAL_QUEUE_BOUND = 128L;
34+
private Phase phase;
35+
36+
GlobalRelabellingBFSTask(
37+
FlowGraph flowGraph,
38+
AtomicWorkingSet frontier,
39+
HugeAtomicBitSet vertexIsDiscovered,
40+
HugeLongArray label,
41+
HugeLongArrayQueue localDiscoveredVertices
42+
) {
43+
this.flowGraph = flowGraph;
44+
this.frontier = frontier;
45+
this.localDiscoveredVertices = localDiscoveredVertices;
46+
this.verticesIsDiscovered = vertexIsDiscovered;
47+
this.label = label;
48+
this.phase = Phase.TRAVERSE;
49+
this.batchSize = 1024L;
50+
}
51+
52+
@Override
53+
public void run() {
54+
if (phase == Phase.TRAVERSE) {
55+
traverse();
56+
} else {
57+
addToFrontier();
58+
}
59+
}
60+
61+
private void singleTraverse(long v) {
62+
var newLabel = label.get(v) + 1;
63+
flowGraph.forEachRelationship(
64+
v, (s, t, relIdx, residualCapacity, isReverse) -> {
65+
//(s)-->(t) //want t-->s to have free capacity. (can push from t to s)
66+
if (flowGraph.residualCapacity(relIdx, isReverse) <= 0.0) {
67+
return true;
68+
}
69+
if (!verticesIsDiscovered.getAndSet(t)) {
70+
localDiscoveredVertices.add(t);
71+
label.set(t, newLabel);
72+
}
73+
return true;
74+
}
75+
);
76+
}
77+
78+
public void traverse() {
79+
long oldIdx;
80+
while ((oldIdx = frontier.getAndAdd(batchSize)) < frontier.size()) {
81+
long toIdx = Math.min(oldIdx + batchSize, frontier.size());
82+
frontier.consumeBatch(oldIdx, toIdx, this::singleTraverse);
83+
}
84+
85+
//do some local processing if the localQueue is small enough
86+
while (!localDiscoveredVertices.isEmpty() && localDiscoveredVertices.size() < LOCAL_QUEUE_BOUND) {
87+
long nodeId = localDiscoveredVertices.remove();
88+
singleTraverse(nodeId);
89+
}
90+
phase = Phase.SYNC;
91+
}
92+
93+
private void addToFrontier() {
94+
frontier.batchPush(localDiscoveredVertices);
95+
phase = Phase.TRAVERSE;
96+
}
97+
98+
private enum Phase {
99+
TRAVERSE,
100+
SYNC
101+
}
102+
}

algo/src/main/java/org/neo4j/gds/maxflow/MaxFlow.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,7 @@ private void maximizeFlow(Preflow preflow, long sourceNode, long targetNode) { /
117117
flowGraph,
118118
excess,
119119
label,
120-
tempLabel,
121120
addedExcess,
122-
isDiscovered,
123121
workingSet,
124122
targetNode,
125123
parameters.beta(),
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.maxflow;
21+
22+
import org.neo4j.gds.collections.ha.HugeDoubleArray;
23+
import org.neo4j.gds.collections.ha.HugeLongArray;
24+
import org.neo4j.gds.collections.ha.HugeObjectArray;
25+
import org.neo4j.gds.core.GraphDimensions;
26+
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
27+
import org.neo4j.gds.mem.Estimate;
28+
import org.neo4j.gds.mem.MemoryEstimateDefinition;
29+
import org.neo4j.gds.mem.MemoryEstimation;
30+
import org.neo4j.gds.mem.MemoryEstimations;
31+
import org.neo4j.gds.mem.MemoryRange;
32+
33+
import java.util.function.BiFunction;
34+
import java.util.function.Function;
35+
36+
public class MaxFlowMemoryEstimateDefinition implements MemoryEstimateDefinition {
37+
38+
private final long numberOfSinks;
39+
private final long numberOfTerminals;
40+
41+
public MaxFlowMemoryEstimateDefinition(long numberOfSinks, long numberOfTerminals) {
42+
this.numberOfSinks = numberOfSinks;
43+
this.numberOfTerminals = numberOfTerminals;
44+
}
45+
46+
private MemoryEstimation atomicWorkingSet(){
47+
return MemoryEstimations.builder(AtomicWorkingSet.class)
48+
.perNode("working set", HugeLongArray::memoryEstimation)
49+
.build();
50+
51+
}
52+
private MemoryEstimation globalRelabellingTask(){
53+
return MemoryEstimations.builder(GlobalRelabellingBFSTask.class).build();
54+
}
55+
private MemoryEstimation globalRelabelling(){
56+
return MemoryEstimations.builder(GlobalRelabeling.class)
57+
.perThread("Global Relabelling task",globalRelabellingTask())
58+
.add("frontier",atomicWorkingSet())
59+
.perNode("isDiscovered",HugeAtomicBitSet::memoryEstimation)
60+
.build();
61+
}
62+
63+
private MemoryEstimation dischargeTask(){
64+
return MemoryEstimations.builder(DischargeTask.class).build();
65+
}
66+
67+
private MemoryEstimation discharging(){
68+
return MemoryEstimations.builder(Discharging.class)
69+
.perThread("Discharge task", dischargeTask())
70+
.perNode("temp label", HugeLongArray::memoryEstimation)
71+
.perNode("isDiscovered",HugeAtomicBitSet::memoryEstimation)
72+
.build();
73+
}
74+
75+
private MemoryEstimation flowGraph(){
76+
BiFunction<GraphDimensions, Function<Long,Long>,MemoryRange> relConsumer =
77+
((graphDimensions, longMemoryRangeFunction) -> {
78+
var newRel = graphDimensions.relCountUpperBound() + numberOfSinks + numberOfTerminals;
79+
return MemoryRange.of(longMemoryRangeFunction.apply(newRel));
80+
});
81+
BiFunction<GraphDimensions, Function<Long,Long>,MemoryRange> nodeConsumer =
82+
((graphDimensions, longMemoryRangeFunction) -> {
83+
var newRel = graphDimensions.nodeCount() + 2;
84+
return MemoryRange.of(longMemoryRangeFunction.apply(newRel));
85+
});
86+
87+
//skip revDegree array during construction because it is used only during construction
88+
return MemoryEstimations.builder(FlowGraph.class)
89+
.perNode("index offset", HugeLongArray::memoryEstimation)
90+
.perGraphDimension("flow",((dimensions, ___) -> relConsumer.apply(dimensions, HugeDoubleArray::memoryEstimation)))
91+
.perGraphDimension("capacity",((dimensions, ___) -> relConsumer.apply(dimensions, HugeDoubleArray::memoryEstimation)))
92+
.perGraphDimension("reverse adjacency",((dimensions, ___) -> relConsumer.apply(dimensions, HugeLongArray::memoryEstimation)))
93+
.perGraphDimension("reverse index",((dimensions, ___) -> relConsumer.apply(dimensions, HugeLongArray::memoryEstimation)))
94+
.perGraphDimension("reverse offset",((dimensions, ___) -> nodeConsumer.apply(dimensions, HugeLongArray::memoryEstimation)))
95+
96+
.build();
97+
}
98+
99+
private MemoryEstimation flowResult() {
100+
return MemoryEstimations.builder(FlowResult.class)
101+
.perGraphDimension(
102+
"output", ((dimensions, ___) -> {
103+
var sizeOfFlowRelationship = Estimate.sizeOfInstance(FlowRelationship.class);
104+
return MemoryRange.of(HugeObjectArray.memoryEstimation(
105+
dimensions.relCountUpperBound(),
106+
sizeOfFlowRelationship
107+
));
108+
})
109+
).build();
110+
}
111+
112+
@Override
113+
public MemoryEstimation memoryEstimation() {
114+
return MemoryEstimations.builder(MaxFlow.class)
115+
.fixed("supply", Estimate.sizeOfInstance(NodeWithValue.class) * numberOfSinks)
116+
.fixed("demand", Estimate.sizeOfInstance(NodeWithValue.class) * numberOfTerminals)
117+
.perGraphDimension("thread queues", (dimensions,concurrency)-> MemoryRange.of(dimensions.nodeCount() * concurrency.value()))
118+
.add("flowGraph",flowGraph())
119+
.add("Discharging", discharging())
120+
.add("Global relabelling", globalRelabelling())
121+
.add("result", flowResult())
122+
.build();
123+
}
124+
}

algo/src/test/java/org/neo4j/gds/assertions/MemoryEstimationAssert.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ public MemoryRangeAssert memoryRange(long nodeCount) {
6666
return memoryRange(nodeCount, 0, new Concurrency(1));
6767
}
6868

69+
public MemoryRangeAssert memoryRange(long nodeCount,long relationshipCount) {
70+
return memoryRange(nodeCount, relationshipCount, new Concurrency(1));
71+
}
6972

7073
public MemoryTreeAssert memoryTree(GraphDimensions graphDimensions, Concurrency concurrency) {
7174
isNotNull();

0 commit comments

Comments
 (0)