Skip to content

Commit d3fd6fa

Browse files
authored
Merge pull request #10890 from IoannisPanagiotas/mem-issues-my-fam
Remove memory from memoryTracker for incorrect project config
2 parents ce29a28 + f121165 commit d3fd6fa

File tree

3 files changed

+58
-12
lines changed

3 files changed

+58
-12
lines changed

applications/graph-store-catalog/src/main/java/org/neo4j/gds/applications/graphstorecatalog/MemoryUsageValidator.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,19 @@ public class MemoryUsageValidator {
4242
private final MemoryTracker memoryTracker;
4343
private final String username;
4444

45-
public MemoryUsageValidator(String username,MemoryTracker memoryTracker, boolean useMaxMemoryEstimation, Log log) {
45+
public MemoryUsageValidator(
46+
String username,
47+
MemoryTracker memoryTracker,
48+
boolean useMaxMemoryEstimation,
49+
Log log
50+
) {
4651
this.log = log;
4752
this.useMaxMemoryEstimation = useMaxMemoryEstimation;
4853
this.memoryTracker = memoryTracker;
4954
this.username = username;
5055
}
5156

52-
public <C extends BaseConfig & JobIdConfig> MemoryRange tryValidateMemoryUsage(
57+
public synchronized <C extends BaseConfig & JobIdConfig> MemoryRange tryValidateMemoryUsage(
5358
String taskName,
5459
C config,
5560
Function<C, MemoryTreeWithDimensions> runEstimation

native-projection/src/main/java/org/neo4j/gds/projection/NativeFactory.java

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@
4141

4242
import java.util.Optional;
4343

44-
import static org.neo4j.gds.projection.GraphDimensionsValidation.validate;
45-
4644
final class NativeFactory extends CSRGraphStoreFactory<GraphProjectFromStoreConfig> {
4745

4846
private final GraphProjectFromStoreConfig storeConfig;
@@ -64,18 +62,21 @@ static NativeFactory nativeFactory(
6462
return new NativeFactory(
6563
graphProjectFromStoreConfig,
6664
loadingContext,
67-
dimensions
65+
dimensions,
66+
initProgressTracker(graphProjectFromStoreConfig, loadingContext, dimensions)
6867
);
6968
}
7069

71-
private NativeFactory(
70+
// Package-private constructor used in tests
71+
NativeFactory(
7272
GraphProjectFromStoreConfig graphProjectConfig,
7373
GraphLoaderContext loadingContext,
74-
GraphDimensions graphDimensions
74+
GraphDimensions graphDimensions,
75+
ProgressTracker progressTracker
7576
) {
7677
super(graphProjectConfig, ImmutableStaticCapabilities.of(WriteMode.LOCAL), loadingContext, graphDimensions);
7778
this.storeConfig = graphProjectConfig;
78-
this.progressTracker = initProgressTracker();
79+
this.progressTracker = progressTracker;
7980
}
8081

8182
@Override
@@ -88,7 +89,10 @@ public MemoryEstimation estimateMemoryUsageAfterLoading() {
8889
return getMemoryEstimation(storeConfig.nodeProjections(), storeConfig.relationshipProjections(), false);
8990
}
9091

91-
private ProgressTracker initProgressTracker() {
92+
private static ProgressTracker initProgressTracker(
93+
GraphProjectFromStoreConfig graphProjectConfig,
94+
GraphLoaderContext loadingContext, GraphDimensions dimensions
95+
) {
9296
long relationshipCount = graphProjectConfig
9397
.relationshipProjections()
9498
.projections()
@@ -133,11 +137,12 @@ private ProgressTracker initProgressTracker() {
133137

134138
@Override
135139
public CSRGraphStore build() {
136-
validate(dimensions, storeConfig);
137-
138-
var concurrency = graphProjectConfig.readConcurrency();
139140
try {
141+
// Start the sub-task so it will be registered in the task store
140142
progressTracker.beginSubTask();
143+
// `validate` can raise an exception which has to be handled in order to end the sub-task with failure
144+
validate();
145+
var concurrency = graphProjectConfig.readConcurrency();
141146
Nodes nodes = loadNodes(concurrency);
142147
RelationshipImportResult relationships = loadRelationships(nodes.idMap(), concurrency);
143148
CSRGraphStore graphStore = createGraphStore(nodes, relationships);
@@ -152,6 +157,11 @@ public CSRGraphStore build() {
152157
}
153158
}
154159

160+
// Allows for mocking the validation behaviour.
161+
void validate() throws IllegalArgumentException {
162+
GraphDimensionsValidation.validate(dimensions, storeConfig);
163+
}
164+
155165
private Nodes loadNodes(Concurrency concurrency) {
156166
var scanningNodesImporter = new ScanningNodesImporterBuilder()
157167
.concurrency(concurrency)

native-projection/src/test/java/org/neo4j/gds/projection/NativeFactoryTest.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,24 @@
2727
import org.neo4j.gds.RelationshipProjections;
2828
import org.neo4j.gds.RelationshipType;
2929
import org.neo4j.gds.api.CSRGraphStoreFactory;
30+
import org.neo4j.gds.api.GraphLoaderContext;
3031
import org.neo4j.gds.core.GraphDimensions;
3132
import org.neo4j.gds.core.ImmutableGraphDimensions;
3233
import org.neo4j.gds.core.concurrency.Concurrency;
34+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
3335
import org.neo4j.gds.mem.MemoryEstimation;
3436
import org.neo4j.gds.mem.MemoryTree;
3537

3638
import java.util.concurrent.atomic.AtomicReference;
3739

40+
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
3841
import static org.junit.jupiter.api.Assertions.assertEquals;
42+
import static org.mockito.Mockito.doThrow;
43+
import static org.mockito.Mockito.mock;
44+
import static org.mockito.Mockito.spy;
45+
import static org.mockito.Mockito.times;
46+
import static org.mockito.Mockito.verify;
47+
import static org.mockito.Mockito.verifyNoMoreInteractions;
3948

4049
class NativeFactoryTest {
4150

@@ -142,4 +151,26 @@ void memoryEstimationForMultipleProjections() {
142151
assertEquals(12_056_534_400L, estimate.memoryUsage().min);
143152
assertEquals(13_667_147_136L, estimate.memoryUsage().max);
144153
}
154+
155+
@Test
156+
void shouldCleanTaskOnValidateFailure() {
157+
var progressTrackerMock = mock(ProgressTracker.class);
158+
159+
var nativeFactorySpy = spy(new NativeFactory(
160+
mock(GraphProjectFromStoreConfig.class),
161+
mock(GraphLoaderContext.class),
162+
mock(GraphDimensions.class),
163+
progressTrackerMock
164+
));
165+
166+
doThrow(new IllegalArgumentException("Intentionally failing validation")).when(nativeFactorySpy).validate();
167+
168+
assertThatIllegalArgumentException()
169+
.isThrownBy(nativeFactorySpy::build)
170+
.withMessageContaining("Intentionally failing validation");
171+
172+
verify(progressTrackerMock, times(1)).beginSubTask();
173+
verify(progressTrackerMock, times(1)).endSubTaskWithFailure();
174+
verifyNoMoreInteractions(progressTrackerMock);
175+
}
145176
}

0 commit comments

Comments
 (0)