Skip to content

Commit bda99c6

Browse files
Add label propagation in compute facade
1 parent 871609a commit bda99c6

File tree

4 files changed

+78
-1
lines changed

4 files changed

+78
-1
lines changed

algo/src/main/java/org/neo4j/gds/labelpropagation/LabelPropagationResult.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,7 @@
2121

2222
import org.neo4j.gds.collections.ha.HugeLongArray;
2323

24-
public record LabelPropagationResult(HugeLongArray labels, boolean didConverge, long ranIterations) {}
24+
public record LabelPropagationResult(HugeLongArray labels, boolean didConverge, long ranIterations) {
25+
26+
public static LabelPropagationResult EMPTY = new LabelPropagationResult(HugeLongArray.newArray(0), false, 0);
27+
}

algorithms-compute-facade/src/main/java/org/neo4j/gds/community/CommunityComputeFacade.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@
4747
import org.neo4j.gds.kmeans.KmeansContext;
4848
import org.neo4j.gds.kmeans.KmeansParameters;
4949
import org.neo4j.gds.kmeans.KmeansResult;
50+
import org.neo4j.gds.labelpropagation.LabelPropagation;
51+
import org.neo4j.gds.labelpropagation.LabelPropagationParameters;
52+
import org.neo4j.gds.labelpropagation.LabelPropagationResult;
5053
import org.neo4j.gds.result.TimedAlgorithmResult;
5154
import org.neo4j.gds.termination.TerminationFlag;
5255

@@ -301,4 +304,36 @@ CompletableFuture<TimedAlgorithmResult<KmeansResult>> kMeans(
301304
);
302305
}
303306

307+
CompletableFuture<TimedAlgorithmResult<LabelPropagationResult>> labelPropagation(
308+
Graph graph,
309+
LabelPropagationParameters parameters,
310+
JobId jobId,
311+
boolean logProgress
312+
) {
313+
314+
if (graph.isEmpty()) {
315+
return CompletableFuture.completedFuture(TimedAlgorithmResult.empty(LabelPropagationResult.EMPTY));
316+
}
317+
318+
var progressTracker = progressTrackerFactory.create(
319+
tasks.labelPropagation(graph,parameters),
320+
jobId,
321+
parameters.concurrency(),
322+
logProgress
323+
);
324+
325+
var algorithm = new LabelPropagation(
326+
graph,
327+
parameters,
328+
DefaultPool.INSTANCE,
329+
progressTracker,
330+
terminationFlag
331+
);
332+
333+
return algorithmCaller.run(
334+
algorithm::compute,
335+
jobId
336+
);
337+
}
338+
304339
}

algorithms-compute-facade/src/test/java/org/neo4j/gds/community/CommunityComputeFacadeEmptyGraphTest.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
import org.neo4j.gds.kcore.KCoreDecompositionParameters;
4242
import org.neo4j.gds.kcore.KCoreDecompositionResult;
4343
import org.neo4j.gds.kmeans.KmeansParameters;
44+
import org.neo4j.gds.labelpropagation.LabelPropagationParameters;
45+
import org.neo4j.gds.labelpropagation.LabelPropagationResult;
4446
import org.neo4j.gds.termination.TerminationFlag;
4547

4648
import static org.assertj.core.api.Assertions.assertThat;
@@ -193,4 +195,20 @@ void kMeans(){
193195
verifyNoInteractions(algorithmCallerMock);
194196
}
195197

198+
@Test
199+
void labelPropagation(){
200+
201+
var future = facade.labelPropagation(
202+
graph,
203+
mock(LabelPropagationParameters.class),
204+
jobIdMock,
205+
false
206+
);
207+
208+
var results = future.join();
209+
210+
assertThat(results.result()).isEqualTo(LabelPropagationResult.EMPTY);
211+
verifyNoInteractions(progressTrackerFactoryMock);
212+
verifyNoInteractions(algorithmCallerMock);
213+
}
196214
}

algorithms-compute-facade/src/test/java/org/neo4j/gds/community/CommunityComputeFacadeTest.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import org.neo4j.gds.kcore.KCoreDecompositionParameters;
4545
import org.neo4j.gds.kmeans.KmeansParameters;
4646
import org.neo4j.gds.kmeans.SamplerType;
47+
import org.neo4j.gds.labelpropagation.LabelPropagationParameters;
4748
import org.neo4j.gds.logging.Log;
4849
import org.neo4j.gds.termination.TerminationFlag;
4950

@@ -247,4 +248,24 @@ void kMeans(){
247248
assertThat(results.computeMillis()).isNotNegative();
248249
}
249250

251+
@Test
252+
void labelPropagation(){
253+
var future = facade.labelPropagation(
254+
graph,
255+
new LabelPropagationParameters(
256+
new Concurrency(4),
257+
10,
258+
null,
259+
null
260+
),
261+
jobIdMock,
262+
false
263+
);
264+
265+
var results = future.join();
266+
267+
assertThat(results.result().ranIterations()).isGreaterThan(0);
268+
assertThat(results.computeMillis()).isNotNegative();
269+
}
270+
250271
}

0 commit comments

Comments
 (0)