Skip to content

Commit 146af9e

Browse files
Implement global relabelling
Co-authored-by: Ioannis Panagiotas <ioannis.panagiotas@neo4j.com>
1 parent 627c8a9 commit 146af9e

File tree

2 files changed

+240
-0
lines changed

2 files changed

+240
-0
lines changed
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.maxflow;
21+
22+
import org.neo4j.gds.collections.ha.HugeLongArray;
23+
import org.neo4j.gds.core.concurrency.Concurrency;
24+
import org.neo4j.gds.core.concurrency.ParallelUtil;
25+
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
26+
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
27+
import org.neo4j.gds.core.utils.paged.HugeLongArrayQueue;
28+
29+
public class GlobalRelabeling {
30+
public static void globalRelabeling(
31+
FlowGraph flowGraph,
32+
HugeLongArray label,
33+
long source,
34+
long target,
35+
Concurrency concurrency
36+
) {
37+
label.setAll((i) -> flowGraph.nodeCount());
38+
label.set(target, 0L);
39+
var vertexIsDiscovered = HugeAtomicBitSet.create(flowGraph.nodeCount());
40+
41+
var frontier = new AtomicWorkingSet(flowGraph.nodeCount());
42+
frontier.push(target);
43+
vertexIsDiscovered.set(target);
44+
vertexIsDiscovered.set(source);
45+
46+
47+
var tasks = ParallelUtil.tasks(
48+
concurrency,
49+
() -> new GlobalRelabellingBFSTask(flowGraph.concurrentCopy(), frontier, vertexIsDiscovered, label)
50+
);
51+
52+
while (!frontier.isEmpty()) {
53+
RunWithConcurrency.builder().concurrency(concurrency).tasks(tasks).build().run();
54+
frontier.reset();
55+
RunWithConcurrency.builder().concurrency(concurrency).tasks(tasks).build().run();
56+
}
57+
label.set(source, flowGraph.nodeCount());
58+
}
59+
}
60+
class GlobalRelabellingBFSTask implements Runnable {
61+
private final FlowGraph flowGraph;
62+
private final AtomicWorkingSet frontier;
63+
private final HugeLongArrayQueue localDiscoveredVertices;
64+
private final HugeAtomicBitSet verticesIsDiscovered;
65+
private final HugeLongArray label;
66+
private Phase phase;
67+
private final long batchSize;
68+
private final long LOCAL_QUEUE_BOUND = 128L;
69+
70+
GlobalRelabellingBFSTask(
71+
FlowGraph flowGraph,
72+
AtomicWorkingSet frontier,
73+
HugeAtomicBitSet vertexIsDiscovered,
74+
HugeLongArray label
75+
) {
76+
this.flowGraph = flowGraph;
77+
this.frontier = frontier;
78+
this.localDiscoveredVertices = HugeLongArrayQueue.newQueue(flowGraph.nodeCount()); //think
79+
this.verticesIsDiscovered = vertexIsDiscovered;
80+
this.label = label;
81+
this.phase = Phase.TRAVERSE;
82+
this.batchSize = 1024L;
83+
}
84+
85+
@Override
86+
public void run() {
87+
if (phase == Phase.TRAVERSE){
88+
traverse();
89+
} else {
90+
addToFrontier();
91+
}
92+
}
93+
94+
private void batchTraverse(long from, long to) {
95+
for(long idx = from; idx < to; idx++) {
96+
long v = frontier.unsafePeek(idx);
97+
singleTraverse(v);
98+
}
99+
}
100+
101+
private void singleTraverse(long v){
102+
var newLabel = label.get(v) + 1;
103+
flowGraph.forEachRelationship(v, (s, t, relIdx, residualCapacity, isReverse) -> {
104+
//(s)-->(t) //want t-->s to have free capacity. (can push from t to s)
105+
if(flowGraph.residualCapacity(relIdx, isReverse) <= 0.0) {
106+
return true;
107+
}
108+
if(!verticesIsDiscovered.getAndSet(t)) {
109+
localDiscoveredVertices.add(t);
110+
label.set(t, newLabel);
111+
}
112+
return true;
113+
});
114+
}
115+
116+
public void traverse() {
117+
long oldIdx;
118+
while((oldIdx = frontier.getAndAdd(batchSize)) < frontier.size()) {
119+
long toIdx = Math.min(oldIdx + batchSize, frontier.size());
120+
batchTraverse(oldIdx, toIdx);
121+
}
122+
123+
//do some local processing if the localQueue is small enough
124+
while (!localDiscoveredVertices.isEmpty() && localDiscoveredVertices.size() < LOCAL_QUEUE_BOUND) {
125+
long nodeId = localDiscoveredVertices.remove();
126+
singleTraverse(nodeId);
127+
}
128+
phase = Phase.SYNC;
129+
}
130+
131+
public void addToFrontier() {
132+
frontier.batchPush(localDiscoveredVertices);
133+
phase = Phase.TRAVERSE;
134+
}
135+
}
136+
137+
enum Phase {
138+
TRAVERSE,
139+
SYNC
140+
}
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.maxflow;
21+
22+
import org.junit.jupiter.api.Test;
23+
import org.neo4j.gds.collections.ha.HugeLongArray;
24+
import org.neo4j.gds.core.concurrency.Concurrency;
25+
import org.neo4j.gds.extension.GdlExtension;
26+
import org.neo4j.gds.extension.GdlGraph;
27+
import org.neo4j.gds.extension.Inject;
28+
import org.neo4j.gds.extension.TestGraph;
29+
30+
import static org.assertj.core.api.Assertions.assertThat;
31+
32+
@GdlExtension
33+
class GlobalRelabelingTest {
34+
@GdlGraph
35+
private static final String GRAPH =
36+
"""
37+
CREATE
38+
(a:Node {id: 0}),
39+
(b:Node {id: 1}),
40+
(c:Node {id: 2}),
41+
(d:Node {id: 3}),
42+
(e:Node {id: 4}),
43+
(a)-[:R {w: 4.0}]->(d),
44+
(b)-[:R {w: 3.0}]->(a),
45+
(c)-[:R {w: 2.0}]->(a),
46+
(c)-[:R {w: 0.0}]->(b),
47+
(d)-[:R {w: 5.0}]->(e)
48+
""";
49+
50+
@Inject
51+
private TestGraph graph;
52+
53+
//d = 0
54+
//a = 1
55+
//b,c = 2,
56+
//c <- nodeCount(source) = 5
57+
//e <- nodeCount(init) = 5
58+
59+
@Test
60+
void test() {
61+
var flowGraph = FlowGraph.create(graph);
62+
63+
var label = HugeLongArray.newArray(flowGraph.nodeCount());
64+
label.setAll((i) -> flowGraph.nodeCount());
65+
GlobalRelabeling.globalRelabeling(
66+
flowGraph,
67+
label,
68+
graph.toMappedNodeId("c"),
69+
graph.toMappedNodeId("d"),
70+
new Concurrency(1)
71+
);
72+
73+
assertThat(label.get(graph.toMappedNodeId("a"))).isEqualTo(1L);
74+
assertThat(label.get(graph.toMappedNodeId("b"))).isEqualTo(2L);
75+
assertThat(label.get(graph.toMappedNodeId("c"))).isEqualTo(5L);
76+
assertThat(label.get(graph.toMappedNodeId("d"))).isEqualTo(0L);
77+
assertThat(label.get(graph.toMappedNodeId("e"))).isEqualTo(5L);
78+
}
79+
80+
@Test
81+
void test2() {
82+
var flowGraph = FlowGraph.create(graph);
83+
84+
var label = HugeLongArray.newArray(flowGraph.nodeCount());
85+
label.setAll((i) -> flowGraph.nodeCount());
86+
GlobalRelabeling.globalRelabeling(
87+
flowGraph,
88+
label,
89+
graph.toMappedNodeId("a"),
90+
graph.toMappedNodeId("e"),
91+
new Concurrency(1)
92+
);
93+
94+
assertThat(label.get(graph.toMappedNodeId("a"))).isEqualTo(5L);
95+
assertThat(label.get(graph.toMappedNodeId("b"))).isEqualTo(5L);
96+
assertThat(label.get(graph.toMappedNodeId("c"))).isEqualTo(5L);
97+
assertThat(label.get(graph.toMappedNodeId("d"))).isEqualTo(1L);
98+
assertThat(label.get(graph.toMappedNodeId("e"))).isEqualTo(0L);
99+
}
100+
}

0 commit comments

Comments
 (0)