File tree Expand file tree Collapse file tree 4 files changed +100
-2
lines changed
algorithms-compute-facade/src
main/java/org/neo4j/gds/community
algo/src/main/java/org/neo4j/gds/kmeans Expand file tree Collapse file tree 4 files changed +100
-2
lines changed Original file line number Diff line number Diff line change @@ -30,4 +30,16 @@ public record KmeansResult(
3030 double averageDistanceToCentroid ,
3131 @ Nullable HugeDoubleArray silhouette ,
3232 double averageSilhouette
33- ) {}
33+ ) {
34+ public static KmeansResult empty (int k ) {
35+ return new KmeansResult (
36+ HugeIntArray .newArray (0 ),
37+ HugeDoubleArray .newArray (0 ),
38+ new double [k ][0 ],
39+ 0 ,
40+ null ,
41+ 0
42+ );
43+ }
44+
45+ }
Original file line number Diff line number Diff line change 4343import org .neo4j .gds .kcore .KCoreDecomposition ;
4444import org .neo4j .gds .kcore .KCoreDecompositionParameters ;
4545import org .neo4j .gds .kcore .KCoreDecompositionResult ;
46+ import org .neo4j .gds .kmeans .Kmeans ;
47+ import org .neo4j .gds .kmeans .KmeansContext ;
48+ import org .neo4j .gds .kmeans .KmeansParameters ;
49+ import org .neo4j .gds .kmeans .KmeansResult ;
4650import org .neo4j .gds .result .TimedAlgorithmResult ;
4751import org .neo4j .gds .termination .TerminationFlag ;
4852
@@ -265,4 +269,36 @@ CompletableFuture<TimedAlgorithmResult<KCoreDecompositionResult>> kCore(
265269 jobId
266270 );
267271 }
272+
273+ CompletableFuture <TimedAlgorithmResult <KmeansResult >> kMeans (
274+ Graph graph ,
275+ KmeansParameters parameters ,
276+ JobId jobId ,
277+ boolean logProgress
278+ ) {
279+
280+ if (graph .isEmpty ()) {
281+ return CompletableFuture .completedFuture (TimedAlgorithmResult .empty (KmeansResult .empty (parameters .k ())));
282+ }
283+
284+ var progressTracker = progressTrackerFactory .create (
285+ tasks .kMeans (graph ,parameters ),
286+ jobId ,
287+ parameters .concurrency (),
288+ logProgress
289+ );
290+
291+ var algorithm = Kmeans .createKmeans (
292+ graph ,
293+ parameters ,
294+ new KmeansContext (DefaultPool .INSTANCE , progressTracker ),
295+ terminationFlag
296+ );
297+
298+ return algorithmCaller .run (
299+ algorithm ::compute ,
300+ jobId
301+ );
302+ }
303+
268304}
Original file line number Diff line number Diff line change 4040import org .neo4j .gds .k1coloring .K1ColoringResult ;
4141import org .neo4j .gds .kcore .KCoreDecompositionParameters ;
4242import org .neo4j .gds .kcore .KCoreDecompositionResult ;
43+ import org .neo4j .gds .kmeans .KmeansParameters ;
4344import org .neo4j .gds .termination .TerminationFlag ;
4445
4546import static org .assertj .core .api .Assertions .assertThat ;
@@ -171,4 +172,25 @@ void kCore(){
171172 verifyNoInteractions (algorithmCallerMock );
172173 }
173174
175+ @ Test
176+ void kMeans (){
177+ var params = mock (KmeansParameters .class );
178+ when (params .k ()).thenReturn (3 );
179+
180+ var future = facade .kMeans (
181+ graph ,
182+ params ,
183+ jobIdMock ,
184+ false
185+ );
186+
187+ var results = future .join ();
188+
189+ assertThat (results .result ().communities ().toArray ()).hasSize (0 );
190+ assertThat (results .result ().centers ()).hasDimensions (3 ,0 );
191+
192+ verifyNoInteractions (progressTrackerFactoryMock );
193+ verifyNoInteractions (algorithmCallerMock );
194+ }
195+
174196}
Original file line number Diff line number Diff line change 4242import org .neo4j .gds .hdbscan .HDBScanParameters ;
4343import org .neo4j .gds .k1coloring .K1ColoringParameters ;
4444import org .neo4j .gds .kcore .KCoreDecompositionParameters ;
45+ import org .neo4j .gds .kmeans .KmeansParameters ;
46+ import org .neo4j .gds .kmeans .SamplerType ;
4547import org .neo4j .gds .logging .Log ;
4648import org .neo4j .gds .termination .TerminationFlag ;
4749
@@ -171,7 +173,7 @@ void hdbscan(){
171173 1 ,
172174 3 ,
173175 1 ,
174- "prop "
176+ "prop2 "
175177 ),
176178 jobIdMock ,
177179 false
@@ -219,4 +221,30 @@ void kCore(){
219221 assertThat (results .computeMillis ()).isNotNegative ();
220222 }
221223
224+ @ Test
225+ void kMeans (){
226+ var future = facade .kMeans (
227+ graph ,
228+ new KmeansParameters (
229+ 3 ,
230+ 10 ,
231+ 0.5 ,
232+ 1 ,
233+ false ,
234+ new Concurrency (4 ),
235+ "prop2" ,
236+ SamplerType .UNIFORM ,
237+ List .of (),
238+ Optional .empty ()
239+ ),
240+ jobIdMock ,
241+ false
242+ );
243+
244+ var results = future .join ();
245+
246+ assertThat (results .result ().communities ().toArray ()).hasSize (3 );
247+ assertThat (results .computeMillis ()).isNotNegative ();
248+ }
249+
222250}
You can’t perform that action at this time.
0 commit comments