1919 */
2020package org .neo4j .gds .mem ;
2121
22- import com .carrotsearch .hppc .ObjectLongHashMap ;
23- import com .carrotsearch .hppc .ObjectLongMap ;
24- import com .carrotsearch .hppc .procedures .LongProcedure ;
2522import org .neo4j .gds .api .graph .store .catalog .GraphStoreAddedEvent ;
2623import org .neo4j .gds .api .graph .store .catalog .GraphStoreAddedEventListener ;
2724import org .neo4j .gds .api .graph .store .catalog .GraphStoreRemovedEvent ;
3128import org .neo4j .gds .core .utils .progress .UserTask ;
3229import org .neo4j .gds .logging .Log ;
3330
34- import java .util .concurrent .atomic .LongAdder ;
35-
3631import static org .neo4j .gds .utils .StringFormatting .formatWithLocale ;
3732
3833public class MemoryTracker implements TaskStoreListener , GraphStoreAddedEventListener , GraphStoreRemovedEventListener {
3934 private final long initialMemory ;
4035 private final GraphStoreMemoryContainer graphStoreMemoryContainer = new GraphStoreMemoryContainer ();
41- private final ObjectLongMap < JobId > memoryInUse = new ObjectLongHashMap <> ();
36+ private final TaskMemoryContainer taskMemoryContainer = new TaskMemoryContainer ();
4237 private final Log log ;
4338
4439 public MemoryTracker (long initialMemory , Log log ) {
@@ -53,7 +48,7 @@ public long initialMemory() {
5348
5449 public synchronized void track (JobId jobId , long memoryEstimate ) {
5550 log .debug ("Tracking %s: %s bytes" , jobId .asString (), memoryEstimate );
56- memoryInUse . put (jobId , memoryEstimate );
51+ taskMemoryContainer . reserve (jobId , memoryEstimate );
5752 log .debug ("Available memory after tracking task: %s bytes" , availableMemory ());
5853 }
5954
@@ -66,22 +61,21 @@ public synchronized void tryToTrack(JobId jobId, long memoryEstimate) throws Mem
6661 }
6762
6863 public synchronized long availableMemory () {
69- var reservedMemory = new LongAdder ();
70- memoryInUse .values ().forEach ((LongProcedure ) reservedMemory ::add );
71- return initialMemory - (reservedMemory .longValue () + graphStoreMemoryContainer .graphStoreReservedMemory ());
64+ return initialMemory - graphStoreMemoryContainer .graphStoreReservedMemory () - taskMemoryContainer .taskReservedMemory ();
7265 }
7366
7467 @ Override
7568 public void onTaskAdded (UserTask userTask ) {
7669 // do nothing, we add the memory explicitly prior to execution
70+ taskMemoryContainer .addTask (userTask );
7771 }
7872
7973 @ Override
8074 public synchronized void onTaskRemoved (UserTask userTask ) {
8175 var taskDescription = userTask .task ().description ();
8276 log .debug ("Removing task: %s" , taskDescription );
8377 var jobId = userTask .jobId ();
84- var removed = memoryInUse . remove ( jobId );
78+ var removed = taskMemoryContainer . removeTask ( userTask );
8579 log .debug ("Removed task %s (%s): %s bytes" , taskDescription , jobId .asString (), removed );
8680 log .debug ("Available memory after removing task: %s bytes" , availableMemory ());
8781 log .debug ("Done removing task: %s" , taskDescription );
0 commit comments