Skip to content

Commit 805f358

Browse files
Transfer smallest update to correct place :see_no_evil
1 parent d26b50e commit 805f358

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

algo/src/main/java/org/neo4j/gds/steiner/SteinerBasedDeltaStepping.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ private boolean ensureShortest(
198198
.min()
199199
.orElseThrow();
200200
//return true if the closet terminal is at least as close as the closest next node
201+
System.out.println("hi" + distance + " " + currentMinDistance + " " + (distance <= currentMinDistance) + " " + oldBin + " " + currentBin);
201202
return distance <= currentMinDistance;
202203
} else {
203204
return (distance < currentBin * delta);

algo/src/main/java/org/neo4j/gds/steiner/SteinerBasedDeltaTask.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ private void relaxNode(long nodeId) {
138138
graph.forEachRelationship(nodeId, 1.0, (sourceNodeId, targetNodeId, weight) -> {
139139
if (!mergedToSource.get(targetNodeId)) { //ignore merged vertices
140140
tryToUpdate(sourceNodeId, targetNodeId, weight);
141-
smallestConsideredDistance = Math.min(weight, smallestConsideredDistance);
142141
}
143142
return true;
144143
});
@@ -165,6 +164,8 @@ private void tryToUpdate(long sourceNodeId, long targetNodeId,double weight) {
165164
// CAX failed, retry
166165
oldDist = witness;
167166
}
167+
smallestConsideredDistance = Math.min(newDist, smallestConsideredDistance);
168+
168169
if (uninvisitedTerminal.get(targetNodeId)) {
169170

170171
terminalQueueLock.lock();

algo/src/test/java/org/neo4j/gds/steiner/ShortestPathSteinerAlgorithmExtendedTest.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
import com.carrotsearch.hppc.BitSet;
2323
import org.junit.jupiter.api.Test;
24+
import org.junit.jupiter.params.ParameterizedTest;
25+
import org.junit.jupiter.params.provider.ValueSource;
2426
import org.neo4j.gds.Orientation;
2527
import org.neo4j.gds.core.concurrency.Pools;
2628
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
@@ -138,15 +140,19 @@ class ShortestPathSteinerAlgorithmExtendedTest {
138140
@Inject
139141
private IdFunction triangleIdFunction;
140142

141-
@Test
142-
void shouldWorkCorrectly() {
143+
@ParameterizedTest
144+
@ValueSource(doubles = {2.0, 5.0, 100.0})
145+
//2.0 is the standard one we use everything
146+
//5.0 is one such that we can deduce a2 is best without changing bucket
147+
//100.0 is a big one where everything takes place inside a single bucket
148+
void shouldWorkCorrectly(double delta) {
143149

144150
var a = SteinerTestUtils.getNodes(idFunction, 6);
145151
var steinerTreeResult = new ShortestPathsSteinerAlgorithm(
146152
graph,
147153
a[0],
148154
List.of(a[2], a[5]),
149-
2.0,
155+
delta,
150156
1,
151157
false,
152158
Pools.DEFAULT

0 commit comments

Comments
 (0)