Skip to content

Commit c449600

Browse files
Never fail when no localMoves are ever done
Co-authored-by: Veselin Nikolov <veselin.nikolov@neotechnology.com>
1 parent 80016c1 commit c449600

File tree

2 files changed

+40
-10
lines changed

2 files changed

+40
-10
lines changed

algo/src/main/java/org/neo4j/gds/leiden/Leiden.java

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,10 @@ public LeidenResult compute() {
133133
gamma,
134134
concurrency
135135
);
136-
var communitiesCount = localMovePhase.run();
137-
boolean localPhaseConverged = communitiesCount == workingGraph.nodeCount() || localMovePhase.swaps == 0;
138136

137+
localMovePhase.run();
138+
//if you do swaps, no convergence
139+
boolean localPhaseConverged = localMovePhase.swaps == 0;
139140
progressTracker.endSubTask("Local Move");
140141

141142
progressTracker.beginSubTask("Modularity Computation");
@@ -244,11 +245,11 @@ public LeidenResult compute() {
244245

245246
@NotNull
246247
private LeidenResult getLeidenResult(boolean didConverge, int iteration) {
247-
boolean seedIsOptimal = didConverge && seedValues.isPresent() && iteration == 0;
248-
if (seedIsOptimal) {
248+
boolean stoppedAtFirstIteration = didConverge && iteration == 0;
249+
if (stoppedAtFirstIteration) {
249250
var modularity = modularities[0];
250251
return LeidenResult.of(
251-
LeidenUtils.createSeedCommunities(rootGraph.nodeCount(), seedValues.orElse(null)),
252+
LeidenUtils.createStartingCommunities(rootGraph.nodeCount(), seedValues.orElse(null)),
252253
1,
253254
didConverge,
254255
null,
@@ -266,8 +267,8 @@ private LeidenResult getLeidenResult(boolean didConverge, int iteration) {
266267
);
267268
}
268269
}
269-
270-
private boolean updateModularity(
270+
271+
private void updateModularity(
271272
Graph workingGraph,
272273
HugeLongArray localMoveCommunities,
273274
HugeDoubleArray localMoveCommunityVolumes,
@@ -276,8 +277,10 @@ private boolean updateModularity(
276277
boolean localPhaseConverged,
277278
int iteration
278279
) {
279-
boolean seedIsOptimal = localPhaseConverged && seedValues.isPresent() && iteration == 0;
280-
boolean shouldCalculateModularity = !localPhaseConverged || seedIsOptimal;
280+
//will calculate modularity only if:
281+
// - the local phase has not converged (i.e., no swaps done)
282+
//- or we terminate in the first iteration (i.e., given seeding is optimal, graph is empty etc)
283+
boolean shouldCalculateModularity = !localPhaseConverged || iteration == 0;
281284

282285
if (shouldCalculateModularity) {
283286
modularities[iteration] = ModularityComputer.compute(
@@ -291,7 +294,6 @@ private boolean updateModularity(
291294
progressTracker
292295
);
293296
}
294-
return seedIsOptimal;
295297
}
296298

297299
private double initVolumes(

algo/src/test/java/org/neo4j/gds/leiden/LeidenTest.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import org.neo4j.gds.extension.IdFunction;
4242
import org.neo4j.gds.extension.Inject;
4343
import org.neo4j.gds.extension.TestGraph;
44+
import org.neo4j.gds.gdl.GdlFactory;
4445

4546
import java.util.stream.Collectors;
4647
import java.util.stream.LongStream;
@@ -154,6 +155,8 @@ void shouldWorkWithBestSeed() {
154155
community -> assertThat(community).containsExactlyInAnyOrder("a1", "a5", "a6", "a7")
155156
);
156157
assertThat(communitiesMap.keySet()).containsExactly(4000L, 5000L);
158+
159+
assertThat(leidenResult.modularity()).isGreaterThan(0);
157160
}
158161

159162
@Test
@@ -364,4 +367,29 @@ void shouldLogProgress() {
364367
"Leiden :: Finished"
365368
);
366369
}
370+
371+
@Test
372+
void shouldWorkIfStopAtFirstIteration() {
373+
374+
var query = " CREATE\n" +
375+
" (c:C {id: 0}) \n" +
376+
" , (d:D {id: 1}) ";
377+
var empty = GdlFactory.of(query).build().getUnion();
378+
var result = new Leiden(
379+
empty,
380+
5,
381+
1,
382+
0.01,
383+
false,
384+
42,
385+
null,
386+
0.1,
387+
1,
388+
ProgressTracker.NULL_TRACKER
389+
).compute();
390+
var communities = result.communities();
391+
assertThat(communities.toArray()).isEqualTo(new long[]{0, 1});
392+
393+
394+
}
367395
}

0 commit comments

Comments
 (0)