@@ -338,6 +338,13 @@ NB_MODULE(_pyllmq, m) {
338338 res[" pageable" ] = size.PageableHost ;
339339 ret[nb::cast (name)] = res;
340340 }
341+
342+ auto stack = trainer->get_stack (gpu_id);
343+ for (const auto & [name, size] : stack) {
344+ nb::dict res;
345+ res[" stack" ] = size;
346+ ret[nb::cast (name)] = res;
347+ }
341348 return ret;
342349 }, nb::arg (" gpu_id" ) = 0 , " Get the current memory allocations for the given GPU" )
343350 ;
@@ -441,17 +448,22 @@ NB_MODULE(_pyllmq, m) {
441448 " Log GPU utilization state" )
442449 .def (" log_allocator" , [](TrainingRunLogger* logger, const nb::dict& stats) {
443450 std::vector<std::pair<std::string, sSegmentMemory >> cpp_stats;
451+ std::vector<std::pair<std::string, long >> cpp_stack;
444452 cpp_stats.reserve (stats.size ());
445453 for (auto item : stats) {
446454 std::string key = nb::cast<std::string>(item.first );
447455 nb::dict value = nb::cast<nb::dict>(item.second );
448- long device = nb::cast<long >(value[" device" ]);
449- long managed = nb::cast<long >(value[" managed" ]);
450- long pinned = nb::cast<long >(value[" pinned" ]);
451- long pageable = nb::cast<long >(value[" pageable" ]);
452- cpp_stats.emplace_back (key, sSegmentMemory {device, managed, pinned, pageable});
456+ if (value.contains (" stack" )) {
457+ cpp_stack.emplace_back (key, nb::cast<long >(value[" stack" ]));
458+ } else {
459+ long device = nb::cast<long >(value[" device" ]);
460+ long managed = nb::cast<long >(value[" managed" ]);
461+ long pinned = nb::cast<long >(value[" pinned" ]);
462+ long pageable = nb::cast<long >(value[" pageable" ]);
463+ cpp_stats.emplace_back (key, sSegmentMemory {device, managed, pinned, pageable});
464+ }
453465 }
454- logger->log_allocator (cpp_stats);
466+ logger->log_allocator (cpp_stats, cpp_stack );
455467 }, nb::arg (" stats" ), " Log memory allocator statistics" )
456468 .def (" set_expected_time_per_token" , [](TrainingRunLogger* logger, const MultiGPUPyTrainer* trainer){
457469 auto & config = trainer->config ();
0 commit comments