Skip to content

Commit 2ee0301

Browse files
Add max-k-cut in compute facade
1 parent fb0a8d1 commit 2ee0301

File tree

5 files changed

+301
-1
lines changed

5 files changed

+301
-1
lines changed

algo/src/main/java/org/neo4j/gds/approxmaxkcut/ApproxMaxKCutResult.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,7 @@
2424
public record ApproxMaxKCutResult(
2525
HugeByteArray candidateSolution,
2626
double cutCost
27-
) {}
27+
) {
28+
29+
public static ApproxMaxKCutResult EMPTY = new ApproxMaxKCutResult(HugeByteArray.newArray(0),0d);
30+
}

algorithms-compute-facade/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ dependencies {
99
implementation project(':algo')
1010
implementation project(':algo-common')
1111
implementation project(':collections')
12+
implementation project(':community-params')
1213
implementation project(':core')
1314
implementation project(':core-api')
1415
implementation project(':logging')
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.community;
21+
22+
import org.neo4j.gds.CommunityAlgorithmTasks;
23+
import org.neo4j.gds.ProgressTrackerFactory;
24+
import org.neo4j.gds.api.Graph;
25+
import org.neo4j.gds.approxmaxkcut.ApproxMaxKCut;
26+
import org.neo4j.gds.approxmaxkcut.ApproxMaxKCutParameters;
27+
import org.neo4j.gds.approxmaxkcut.ApproxMaxKCutResult;
28+
import org.neo4j.gds.async.AsyncAlgorithmCaller;
29+
import org.neo4j.gds.core.concurrency.DefaultPool;
30+
import org.neo4j.gds.core.utils.progress.JobId;
31+
import org.neo4j.gds.result.TimedAlgorithmResult;
32+
import org.neo4j.gds.termination.TerminationFlag;
33+
34+
import java.util.concurrent.CompletableFuture;
35+
36+
public class CommunityComputeFacade {
37+
38+
// Global dependencies
39+
// This is created with its own ExecutorService workerPool,
40+
// which determines how many algorithms can run in parallel.
41+
private final AsyncAlgorithmCaller algorithmCaller;
42+
private final ProgressTrackerFactory progressTrackerFactory;
43+
44+
// Request scope dependencies
45+
private final TerminationFlag terminationFlag;
46+
47+
// Local dependencies
48+
private final CommunityAlgorithmTasks tasks = new CommunityAlgorithmTasks();
49+
50+
51+
public CommunityComputeFacade(
52+
AsyncAlgorithmCaller algorithmCaller,
53+
ProgressTrackerFactory progressTrackerFactory,
54+
TerminationFlag terminationFlag
55+
) {
56+
this.algorithmCaller = algorithmCaller;
57+
this.progressTrackerFactory = progressTrackerFactory;
58+
this.terminationFlag = terminationFlag;
59+
}
60+
61+
62+
public CompletableFuture<TimedAlgorithmResult<ApproxMaxKCutResult>> approxMaxKCut(
63+
Graph graph,
64+
ApproxMaxKCutParameters parameters,
65+
JobId jobId,
66+
boolean logProgress
67+
) {
68+
if (graph.isEmpty()) {
69+
return CompletableFuture.completedFuture(TimedAlgorithmResult.empty(ApproxMaxKCutResult.EMPTY));
70+
}
71+
72+
var progressTracker = progressTrackerFactory.create(
73+
tasks.approximateMaximumKCut(graph, parameters),
74+
jobId,
75+
parameters.concurrency(),
76+
logProgress
77+
);
78+
79+
var approxMaxKCut = ApproxMaxKCut.create(
80+
graph,
81+
parameters,
82+
DefaultPool.INSTANCE,
83+
progressTracker,
84+
terminationFlag
85+
);
86+
87+
return algorithmCaller.run(
88+
approxMaxKCut::compute,
89+
jobId
90+
);
91+
}
92+
93+
94+
95+
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.community;
21+
22+
import org.junit.jupiter.api.BeforeEach;
23+
import org.junit.jupiter.api.Test;
24+
import org.junit.jupiter.api.extension.ExtendWith;
25+
import org.mockito.Mock;
26+
import org.mockito.junit.jupiter.MockitoExtension;
27+
import org.neo4j.gds.ProgressTrackerFactory;
28+
import org.neo4j.gds.api.Graph;
29+
import org.neo4j.gds.approxmaxkcut.ApproxMaxKCutParameters;
30+
import org.neo4j.gds.approxmaxkcut.ApproxMaxKCutResult;
31+
import org.neo4j.gds.async.AsyncAlgorithmCaller;
32+
import org.neo4j.gds.core.utils.progress.JobId;
33+
import org.neo4j.gds.termination.TerminationFlag;
34+
35+
import static org.assertj.core.api.Assertions.assertThat;
36+
import static org.mockito.Mockito.mock;
37+
import static org.mockito.Mockito.verifyNoInteractions;
38+
import static org.mockito.Mockito.when;
39+
40+
@ExtendWith(MockitoExtension.class)
41+
class CommunityComputeFacadeEmptyGraphTest {
42+
@Mock
43+
private Graph graph;
44+
45+
@Mock
46+
private ProgressTrackerFactory progressTrackerFactoryMock;
47+
48+
@Mock
49+
private AsyncAlgorithmCaller algorithmCallerMock;
50+
51+
@Mock
52+
private JobId jobIdMock;
53+
54+
private CommunityComputeFacade facade;
55+
56+
@BeforeEach
57+
void setUp() {
58+
when(graph.isEmpty()).thenReturn(true);
59+
facade = new CommunityComputeFacade(
60+
algorithmCallerMock,
61+
progressTrackerFactoryMock,
62+
TerminationFlag.RUNNING_TRUE
63+
);
64+
}
65+
66+
@Test
67+
void maxKCut() {
68+
var future = facade.approxMaxKCut(
69+
graph,
70+
mock(ApproxMaxKCutParameters.class),
71+
jobIdMock,
72+
true
73+
);
74+
var result = future.join();
75+
assertThat(result.result()).isEqualTo(ApproxMaxKCutResult.EMPTY);
76+
77+
verifyNoInteractions(progressTrackerFactoryMock);
78+
verifyNoInteractions(algorithmCallerMock);
79+
}
80+
81+
82+
}
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.community;
21+
22+
import org.junit.jupiter.api.BeforeEach;
23+
import org.junit.jupiter.api.Test;
24+
import org.junit.jupiter.api.extension.ExtendWith;
25+
import org.mockito.Mock;
26+
import org.mockito.junit.jupiter.MockitoExtension;
27+
import org.neo4j.gds.Orientation;
28+
import org.neo4j.gds.ProgressTrackerFactory;
29+
import org.neo4j.gds.approxmaxkcut.ApproxMaxKCutParameters;
30+
import org.neo4j.gds.async.AsyncAlgorithmCaller;
31+
import org.neo4j.gds.core.concurrency.Concurrency;
32+
import org.neo4j.gds.core.utils.progress.JobId;
33+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
34+
import org.neo4j.gds.extension.GdlExtension;
35+
import org.neo4j.gds.extension.GdlGraph;
36+
import org.neo4j.gds.extension.IdFunction;
37+
import org.neo4j.gds.extension.Inject;
38+
import org.neo4j.gds.extension.TestGraph;
39+
import org.neo4j.gds.logging.Log;
40+
import org.neo4j.gds.termination.TerminationFlag;
41+
42+
import java.util.List;
43+
import java.util.Optional;
44+
import java.util.concurrent.Executors;
45+
46+
import static org.assertj.core.api.Assertions.assertThat;
47+
import static org.mockito.ArgumentMatchers.any;
48+
import static org.mockito.ArgumentMatchers.anyBoolean;
49+
import static org.mockito.Mockito.when;
50+
51+
@ExtendWith(MockitoExtension.class)
52+
@GdlExtension
53+
class CommunityComputeFacadeTest {
54+
55+
@Mock(strictness = Mock.Strictness.LENIENT)
56+
private ProgressTrackerFactory progressTrackerFactoryMock;
57+
@Mock
58+
private ProgressTracker progressTrackerMock;
59+
60+
@Mock
61+
private JobId jobIdMock;
62+
63+
@Mock
64+
private Log logMock;
65+
66+
@GdlGraph(orientation = Orientation.UNDIRECTED)
67+
private static final String GDL = """
68+
(a:Node { prize: 1.0 })-[:REL]->(b:Node { prize: 2.0 }),
69+
(b)-[:REL]->(c:Node { prize: 3.0 }),
70+
(a)-[:REL]->(c)
71+
""";
72+
73+
@Inject
74+
private TestGraph graph;
75+
76+
@Inject
77+
private IdFunction idFunction;
78+
private CommunityComputeFacade facade;
79+
80+
@BeforeEach
81+
void setUp() {
82+
when(progressTrackerFactoryMock.nullTracker())
83+
.thenReturn(ProgressTracker.NULL_TRACKER);
84+
when(progressTrackerFactoryMock.create(any(), any(), any(), anyBoolean()))
85+
.thenReturn(progressTrackerMock);
86+
87+
facade = new CommunityComputeFacade(
88+
new AsyncAlgorithmCaller(Executors.newSingleThreadExecutor(), logMock),
89+
progressTrackerFactoryMock,
90+
TerminationFlag.RUNNING_TRUE
91+
);
92+
}
93+
94+
@Test
95+
void maxKCut() {
96+
var future = facade.approxMaxKCut(
97+
graph,
98+
new ApproxMaxKCutParameters(
99+
(byte) 2,
100+
5,
101+
1,
102+
new Concurrency(4),
103+
10_000,
104+
Optional.empty(),
105+
List.of(),
106+
false,
107+
false
108+
),
109+
jobIdMock,
110+
false
111+
);
112+
113+
var results = future.join();
114+
115+
assertThat(results.result().candidateSolution().toArray()).containsOnly(0,1);
116+
assertThat(results.computeMillis()).isNotNegative();
117+
118+
}
119+
}

0 commit comments

Comments
 (0)