Skip to content

Commit 265a3db

Browse files
committed
migrate estimation cli kmeans to application layer
1 parent c2161ea commit 265a3db

File tree

9 files changed

+18
-356
lines changed

9 files changed

+18
-356
lines changed

algo/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ dependencies {
7979
testImplementation openGds.mockito.junit.jupiter
8080

8181
testImplementation project(':centrality-algorithms')
82+
testImplementation project(':community-algorithms')
8283
testImplementation project(':node-embedding-algorithms')
8384
testImplementation project(':path-finding-algorithms')
8485
}

algo/src/main/java/org/neo4j/gds/kmeans/Kmeans.java

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,6 @@
3939
import java.util.concurrent.ExecutorService;
4040

4141
public final class Kmeans extends Algorithm<KmeansResult> {
42-
43-
static final String KMEANS_DESCRIPTION =
44-
"The Kmeans algorithm clusters nodes into different communities based on Euclidean distance";
4542
private static final int UNASSIGNED = -1;
4643
private HugeIntArray bestCommunities;
4744
private final Graph graph;
@@ -166,7 +163,7 @@ private void kMeans(
166163
int restartIteration
167164
) {
168165

169-
//note: currentDistanceFromCentroid is not reset to a [0,...,0] distance array but it does not have to
166+
//note: currentDistanceFromCentroid is not reset to a [0,...,0] distance array, but it does not have to
170167
// it's used only in K-Means++ (where it is essentially reset; see func distanceFromLastSampledCentroid in KmeansTask)
171168
// or during final distance calculation where it is reset as well (see calculateFinalDistance in KmeansTask)
172169

@@ -242,7 +239,7 @@ private void kMeans(
242239
}
243240
progressTracker.endSubTask(); // Main - end
244241

245-
double averageDistanceFromCentroid = calculatedistancePhase(tasks);
242+
double averageDistanceFromCentroid = calculateDistancePhase(tasks);
246243
updateBestSolution(
247244
restartIteration,
248245
clusterManager,
@@ -262,7 +259,7 @@ private void initializeCentroids(ClusterManager clusterManager, KmeansSampler sa
262259
progressTracker.endSubTask(); // Initialization - end
263260
}
264261

265-
private void recomputeCentroids(ClusterManager clusterManager, List<KmeansTask> tasks) {
262+
private void recomputeCentroids(ClusterManager clusterManager, Iterable<KmeansTask> tasks) {
266263
clusterManager.reset();
267264

268265
for (KmeansTask task : tasks) {
@@ -367,7 +364,7 @@ private void calculateSilhouette() {
367364

368365
}
369366

370-
private double calculatedistancePhase(List<KmeansTask> tasks) {
367+
private double calculateDistancePhase(Iterable<KmeansTask> tasks) {
371368
for (KmeansTask task : tasks) {
372369
task.switchToPhase(TaskPhase.DISTANCE);
373370
}

algo/src/main/java/org/neo4j/gds/kmeans/KmeansAlgorithmFactory.java

Lines changed: 0 additions & 101 deletions
This file was deleted.

algo/src/test/java/org/neo4j/gds/kmeans/KmeansTest.java

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import org.junit.jupiter.params.ParameterizedTest;
2525
import org.junit.jupiter.params.provider.ValueSource;
2626
import org.neo4j.gds.api.Graph;
27-
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
27+
import org.neo4j.gds.applications.algorithms.community.CommunityAlgorithms;
2828
import org.neo4j.gds.extension.GdlExtension;
2929
import org.neo4j.gds.extension.GdlGraph;
3030
import org.neo4j.gds.extension.IdFunction;
@@ -265,6 +265,8 @@ void shouldComputeSilhouetteCorrectly() {
265265

266266
@Test
267267
void shouldNotWorkForRestartsAndSeeds() {
268+
var communityAlgorithms = new CommunityAlgorithms(null, null);
269+
268270
var kmeansConfig = KmeansStreamConfigImpl.builder()
269271
.nodeProperty("kmeans")
270272
.concurrency(1)
@@ -273,31 +275,23 @@ void shouldNotWorkForRestartsAndSeeds() {
273275
.k(2)
274276
.numberOfRestarts(10)
275277
.build();
276-
277-
var kmeansAlgorithmFactory = new KmeansAlgorithmFactory<>();
278-
assertThatThrownBy(() -> kmeansAlgorithmFactory.build(
279-
lineGraph,
280-
kmeansConfig,
281-
ProgressTracker.NULL_TRACKER
282-
)).hasMessageContaining("cannot be run");
278+
assertThatThrownBy(() -> communityAlgorithms.kMeans(lineGraph, kmeansConfig))
279+
.hasMessageContaining("cannot be run");
283280
}
284281

285282
@Test
286283
void shouldNotWorkForDifferentSeedAndK() {
284+
var communityAlgorithms = new CommunityAlgorithms(null, null);
285+
287286
var kmeansConfig = KmeansStreamConfigImpl.builder()
288287
.nodeProperty("kmeans")
289288
.concurrency(1)
290289
.randomSeed(19L)
291290
.seedCentroids(List.of(List.of(1d)))
292291
.k(2)
293292
.build();
294-
295-
var kmeansAlgorithmFactory = new KmeansAlgorithmFactory<>();
296-
assertThatThrownBy(() -> kmeansAlgorithmFactory.build(
297-
lineGraph,
298-
kmeansConfig,
299-
ProgressTracker.NULL_TRACKER
300-
)).hasMessageContaining("Incorrect");
293+
assertThatThrownBy(() -> communityAlgorithms.kMeans(lineGraph, kmeansConfig))
294+
.hasMessageContaining("Incorrect");
301295
}
302296

303297
@Test

algorithm-specifications/src/main/java/org/neo4j/gds/kmeans/KmeansMutateSpec.java

Lines changed: 0 additions & 57 deletions
This file was deleted.

algorithm-specifications/src/main/java/org/neo4j/gds/kmeans/KmeansStatsSpec.java

Lines changed: 0 additions & 58 deletions
This file was deleted.

algorithm-specifications/src/main/java/org/neo4j/gds/kmeans/KmeansStreamSpec.java

Lines changed: 0 additions & 56 deletions
This file was deleted.

0 commit comments

Comments
 (0)