Skip to content

Commit 871609a

Browse files
Add kmeans in compute facade
1 parent 3b6af25 commit 871609a

File tree

4 files changed

+100
-2
lines changed

4 files changed

+100
-2
lines changed

algo/src/main/java/org/neo4j/gds/kmeans/KmeansResult.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,16 @@ public record KmeansResult(
3030
double averageDistanceToCentroid,
3131
@Nullable HugeDoubleArray silhouette,
3232
double averageSilhouette
33-
) {}
33+
) {
34+
public static KmeansResult empty(int k) {
35+
return new KmeansResult(
36+
HugeIntArray.newArray(0),
37+
HugeDoubleArray.newArray(0),
38+
new double[k][0],
39+
0,
40+
null,
41+
0
42+
);
43+
}
44+
45+
}

algorithms-compute-facade/src/main/java/org/neo4j/gds/community/CommunityComputeFacade.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@
4343
import org.neo4j.gds.kcore.KCoreDecomposition;
4444
import org.neo4j.gds.kcore.KCoreDecompositionParameters;
4545
import org.neo4j.gds.kcore.KCoreDecompositionResult;
46+
import org.neo4j.gds.kmeans.Kmeans;
47+
import org.neo4j.gds.kmeans.KmeansContext;
48+
import org.neo4j.gds.kmeans.KmeansParameters;
49+
import org.neo4j.gds.kmeans.KmeansResult;
4650
import org.neo4j.gds.result.TimedAlgorithmResult;
4751
import org.neo4j.gds.termination.TerminationFlag;
4852

@@ -265,4 +269,36 @@ CompletableFuture<TimedAlgorithmResult<KCoreDecompositionResult>> kCore(
265269
jobId
266270
);
267271
}
272+
273+
CompletableFuture<TimedAlgorithmResult<KmeansResult>> kMeans(
274+
Graph graph,
275+
KmeansParameters parameters,
276+
JobId jobId,
277+
boolean logProgress
278+
) {
279+
280+
if (graph.isEmpty()) {
281+
return CompletableFuture.completedFuture(TimedAlgorithmResult.empty(KmeansResult.empty(parameters.k())));
282+
}
283+
284+
var progressTracker = progressTrackerFactory.create(
285+
tasks.kMeans(graph,parameters),
286+
jobId,
287+
parameters.concurrency(),
288+
logProgress
289+
);
290+
291+
var algorithm = Kmeans.createKmeans(
292+
graph,
293+
parameters,
294+
new KmeansContext(DefaultPool.INSTANCE, progressTracker),
295+
terminationFlag
296+
);
297+
298+
return algorithmCaller.run(
299+
algorithm::compute,
300+
jobId
301+
);
302+
}
303+
268304
}

algorithms-compute-facade/src/test/java/org/neo4j/gds/community/CommunityComputeFacadeEmptyGraphTest.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import org.neo4j.gds.k1coloring.K1ColoringResult;
4141
import org.neo4j.gds.kcore.KCoreDecompositionParameters;
4242
import org.neo4j.gds.kcore.KCoreDecompositionResult;
43+
import org.neo4j.gds.kmeans.KmeansParameters;
4344
import org.neo4j.gds.termination.TerminationFlag;
4445

4546
import static org.assertj.core.api.Assertions.assertThat;
@@ -171,4 +172,25 @@ void kCore(){
171172
verifyNoInteractions(algorithmCallerMock);
172173
}
173174

175+
@Test
176+
void kMeans(){
177+
var params = mock(KmeansParameters.class);
178+
when(params.k()).thenReturn(3);
179+
180+
var future = facade.kMeans(
181+
graph,
182+
params,
183+
jobIdMock,
184+
false
185+
);
186+
187+
var results = future.join();
188+
189+
assertThat(results.result().communities().toArray()).hasSize(0);
190+
assertThat(results.result().centers()).hasDimensions(3,0);
191+
192+
verifyNoInteractions(progressTrackerFactoryMock);
193+
verifyNoInteractions(algorithmCallerMock);
194+
}
195+
174196
}

algorithms-compute-facade/src/test/java/org/neo4j/gds/community/CommunityComputeFacadeTest.java

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
import org.neo4j.gds.hdbscan.HDBScanParameters;
4343
import org.neo4j.gds.k1coloring.K1ColoringParameters;
4444
import org.neo4j.gds.kcore.KCoreDecompositionParameters;
45+
import org.neo4j.gds.kmeans.KmeansParameters;
46+
import org.neo4j.gds.kmeans.SamplerType;
4547
import org.neo4j.gds.logging.Log;
4648
import org.neo4j.gds.termination.TerminationFlag;
4749

@@ -171,7 +173,7 @@ void hdbscan(){
171173
1,
172174
3,
173175
1,
174-
"prop"
176+
"prop2"
175177
),
176178
jobIdMock,
177179
false
@@ -219,4 +221,30 @@ void kCore(){
219221
assertThat(results.computeMillis()).isNotNegative();
220222
}
221223

224+
@Test
225+
void kMeans(){
226+
var future = facade.kMeans(
227+
graph,
228+
new KmeansParameters(
229+
3,
230+
10,
231+
0.5,
232+
1,
233+
false,
234+
new Concurrency(4),
235+
"prop2",
236+
SamplerType.UNIFORM,
237+
List.of(),
238+
Optional.empty()
239+
),
240+
jobIdMock,
241+
false
242+
);
243+
244+
var results = future.join();
245+
246+
assertThat(results.result().communities().toArray()).hasSize(3);
247+
assertThat(results.computeMillis()).isNotNegative();
248+
}
249+
222250
}

0 commit comments

Comments
 (0)