Skip to content

Commit 883ee09

Browse files
Make sure that memory is cleaned following bad validation
Co-authored-by: Veselin Nikolov <nickolov.vesselin@gmail.com>
1 parent 4edb267 commit 883ee09

File tree

3 files changed

+90
-4
lines changed

3 files changed

+90
-4
lines changed

applications/graph-store-catalog/src/main/java/org/neo4j/gds/applications/graphstorecatalog/MemoryUsageValidator.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,12 @@ public class MemoryUsageValidator {
4242
private final MemoryTracker memoryTracker;
4343
private final String username;
4444

45-
public MemoryUsageValidator(String username,MemoryTracker memoryTracker, boolean useMaxMemoryEstimation, Log log) {
45+
public MemoryUsageValidator(
46+
String username,
47+
MemoryTracker memoryTracker,
48+
boolean useMaxMemoryEstimation,
49+
Log log
50+
) {
4651
this.log = log;
4752
this.useMaxMemoryEstimation = useMaxMemoryEstimation;
4853
this.memoryTracker = memoryTracker;

native-projection/src/main/java/org/neo4j/gds/projection/NativeFactory.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,12 @@ private ProgressTracker initProgressTracker() {
133133

134134
@Override
135135
public CSRGraphStore build() {
136-
validate(dimensions, storeConfig);
137-
138-
var concurrency = graphProjectConfig.readConcurrency();
139136
try {
137+
// Start the sub-task so it will be registered in the task store
140138
progressTracker.beginSubTask();
139+
// `validate` can raise an exception which has to be handled in order to end the sub-task with failure
140+
validate(dimensions, storeConfig);
141+
var concurrency = graphProjectConfig.readConcurrency();
141142
Nodes nodes = loadNodes(concurrency);
142143
RelationshipImportResult relationships = loadRelationships(nodes.idMap(), concurrency);
143144
CSRGraphStore graphStore = createGraphStore(nodes, relationships);

native-projection/src/test/java/org/neo4j/gds/projection/NativeFactoryTest.java

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,40 @@
1919
*/
2020
package org.neo4j.gds.projection;
2121

22+
import com.carrotsearch.hppc.LongHashSet;
23+
import org.agrona.collections.MutableBoolean;
2224
import org.junit.jupiter.api.Test;
2325
import org.neo4j.gds.ImmutableRelationshipProjections;
26+
import org.neo4j.gds.NodeLabel;
27+
import org.neo4j.gds.NodeProjection;
2428
import org.neo4j.gds.NodeProjections;
2529
import org.neo4j.gds.Orientation;
2630
import org.neo4j.gds.RelationshipProjection;
2731
import org.neo4j.gds.RelationshipProjections;
2832
import org.neo4j.gds.RelationshipType;
2933
import org.neo4j.gds.api.CSRGraphStoreFactory;
34+
import org.neo4j.gds.api.GraphLoaderContext;
3035
import org.neo4j.gds.core.GraphDimensions;
3136
import org.neo4j.gds.core.ImmutableGraphDimensions;
3237
import org.neo4j.gds.core.concurrency.Concurrency;
38+
import org.neo4j.gds.core.utils.progress.JobId;
39+
import org.neo4j.gds.core.utils.progress.TaskRegistry;
40+
import org.neo4j.gds.core.utils.progress.TaskRegistryFactory;
41+
import org.neo4j.gds.logging.Log;
3342
import org.neo4j.gds.mem.MemoryEstimation;
3443
import org.neo4j.gds.mem.MemoryTree;
3544

45+
import java.util.Map;
46+
import java.util.Optional;
3647
import java.util.concurrent.atomic.AtomicReference;
3748

49+
import static org.assertj.core.api.Assertions.assertThat;
3850
import static org.junit.jupiter.api.Assertions.assertEquals;
51+
import static org.mockito.ArgumentMatchers.any;
52+
import static org.mockito.Mockito.mock;
53+
import static org.mockito.Mockito.times;
54+
import static org.mockito.Mockito.verify;
55+
import static org.mockito.Mockito.when;
3956

4057
class NativeFactoryTest {
4158

@@ -142,4 +159,67 @@ void memoryEstimationForMultipleProjections() {
142159
assertEquals(12_056_534_400L, estimate.memoryUsage().min);
143160
assertEquals(13_667_147_136L, estimate.memoryUsage().max);
144161
}
162+
163+
@Test
164+
void shouldCleanTaskOnValidateFailure(){
165+
166+
var mockRegistry = mock(TaskRegistry.class);
167+
168+
var mockRegistryFactory = mock(TaskRegistryFactory.class);
169+
when(mockRegistryFactory.newInstance(any())).thenReturn(mockRegistry);
170+
var mockgraphLoaderContext = mock(GraphLoaderContext.class);
171+
when(mockgraphLoaderContext.taskRegistryFactory()).thenReturn(mockRegistryFactory);
172+
when(mockgraphLoaderContext.log()).thenReturn(Log.noOpLog());
173+
174+
var labelSet = new LongHashSet();
175+
labelSet.add(GraphDimensions.NO_SUCH_LABEL);
176+
var graphDimensions =ImmutableGraphDimensions.builder()
177+
.nodeCount(100)
178+
.nodeLabelTokens(labelSet)
179+
.build();
180+
181+
var config = mock(GraphProjectFromStoreConfig.class);
182+
when(config.relationshipProjections()).thenReturn(new RelationshipProjections() {
183+
@Override
184+
public Map<RelationshipType, RelationshipProjection> projections() {
185+
return Map.of();
186+
}
187+
});
188+
189+
when(config.nodeProjections()).thenReturn(new NodeProjections() {
190+
@Override
191+
public Map<NodeLabel, NodeProjection> projections() {
192+
return Map.of(
193+
NodeLabel.of("foo"), new NodeProjection() {
194+
@Override
195+
public String label() {
196+
return "bar";
197+
}
198+
}
199+
);
200+
}
201+
});
202+
203+
when(config.logProgress()).thenReturn(true);
204+
when(config.jobId()).thenReturn(new JobId("foo"));
205+
when(config.readConcurrency()).thenReturn(new Concurrency(1));
206+
207+
var nativeFactory = NativeFactory.nativeFactory(
208+
config,
209+
mockgraphLoaderContext,
210+
Optional.of(graphDimensions)
211+
);
212+
213+
MutableBoolean failed = new MutableBoolean(false);
214+
try {
215+
nativeFactory.build();
216+
} catch (Exception ignored){
217+
failed.set(true);
218+
}
219+
220+
assertThat(failed.get()).isTrue();
221+
verify(mockRegistry,times(1)).registerTask(any());
222+
verify(mockRegistry,times(1)).markCompleted();
223+
224+
}
145225
}

0 commit comments

Comments
 (0)