Skip to content

Commit 3e3ae3e

Browse files
Add triangles stream in compute facade
1 parent d834082 commit 3e3ae3e

File tree

4 files changed

+80
-17
lines changed

4 files changed

+80
-17
lines changed

algorithms-compute-facade/src/main/java/org/neo4j/gds/community/CommunityComputeFacade.java

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
import org.neo4j.gds.approxmaxkcut.ApproxMaxKCutResult;
2929
import org.neo4j.gds.async.AsyncAlgorithmCaller;
3030
import org.neo4j.gds.beta.pregel.ImmutablePregelResult;
31+
import org.neo4j.gds.beta.pregel.NodeValue;
3132
import org.neo4j.gds.beta.pregel.PregelResult;
33+
import org.neo4j.gds.beta.pregel.PregelSchema;
3234
import org.neo4j.gds.cliqueCounting.CliqueCounting;
3335
import org.neo4j.gds.cliqueCounting.CliqueCountingResult;
3436
import org.neo4j.gds.cliquecounting.CliqueCountingParameters;
@@ -82,11 +84,14 @@
8284
import org.neo4j.gds.triangle.LocalClusteringCoefficientResult;
8385
import org.neo4j.gds.triangle.TriangleCountParameters;
8486
import org.neo4j.gds.triangle.TriangleCountResult;
87+
import org.neo4j.gds.triangle.TriangleResult;
88+
import org.neo4j.gds.triangle.TriangleStream;
8589
import org.neo4j.gds.wcc.Wcc;
8690
import org.neo4j.gds.wcc.WccParameters;
8791

8892
import java.util.Optional;
8993
import java.util.concurrent.CompletableFuture;
94+
import java.util.stream.Stream;
9095

9196
public class CommunityComputeFacade {
9297

@@ -596,6 +601,29 @@ CompletableFuture<TimedAlgorithmResult<TriangleCountResult>> triangleCount(
596601
);
597602
}
598603

604+
CompletableFuture<TimedAlgorithmResult<Stream<TriangleResult>>> triangles(
605+
Graph graph,
606+
TriangleCountParameters parameters,
607+
JobId jobId) {
608+
609+
if (graph.isEmpty()) {
610+
return CompletableFuture.completedFuture(TimedAlgorithmResult.empty( Stream.empty()));
611+
}
612+
613+
var algorithm = TriangleStream.create(
614+
graph,
615+
DefaultPool.INSTANCE,
616+
parameters.concurrency(),
617+
parameters.labelFilter(),
618+
terminationFlag
619+
);
620+
621+
return algorithmCaller.run(
622+
algorithm::compute,
623+
jobId
624+
);
625+
}
626+
599627
CompletableFuture<TimedAlgorithmResult<DisjointSetStruct>> wcc(
600628
Graph graph,
601629
WccParameters parameters,
@@ -628,6 +656,7 @@ CompletableFuture<TimedAlgorithmResult<DisjointSetStruct>> wcc(
628656
jobId
629657
);
630658
}
659+
631660
CompletableFuture<TimedAlgorithmResult<PregelResult>> sllpa(
632661
Graph graph,
633662
SpeakerListenerLPAConfig configuration,
@@ -636,7 +665,8 @@ CompletableFuture<TimedAlgorithmResult<PregelResult>> sllpa(
636665
) {
637666

638667
if (graph.isEmpty()) {
639-
return CompletableFuture.completedFuture(TimedAlgorithmResult.empty(ImmutablePregelResult.of(null, 0, false)));
668+
var empty = NodeValue.of(new PregelSchema.Builder().build(),0,new Concurrency(1));
669+
return CompletableFuture.completedFuture(TimedAlgorithmResult.empty(ImmutablePregelResult.of(empty, 0, false)));
640670
}
641671

642672
var progressTracker = progressTrackerFactory.create(

algorithms-compute-facade/src/test/java/org/neo4j/gds/community/CommunityComputeFacadeEmptyGraphTest.java

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,7 @@
4141
import org.neo4j.gds.kcore.KCoreDecompositionParameters;
4242
import org.neo4j.gds.kcore.KCoreDecompositionResult;
4343
import org.neo4j.gds.kmeans.KmeansParameters;
44-
import org.neo4j.gds.labelpropagation.LabelPropagationParameters;
4544
import org.neo4j.gds.labelpropagation.LabelPropagationResult;
46-
import org.neo4j.gds.leiden.LeidenParameters;
4745
import org.neo4j.gds.leiden.LeidenResult;
4846
import org.neo4j.gds.louvain.LouvainParameters;
4947
import org.neo4j.gds.louvain.LouvainResult;
@@ -215,7 +213,7 @@ void labelPropagation(){
215213

216214
var future = facade.labelPropagation(
217215
graph,
218-
mock(LabelPropagationParameters.class),
216+
null,
219217
jobIdMock,
220218
false
221219
);
@@ -249,7 +247,7 @@ void leiden(){
249247

250248
var future = facade.leiden(
251249
graph,
252-
mock(LeidenParameters.class),
250+
null,
253251
jobIdMock,
254252
false
255253
);
@@ -345,6 +343,21 @@ void triangleCount(){
345343
verifyNoInteractions(algorithmCallerMock);
346344
}
347345

346+
@Test
347+
void triangles() {
348+
var future = facade.triangles(
349+
graph,
350+
mock(TriangleCountParameters.class),
351+
jobIdMock
352+
);
353+
var result = future.join();
354+
assertThat(result.result()).isEmpty();
355+
356+
verifyNoInteractions(progressTrackerFactoryMock);
357+
verifyNoInteractions(algorithmCallerMock);
358+
}
359+
360+
348361
@Test
349362
void wcc(){
350363

algorithms-compute-facade/src/test/java/org/neo4j/gds/community/CommunityComputeFacadeTest.java

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.mockito.junit.jupiter.MockitoExtension;
2727
import org.neo4j.gds.Orientation;
2828
import org.neo4j.gds.ProgressTrackerFactory;
29+
import org.neo4j.gds.api.Graph;
2930
import org.neo4j.gds.approxmaxkcut.ApproxMaxKCutParameters;
3031
import org.neo4j.gds.async.AsyncAlgorithmCaller;
3132
import org.neo4j.gds.cliquecounting.CliqueCountingMode;
@@ -38,7 +39,6 @@
3839
import org.neo4j.gds.extension.GdlGraph;
3940
import org.neo4j.gds.extension.IdFunction;
4041
import org.neo4j.gds.extension.Inject;
41-
import org.neo4j.gds.extension.TestGraph;
4242
import org.neo4j.gds.hdbscan.HDBScanParameters;
4343
import org.neo4j.gds.k1coloring.K1ColoringParameters;
4444
import org.neo4j.gds.kcore.KCoreDecompositionParameters;
@@ -89,7 +89,7 @@ class CommunityComputeFacadeTest {
8989
""";
9090

9191
@Inject
92-
private TestGraph graph;
92+
private Graph graph;
9393

9494
@Inject
9595
private IdFunction idFunction;
@@ -120,7 +120,7 @@ void maxKCut() {
120120
new Concurrency(4),
121121
10_000,
122122
Optional.empty(),
123-
List.of(),
123+
List.of(0L,0L),
124124
false,
125125
false
126126
),
@@ -180,9 +180,9 @@ void hdbscan(){
180180
graph,
181181
new HDBScanParameters(
182182
new Concurrency(4),
183-
1,
183+
2,
184184
3,
185-
1,
185+
2,
186186
"prop2"
187187
),
188188
jobIdMock,
@@ -416,13 +416,33 @@ void triangleCount(){
416416
assertThat(results.computeMillis()).isNotNegative();
417417
}
418418

419+
@Test
420+
void triangles() {
421+
var future = facade.triangles(
422+
graph,
423+
new TriangleCountParameters(new Concurrency(4), 100,List.of()),
424+
jobIdMock
425+
);
426+
427+
var results = future.join();
428+
long a = idFunction.of("a");
429+
long b = idFunction.of("b");
430+
long c = idFunction.of("c");
431+
432+
assertThat(results.result()).isNotEmpty()
433+
.anySatisfy(r -> {
434+
long[] triangleArray = new long[]{r.nodeA,r.nodeB,r.nodeC};
435+
assertThat(triangleArray).containsExactlyInAnyOrder(a,b,c);
436+
});
437+
}
438+
419439
@Test
420440
void wcc(){
421441
var future = facade.wcc(
422442
graph,
423443
new WccParameters(
424444
0,
425-
null,
445+
Optional.empty(),
426446
new Concurrency(4)
427447
),
428448
jobIdMock,

pregel/src/main/java/org/neo4j/gds/beta/pregel/NodeValue.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@
2222
import org.jetbrains.annotations.Nullable;
2323
import org.neo4j.gds.api.DefaultValue;
2424
import org.neo4j.gds.api.nodeproperties.ValueType;
25-
import org.neo4j.gds.core.concurrency.Concurrency;
26-
import org.neo4j.gds.core.concurrency.ParallelUtil;
27-
import org.neo4j.gds.termination.TerminationFlag;
28-
import org.neo4j.gds.mem.MemoryEstimation;
29-
import org.neo4j.gds.mem.MemoryEstimations;
3025
import org.neo4j.gds.collections.ha.HugeDoubleArray;
3126
import org.neo4j.gds.collections.ha.HugeLongArray;
3227
import org.neo4j.gds.collections.ha.HugeObjectArray;
28+
import org.neo4j.gds.core.concurrency.Concurrency;
29+
import org.neo4j.gds.core.concurrency.ParallelUtil;
3330
import org.neo4j.gds.mem.Estimate;
31+
import org.neo4j.gds.mem.MemoryEstimation;
32+
import org.neo4j.gds.mem.MemoryEstimations;
33+
import org.neo4j.gds.termination.TerminationFlag;
3434
import org.neo4j.gds.utils.StringFormatting;
3535
import org.neo4j.gds.utils.StringJoining;
3636
import org.neo4j.gds.values.FloatingPointValue;
@@ -57,7 +57,7 @@ public abstract class NodeValue {
5757
.collect(Collectors.toMap(Element::propertyKey, Element::propertyType));
5858
}
5959

60-
static NodeValue of(PregelSchema schema, long nodeCount, Concurrency concurrency) {
60+
public static NodeValue of(PregelSchema schema, long nodeCount, Concurrency concurrency) {
6161
var properties = schema.elements()
6262
.stream()
6363
.collect(Collectors.toMap(

0 commit comments

Comments
 (0)