diff --git a/pyproject.toml b/pyproject.toml index aa54863..c6818b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pyllmq" -version = "0.3.0" +version = "0.3.2" description = "Python wrappers for LLMQ. LLM pretraining written in CUDA." readme = "README.md" requires-python = ">=3.12" diff --git a/scripts/demo.py b/scripts/demo.py index c9a44a5..a6865cd 100755 --- a/scripts/demo.py +++ b/scripts/demo.py @@ -68,7 +68,7 @@ def main(): print("\nmemory consumption:") for k, v in trainer.get_allocator_info(0).items(): - print(f" {k:20}: {v // 1024 // 1024:6} MiB") + print(f" {k:20}: {v['device'] // 1024 // 1024:6} MiB") train_loader.load_batch(in_tokens, out_tokens) diff --git a/scripts/train.py b/scripts/train.py index 4862f07..c43f678 100755 --- a/scripts/train.py +++ b/scripts/train.py @@ -177,7 +177,7 @@ def main(): # Log allocator stats for idx in range(config.gpus): - logger.log_allocator(trainer.get_allocator_info(idx)) + logger.log_allocator(trainer, idx) # calculate the expected time at peak flops for speed-of-light estimation logger.set_expected_time_per_token(trainer) diff --git a/src/binding/binding.cpp b/src/binding/binding.cpp index 35f6360..cd3e5cd 100644 --- a/src/binding/binding.cpp +++ b/src/binding/binding.cpp @@ -264,6 +264,7 @@ NB_MODULE(_pyllmq, m) { nb::class_(m, "LLMQTrainer") .def("__init__", [](MultiGPUPyTrainer *t, int ngpu, LLamaConfig config, LLamaOptions options, int batch_size, int seq_len, int grad_accum, bool memcpy_all_gather, bool memcpy_send_recv) { + options.ModelType = config.DType; new (t) MultiGPUPyTrainer(ngpu, config, options, batch_size, seq_len, grad_accum, memcpy_all_gather, memcpy_send_recv); }, nb::arg("ngpu"), nb::arg("config"), nb::arg("options"), nb::arg("batch_size"), nb::arg("seq_len"), nb::arg("grad_accum"), nb::arg("memcpy_all_gather") = true, nb::arg("memcpy_send_recv") = true) @@ -337,6 +338,13 @@ NB_MODULE(_pyllmq, m) { res["pageable"] = size.PageableHost; ret[nb::cast(name)] = res; } + + auto stack = trainer->get_stack_info(gpu_id); + for (const auto& [name, size] : stack) { + nb::dict res; + res["stack"] = size; + ret[nb::cast(name)] = res; + } return ret; }, nb::arg("gpu_id") = 0, "Get the current memory allocations for the given GPU") ; @@ -440,17 +448,22 @@ NB_MODULE(_pyllmq, m) { "Log GPU utilization state") .def("log_allocator", [](TrainingRunLogger* logger, const nb::dict& stats) { std::vector> cpp_stats; + std::vector> cpp_stack; cpp_stats.reserve(stats.size()); for (auto item : stats) { std::string key = nb::cast(item.first); nb::dict value = nb::cast(item.second); - long device = nb::cast(value["device"]); - long managed = nb::cast(value["managed"]); - long pinned = nb::cast(value["pinned"]); - long pageable = nb::cast(value["pageable"]); - cpp_stats.emplace_back(key, sSegmentMemory{device, managed, pinned, pageable}); + if (value.contains("stack")) { + cpp_stack.emplace_back(key, nb::cast(value["stack"])); + } else { + long device = nb::cast(value["device"]); + long managed = nb::cast(value["managed"]); + long pinned = nb::cast(value["pinned"]); + long pageable = nb::cast(value["pageable"]); + cpp_stats.emplace_back(key, sSegmentMemory{device, managed, pinned, pageable}); + } } - logger->log_allocator(cpp_stats); + logger->log_allocator(cpp_stats, cpp_stack); }, nb::arg("stats"), "Log memory allocator statistics") .def("set_expected_time_per_token", [](TrainingRunLogger* logger, const MultiGPUPyTrainer* trainer){ auto& config = trainer->config(); diff --git a/src/binding/py_train.cpp b/src/binding/py_train.cpp index 50149e7..4d6bfa1 100644 --- a/src/binding/py_train.cpp +++ b/src/binding/py_train.cpp @@ -14,6 +14,7 @@ #include "utilities/comm.h" #include "kernels/kernels.h" #include "models/llama_gradients.h" +#include "models/llama_run_state.h" MultiGPUPyTrainer::MultiGPUPyTrainer(int ngpus, LLamaConfig config, LLamaOptions options, int batch_size, int seq_len, int grad_accum, bool memcpy_all_gather, bool memcpy_send_recv) : mConfig(config), mOptions(options), B(batch_size), T(seq_len), mGradAccumulation(grad_accum) @@ -249,6 +250,14 @@ std::vector> MultiGPUPyTrainer::get_alloc return result; } +std::vector> MultiGPUPyTrainer::get_stack_info(int gpu_id) { + std::vector> result; + run_work([&result](sThreadContext& ctx) { + result = ctx.Model->run_state().Stack.get_allocation_stats(); + }, gpu_id); + return result; +} + std::vector> MultiGPUPyTrainer::get_gradients(int gpu_id) { std::vector> result; run_work([&result](sThreadContext& ctx) { diff --git a/src/binding/py_train.h b/src/binding/py_train.h index 9ed6e52..6b685b7 100644 --- a/src/binding/py_train.h +++ b/src/binding/py_train.h @@ -63,6 +63,7 @@ class MultiGPUPyTrainer const LLamaOptions& options() const { return mOptions; } std::vector> get_allocations(int gpu_id); + std::vector> get_stack_info(int gpu_id); std::vector> get_gradients(int gpu_id); private: diff --git a/src/binding/python/training.py b/src/binding/python/training.py index 2b5ae7b..b097251 100644 --- a/src/binding/python/training.py +++ b/src/binding/python/training.py @@ -311,6 +311,7 @@ def training_logger_context(config: TrainingConfig): logger.log_cmd(sys.argv) log_options = asdict(config) log_options["matmul_dtype"] = log_options["matmul_dtype"] or config.model_dtype + log_options["gradient_dtype"] = log_options["gradient_dtype"] or log_options["matmul_dtype"] log_options["verbosity"] = str(config.verbosity) logger.log_options(log_options) yield logger diff --git a/src/kernels/adamw.cu b/src/kernels/adamw.cu index b5bc44a..c485c62 100644 --- a/src/kernels/adamw.cu +++ b/src/kernels/adamw.cu @@ -3,7 +3,6 @@ // // Based on llm.c https://github.com/karpathy/llm.c -#include #include #include "squirrel_noise.cuh" @@ -37,7 +36,7 @@ __device__ GenericVector load_vector(const FloatIn* memory, con } template -__device__ void store_vector(FloatOut* memory, const GenericVector& in, float* scales, long idx) { +__device__ void store_vector(FloatOut* memory, const GenericVector& in, float* scales, long idx, unsigned active_mask) { float factor = 1.f; unsigned int rng = get_noise_2d(idx, blockIdx.y, 51245); if(scales != nullptr) { @@ -47,9 +46,8 @@ __device__ void store_vector(FloatOut* memory, const GenericVector warp size"); - unsigned mask = __activemask(); - for(int i = 1; i <= Threads; i *= 2) { - abs_max = std::max(abs_max, __shfl_xor_sync(mask, abs_max, i, Threads)); + for(int i = 1; i < Threads; i *= 2) { + abs_max = std::max(abs_max, __shfl_xor_sync(active_mask, abs_max, i, Threads)); } if(abs_max > 1e-10f) { factor = 448.f / abs_max; @@ -77,6 +75,7 @@ __device__ auto adamw_update(floatX* params_memory, const floatX* grads_memory, using vec_m_t = GenericVector; using vec_v_t = GenericVector; + const unsigned active_mask = __ballot_sync(0xffffffff, idx < num_parameters); if (idx >= num_parameters) { return vec_x_t::zeros(); } // guard vec_f_t m = load_vector(m_memory, m_scales, idx); @@ -108,7 +107,7 @@ __device__ auto adamw_update(floatX* params_memory, const floatX* grads_memory, } p_new.store(params_memory + idx); - store_vector(m_memory, m, m_scales, idx); + store_vector(m_memory, m, m_scales, idx, active_mask); v_new.store(v_memory + idx); return p_new; @@ -122,7 +121,7 @@ __global__ void adamw_kernel(floatX* params_memory, const floatX* grads_memory, using vec_x_t = GenericVector; __shared__ float block_abs_max; if(threadIdx.x == 0) { - block_abs_max = 1e-10f; + block_abs_max = 0.f; } float thread_abs_max = 0.0f; diff --git a/src/kernels/attention.cu b/src/kernels/attention.cu index 21e8de5..30ea9c1 100644 --- a/src/kernels/attention.cu +++ b/src/kernels/attention.cu @@ -9,11 +9,10 @@ #include #include -#include - #include "kernels/kernels.h" #include "utilities/tensor.h" #include "utilities/vec.cuh" +#include "kernel_utils.cuh" namespace cg = cooperative_groups; @@ -22,10 +21,11 @@ __global__ void __launch_bounds__(512) attention_forward_gpu_kernel( scalar_t* out, float* stats, float scale, const scalar_t* qkv, int B, int T, int Hq, int Hkv) { + constexpr const int SubWarpSize = 16; auto block = cg::this_thread_block(); auto warp = cg::tiled_partition<32>(block); - auto sub_warp = cg::tiled_partition<16>(block); + auto sub_warp = cg::tiled_partition(block); extern __shared__ float scratch[]; @@ -43,19 +43,19 @@ __global__ void __launch_bounds__(512) attention_forward_gpu_kernel( using vec_t = GenericVector; using fvec_t = GenericVector; - using q_cache_t = GenericVector; + using q_cache_t = GenericVector; q_cache_t q_cache; // combine values - using v_cache_t = GenericVector; + using v_cache_t = GenericVector; v_cache_t v_cache = v_cache_t::zeros(); // determine maximum and online logsumexp float maximum = std::numeric_limits::lowest(); float lse = 0; - for (int ee = 0; ee < E / (sub_warp.size() * vec_t::size); ++ee) { - int e = (ee * sub_warp.size() + sub_warp.thread_rank()) * vec_t::size; + for (int ee = 0; ee < E / (SubWarpSize * vec_t::size); ++ee) { + int e = (ee * SubWarpSize + sub_warp.thread_rank()) * vec_t::size; vec_t qv = vec_t::load(query + e); for (int j = 0; j < vec_t::size; ++j) { q_cache[ee * vec_t::size + j] = (float)qv[j]; @@ -65,14 +65,14 @@ __global__ void __launch_bounds__(512) attention_forward_gpu_kernel( for (int l = sub_warp.meta_group_rank(); l <= t; l += sub_warp.meta_group_size()) { ptrdiff_t kv_offset = l * TH * E; float qk = 0; - for (int ee = 0; ee < E / (sub_warp.size() * vec_t::size); ++ee) { - int e = (ee * sub_warp.size() + sub_warp.thread_rank()) * vec_t::size; + for (int ee = 0; ee < E / (SubWarpSize * vec_t::size); ++ee) { + int e = (ee * SubWarpSize + sub_warp.thread_rank()) * vec_t::size; vec_t kv = vec_t::load(keys + kv_offset + e); for (int j = 0; j < vec_t::size; ++j) { qk += q_cache[ee * vec_t::size + j] * (float)kv[j]; } } - qk = cg::reduce(sub_warp, qk, cg::plus{}); + qk = reduce_group_add(sub_warp, qk); if (qk > maximum) { float rescale = std::exp(scale * (maximum - qk)); for (int j = 0; j < v_cache_t::size; ++j) { @@ -84,8 +84,8 @@ __global__ void __launch_bounds__(512) attention_forward_gpu_kernel( float att = std::exp(scale * (qk - maximum)); lse += std::exp(scale * (qk - maximum)); - for (int ee = 0; ee < E / (sub_warp.size() * vec_t::size); ++ee) { - int e = (ee * sub_warp.size() + sub_warp.thread_rank()) * vec_t::size; + for (int ee = 0; ee < E / (SubWarpSize * vec_t::size); ++ee) { + int e = (ee * SubWarpSize + sub_warp.thread_rank()) * vec_t::size; vec_t vv = vec_t::load(values + kv_offset + e); for (int j = 0; j < vec_t::size; ++j) { v_cache[ee * vec_t::size + j] += att * (float)vv[j]; @@ -108,9 +108,9 @@ __global__ void __launch_bounds__(512) attention_forward_gpu_kernel( r_lse = scratch[warp.thread_rank() + sub_warp.meta_group_size()]; } - maximum = cg::reduce(warp, r_max, cg::greater{}); + maximum = reduce_group_max(warp, r_max); r_lse *= std::exp(scale * (r_max - maximum)); - lse = cg::reduce(warp, r_lse, cg::plus{}); + lse = reduce_group_add(warp, r_lse); float rescale = std::exp(scale * (l_max - maximum)) / lse; for (int j = 0; j < v_cache_t::size; ++j) { v_cache[j] *= rescale; @@ -120,8 +120,8 @@ __global__ void __launch_bounds__(512) attention_forward_gpu_kernel( } __syncthreads(); - for (int ee = 0; ee < E / (sub_warp.size() * vec_t::size); ++ee) { - int e = (ee * sub_warp.size() + sub_warp.thread_rank()) * vec_t::size; + for (int ee = 0; ee < E / (SubWarpSize * vec_t::size); ++ee) { + int e = (ee * SubWarpSize + sub_warp.thread_rank()) * vec_t::size; fvec_t store; for (int j = 0; j < vec_t::size; ++j) { store[j] = v_cache[ee * vec_t::size + j]; diff --git a/src/kernels/bias.cu b/src/kernels/bias.cu index ea9bd9d..12257c7 100644 --- a/src/kernels/bias.cu +++ b/src/kernels/bias.cu @@ -139,7 +139,7 @@ int get_bias_backward_scratch_size(ETensorDType dtype, int OC, const cudaDeviceP const int block_size = dp.maxThreadsPerMultiProcessor == 1536 ? 768 : 1024; const int OC_per_warp = 8 * ( 16 / get_dtype_size(dtype) ); // 64 at BF16 const int grid_size_x = div_ceil(OC, OC_per_warp); // e.g. 12 horizontal blocks for 768 OCs at BF16 - const int grid_size_y = max(1, block_size * dp.multiProcessorCount / (block_size * grid_size_x)); // full GPU! + const int grid_size_y = std::max(1, block_size * dp.multiProcessorCount / (block_size * grid_size_x)); // full GPU! return grid_size_y * OC * sizeof(float); } @@ -155,7 +155,7 @@ void backward_bias_imp(floatX* dbias, const FloatY* dout, const float* scale_a, dim3 block_dim = {4, 8, (unsigned)block_size/32}; const int OC_per_warp = block_dim.y * x128::size; // 64 at BF16 const int grid_size_x = div_ceil(OC, OC_per_warp); // e.g. 12 horizontal blocks for 768 OCs at BF16 - const int grid_size_y = max(1, block_size * dp.multiProcessorCount / (block_size * grid_size_x)); // full GPU! + const int grid_size_y = std::max(1, block_size * dp.multiProcessorCount / (block_size * grid_size_x)); // full GPU! if( (scale_a == nullptr) != (scale_b == nullptr) ) { throw std::logic_error("backward_bias: scale_a and scale_b must be both nullptr or both non-nullptr"); diff --git a/src/kernels/encoder.cu b/src/kernels/encoder.cu index d59117f..f563caa 100644 --- a/src/kernels/encoder.cu +++ b/src/kernels/encoder.cu @@ -178,7 +178,7 @@ template void encoder_backward_imp(floatX* dwte, int* scratch, // gpu outputs & scratch int* workload_indices, int4* bucket_info, // cpu scratch buffers const floatX* dout, const int* inp, const int* inputs_cpu, // cpu/gpu inputs - int B, int T, int C, unsigned int seed, cudaStream_t stream) { + int B, int T, int C, unsigned int seed, cudaStream_t stream, cudaEvent_t sync_event, cudaStream_t copy_stream) { using x128 = GenericVector; int num_c_groups = div_ceil((size_t)C, x128::size * 32); @@ -220,12 +220,13 @@ void encoder_backward_imp(floatX* dwte, int* scratch, // gpu outputs & scratch bucket_index++; } - // Step 3: Copy data from host to device (async until the last one to avoid synchronising CPU/GPU twice) - // todo - could use CUDA events (even without streams) to avoid CPU/GPU synchronisation completely + // Step 3: Copy data from host to device (async on a different stream) int4* d_bucket_info = (int4*)scratch; int* d_workload_indices = (int*)(scratch + B*T*num_c_groups * 4); - CUDA_CHECK(cudaMemcpyAsync(d_bucket_info, bucket_info, num_buckets * sizeof(int4), cudaMemcpyHostToDevice, stream)); - CUDA_CHECK(cudaMemcpyAsync(d_workload_indices, workload_indices, total_items * sizeof(int), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(d_bucket_info, bucket_info, num_buckets * sizeof(int4), cudaMemcpyHostToDevice, copy_stream)); + CUDA_CHECK(cudaMemcpyAsync(d_workload_indices, workload_indices, total_items * sizeof(int), cudaMemcpyHostToDevice, copy_stream)); + CUDA_CHECK(cudaEventRecord(sync_event, copy_stream)); + CUDA_CHECK(cudaStreamWaitEvent(stream, sync_event, 0)); // Launch wte kernel // todo - profile block sizes on more content (depends on number of buckets and on GPU?) @@ -236,13 +237,13 @@ void encoder_backward_imp(floatX* dwte, int* scratch, // gpu outputs & scratch void encoder_backward(float* dwte, int* scratch, // gpu outputs & scratch int* workload_indices, int4* bucket_info, // cpu scratch buffers const float* dout, const int* inp, const int* inputs_cpu, // cpu/gpu inputs - int B, int T, int C, unsigned int seed, cudaStream_t stream) { - encoder_backward_imp(dwte, scratch, workload_indices, bucket_info, dout, inp, inputs_cpu, B, T, C, seed, stream); + int B, int T, int C, unsigned int seed, cudaStream_t stream, cudaEvent_t sync_event, cudaStream_t copy_stream) { + encoder_backward_imp(dwte, scratch, workload_indices, bucket_info, dout, inp, inputs_cpu, B, T, C, seed, stream, sync_event, copy_stream); } void encoder_backward(nv_bfloat16* dwte, int* scratch, // gpu outputs & scratch int* workload_indices, int4* bucket_info, // cpu scratch buffers const nv_bfloat16* dout, const int* inp, const int* inputs_cpu, // cpu/gpu inputs - int B, int T, int C, unsigned int seed, cudaStream_t stream) { - encoder_backward_imp(dwte, scratch, workload_indices, bucket_info, dout, inp, inputs_cpu, B, T, C, seed, stream); + int B, int T, int C, unsigned int seed, cudaStream_t stream, cudaEvent_t sync_event, cudaStream_t copy_stream) { + encoder_backward_imp(dwte, scratch, workload_indices, bucket_info, dout, inp, inputs_cpu, B, T, C, seed, stream, sync_event, copy_stream); } diff --git a/src/kernels/fused_classifier.cu b/src/kernels/fused_classifier.cu index 31787df..68e9bfe 100644 --- a/src/kernels/fused_classifier.cu +++ b/src/kernels/fused_classifier.cu @@ -5,28 +5,13 @@ #include +#include "kernel_utils.cuh" #include "utilities/utils.h" #include "utilities/vec.cuh" // ---------------------------------------------------------------------------- // CUDA kernels -// warp-level reduction for finding the maximum value -__device__ inline float warpReduceMax(float val) { - for (int offset = 16; offset > 0; offset /= 2) { - val = fmaxf(val, __shfl_xor_sync(0xFFFFFFFF, val, offset)); - } - return val; -} - -__device__ inline float warpReduceSum(float val) { - for (int offset = 16; offset > 0; offset /= 2) { - val += __shfl_xor_sync(0xFFFFFFFF, val, offset); - } - return val; -} - - // requires all 32 threads in the warp to be active, but should work for any block size // uses non-dynamic shared memory so every call increases shared memory requirements by 128 bytes // the fact it's unique shared memory allows us to avoid an extra __syncthreads() call at the end diff --git a/src/kernels/global_norm.cu b/src/kernels/global_norm.cu index 26f5068..cc4537e 100644 --- a/src/kernels/global_norm.cu +++ b/src/kernels/global_norm.cu @@ -7,11 +7,11 @@ #include #include -#include #include #include "utilities/utils.h" +#include "kernel_utils.cuh" // ---------------------------------------------------------------------------- // CUDA kernels @@ -27,7 +27,7 @@ __device__ float global_norm_squared_for_range(const T* data, size_t count) { cooperative_groups::thread_block block = cooperative_groups::this_thread_block(); auto warp = cooperative_groups::tiled_partition<32>(block); - accumulator = cooperative_groups::reduce(warp, accumulator, cooperative_groups::plus()); + accumulator = reduce_group_add(warp, accumulator); __shared__ float shared_accumulator[32]; if(warp.thread_rank() == 0) { shared_accumulator[warp.meta_group_rank()] = accumulator; @@ -35,7 +35,7 @@ __device__ float global_norm_squared_for_range(const T* data, size_t count) { __syncthreads(); // block-level reduce float total = warp.thread_rank() < warp.meta_group_size() ? shared_accumulator[warp.thread_rank()] : 0.f; - total = cooperative_groups::reduce(warp, total, cooperative_groups::plus()); + total = reduce_group_add(warp, total); return total; } @@ -60,7 +60,7 @@ __global__ void deterministic_sum_kernel(float* out, const floatX* data, std::si cooperative_groups::thread_block block = cooperative_groups::this_thread_block(); auto warp = cooperative_groups::tiled_partition<32>(block); - float warp_sum = cooperative_groups::reduce(warp, thread_sum, cooperative_groups::plus()); + float warp_sum = reduce_group_add(warp, thread_sum); __shared__ float shared_accumulator[32]; if(warp.thread_rank() == 0) { shared_accumulator[warp.meta_group_rank()] = warp_sum; @@ -69,7 +69,7 @@ __global__ void deterministic_sum_kernel(float* out, const floatX* data, std::si // block-level reduce if(warp.meta_group_rank() == 0) { float total = warp.thread_rank() < warp.meta_group_size() ? shared_accumulator[warp.thread_rank()] : 0.f; - total = cooperative_groups::reduce(warp, total, cooperative_groups::plus()); + total = reduce_group_add(warp, total); if (threadIdx.x == 0) { *out = total; } diff --git a/src/kernels/kernel_utils.cuh b/src/kernels/kernel_utils.cuh index 502d311..f24223d 100644 --- a/src/kernels/kernel_utils.cuh +++ b/src/kernels/kernel_utils.cuh @@ -6,6 +6,11 @@ #define LLMQ_SRC_KERNELS_KERNEL_UTILS_CUH #include +#include + +#ifndef __HIP__ +#include +#endif static __forceinline__ __device__ void handle_absmax_reduction(float* __restrict__ abs_max_ptr, float* __restrict__ block_max, float thread_max) { if (abs_max_ptr) { @@ -16,6 +21,7 @@ static __forceinline__ __device__ void handle_absmax_reduction(float* __restrict if(threadIdx.x % 32 == 0) { atomicMax_block(reinterpret_cast(block_max), warp_max); } + __syncthreads(); if(threadIdx.x == 0) { atomicMax(reinterpret_cast(abs_max_ptr), __float_as_uint(*block_max)); @@ -23,4 +29,28 @@ static __forceinline__ __device__ void handle_absmax_reduction(float* __restrict } } +template +static __forceinline__ __device__ Element reduce_group_add(Group& group, Element value) { + return cooperative_groups::reduce(group, value, cooperative_groups::plus()); +} + +template +static __forceinline__ __device__ Element reduce_group_max(Group& group, Element value) { + return cooperative_groups::reduce(group, value, cooperative_groups::greater()); +} + +static __forceinline__ __device__ float warpReduceSum(float val) { + for (int offset = 16; offset > 0; offset /= 2) { + val += __shfl_xor_sync(0xFFFFFFFFu, val, offset); + } + return val; +} + +__device__ inline float warpReduceMax(float val) { + for (int offset = 16; offset > 0; offset /= 2) { + val = fmaxf(val, __shfl_xor_sync(0xFFFFFFFFu, val, offset)); + } + return val; +} + #endif //LLMQ_SRC_KERNELS_KERNEL_UTILS_CUH diff --git a/src/kernels/kernels.cpp b/src/kernels/kernels.cpp index 5e101d8..bdff01e 100644 --- a/src/kernels/kernels.cpp +++ b/src/kernels/kernels.cpp @@ -116,13 +116,17 @@ void encoder_forward(Tensor& out, const Tensor& inp, const Tensor& wte, std::opt void encoder_backward(Tensor& dwte, Tensor& scratch, Tensor& workload_indices, Tensor& bucket_info, const Tensor& dout, const Tensor& inp, const Tensor& inputs_cpu, - int B, int T, int C, unsigned int seed, cudaStream_t stream) { + int B, int T, int C, unsigned int seed, cudaStream_t stream, cudaEvent_t sync_event, cudaStream_t copy_stream) { assert(workload_indices.Device == -1); assert(bucket_info.Device == -1); if(dwte.DType == ETensorDType::FP32) { - encoder_backward(dwte.get(), scratch.get(), workload_indices.get(), (int4*)bucket_info.get(), dout.get(), inp.get(), inputs_cpu.get(), B, T, C, seed, stream); + encoder_backward(dwte.get(), scratch.get(), workload_indices.get(), + (int4*)bucket_info.get(), dout.get(), inp.get(), inputs_cpu.get(), + B, T, C, seed, stream, sync_event, copy_stream); } else if(dwte.DType == ETensorDType::BF16) { - encoder_backward(dwte.get(), scratch.get(), workload_indices.get(), (int4*)bucket_info.get(), dout.get(), inp.get(), inputs_cpu.get(), B, T, C, seed, stream); + encoder_backward(dwte.get(), scratch.get(), workload_indices.get(), + (int4*)bucket_info.get(), dout.get(), inp.get(), inputs_cpu.get(), + B, T, C, seed, stream, sync_event, copy_stream); } else { throw std::logic_error("encoder_backward: unsupported dtype"); } diff --git a/src/kernels/kernels.h b/src/kernels/kernels.h index c121211..7d0f966 100644 --- a/src/kernels/kernels.h +++ b/src/kernels/kernels.h @@ -28,15 +28,18 @@ void encoder_forward(Tensor& out, const Tensor& inp, const Tensor& wte, std::opt void encoder_backward(float* dwte, int* scratch, int* workload_indices, int4* bucket_info, const float* dout, const int* inp, const int* inputs_cpu, - int B, int T, int C, unsigned int seed, cudaStream_t stream); + int B, int T, int C, unsigned int seed, cudaStream_t stream, cudaEvent_t sync_event, cudaStream_t copy_stream); void encoder_backward(nv_bfloat16* dwte, int* scratch, int* workload_indices, int4* bucket_info, const nv_bfloat16* dout, const int* inp, const int* inputs_cpu, - int B, int T, int C, unsigned int seed, cudaStream_t stream); + int B, int T, int C, unsigned int seed, cudaStream_t stream, cudaEvent_t sync_event, cudaStream_t copy_stream); + +// The kernel runs on `stream`, but the bucket info that gets generated on CPU to enable efficient determinism +// can be copied using `copy_stream`, so the kernel launch does not have to wait. void encoder_backward(Tensor& dwte, Tensor& scratch, Tensor& workload_indices, Tensor& bucket_info, const Tensor& dout, const Tensor& inp, const Tensor& inputs_cpu, - int B, int T, int C, unsigned int seed, cudaStream_t stream); + int B, int T, int C, unsigned int seed, cudaStream_t stream, cudaEvent_t sync_event, cudaStream_t copy_stream); void rmsnorm_forward(float* out, float* rms, const float* inp, const float* weight, float* abs_max_ptr, float epsilon, int B, int T, int C, cudaStream_t stream); void rmsnorm_forward(nv_bfloat16* out, float* rms, const nv_bfloat16* inp, const nv_bfloat16* weight, float* abs_max_ptr, float epsilon, int B, int T, int C, cudaStream_t stream); diff --git a/src/kernels/matmul.cpp b/src/kernels/matmul.cpp index c858679..7d67462 100644 --- a/src/kernels/matmul.cpp +++ b/src/kernels/matmul.cpp @@ -4,7 +4,6 @@ // Based on llm.c https://github.com/karpathy/llm.c #include -#include #include #include "kernels.h" @@ -34,6 +33,10 @@ cublasLtHandle_t create_cublaslt_handle() { return handle; } +void destroy_cublaslt_handle(cublasLtHandle_t handle) { + CUBLAS_CHECK(cublasLtDestroy(handle)); +} + // ---------------------------------------------------------------------------- // kernel launchers @@ -119,10 +122,6 @@ void matmul_cublaslt(FloatC* d, const FloatA* a, const FloatB* b, const FloatBia CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scale_b, sizeof(&scale_b))); } - // set scale type to FP32 (needs to be FP16 if and only if using CUBLAS_COMPUTE_16F, so it's FP32 even for FP8!) - cublasDataType_t scale_type = CUDA_R_32F; - CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type))); - // find a suitable algorithm (cached internally so shouldn't take much CPU time in practice) cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, ALayout, BLayout, CLayout, DLayout, preference, 1, &heuristic, &returnedResults); diff --git a/src/kernels/quant.cu b/src/kernels/quant.cu index 1dd45a9..1da6f33 100644 --- a/src/kernels/quant.cu +++ b/src/kernels/quant.cu @@ -13,7 +13,7 @@ __global__ void reduce_abs_max_kernel(float* __restrict__ result, const floatX* __shared__ float block_abs_max; if(threadIdx.x == 0) { - block_abs_max = 1e-10f; + block_abs_max = 0.f; } __syncthreads(); float thread_abs_max = 0.f; @@ -76,7 +76,7 @@ __global__ void quantize_with_abs_max_kernel(FloatOut* __restrict__ out, float* const FloatIn* __restrict__ in, const float* __restrict__ abs_max, long N) { using vec_t = GenericVector; using f8v_t = GenericVector; - float scale = 448.f / *abs_max; + float scale = 448.f / fmaxf(*abs_max, 1e-10f); if(threadIdx.x == 0 && blockIdx.x == 0 && scale_ptr) { *scale_ptr = 1.f / scale; } @@ -150,7 +150,7 @@ __global__ void quantize_and_transpose_with_abs_max_kernel(std::int8_t* out, flo template __global__ void quantize_and_transpose_with_abs_max_kernel(__nv_fp8_e4m3* out, float* scale_ptr, const floatX* in, const float* abs_max, int rows, int cols) { - float scale = 448.f / *abs_max; + float scale = 448.f / fmaxf(*abs_max, 1e-10f); if(threadIdx.x == 0 && blockIdx.x == 0 && scale_ptr) { *scale_ptr = 1.f / scale; } diff --git a/src/kernels/rmsnorm.cu b/src/kernels/rmsnorm.cu index 2e2ae1b..381be75 100644 --- a/src/kernels/rmsnorm.cu +++ b/src/kernels/rmsnorm.cu @@ -14,13 +14,6 @@ constexpr const int WARP_SIZE = 32; -__device__ inline float warpReduceSum(float val) { - for (int offset = 16; offset > 0; offset /= 2) { - val += __shfl_xor_sync(0xFFFFFFFF, val, offset); - } - return val; -} - template __device__ void rmsnorm_forward_kernel(floatX* __restrict__ out, float* __restrict__ rms, const floatX* __restrict__ inp, const floatX* __restrict__ weight, @@ -44,7 +37,7 @@ __device__ void rmsnorm_forward_kernel(floatX* __restrict__ out, float* __restri s_weight[i/x128::size] = x128::load(weight + i); } if (abs_max_ptr && threadIdx.x == 0) { - block_abs_max = 1e-10f; + block_abs_max =0.f; } __syncthreads(); @@ -68,7 +61,7 @@ __device__ void rmsnorm_forward_kernel(floatX* __restrict__ out, float* __restri acc = warpReduceSum(acc) / C; float s = rsqrtf(acc + epsilon); - float thread_abs_max = -1.f; + float thread_abs_max = 0.f; for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { const x128 in_data = s_in[c / x128::size]; @@ -80,7 +73,7 @@ __device__ void rmsnorm_forward_kernel(floatX* __restrict__ out, float* __restri // so we try to match out_data[k] = (floatX)n * (floatX)w[k]; // scale if (abs_max_ptr) { - thread_abs_max = std::max(thread_abs_max, fabsf(out_data[k])); + thread_abs_max = fmaxf(thread_abs_max, fabsf(out_data[k])); } } @@ -117,7 +110,7 @@ __device__ void fused_residual_rmsnorm_forward_kernel(floatX* residual, floatX* s_weight[i/x128::size] = x128::load(weight + i); } if (abs_max_ptr && threadIdx.x == 0) { - block_abs_max = 1e-10f; + block_abs_max = 0.f; } __syncthreads(); @@ -146,7 +139,7 @@ __device__ void fused_residual_rmsnorm_forward_kernel(floatX* residual, floatX* sum_squared = warpReduceSum(sum_squared) / C; float s = rsqrtf(sum_squared + epsilon); - float thread_abs_max = -1.f; + float thread_abs_max = 0.f; for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { const x128 res = s_res[c / x128::size]; @@ -156,7 +149,7 @@ __device__ void fused_residual_rmsnorm_forward_kernel(floatX* residual, floatX* float n = s * (float)res[k]; // normalized output out[k] = (floatX)n * (floatX)w[k]; // scale if (abs_max_ptr) { - thread_abs_max = std::max(thread_abs_max, fabsf(out[k])); + thread_abs_max = fmaxf(thread_abs_max, fabsf(out[k])); } } @@ -269,7 +262,7 @@ rmsnorm_backward_kernel10(floatX* dinp, floatX* dweight, std::byte* scratch, f128::zeros().store(dweight_shared + i); } if (abs_max_ptr && threadIdx.x == 0) { - block_abs_max = 1e-10f; + block_abs_max = 0.f; } __syncthreads(); @@ -361,7 +354,7 @@ rmsnorm_backward_kernel10(floatX* dinp, floatX* dweight, std::byte* scratch, dinp128.store(dinp_bt + global_index); for(int i = 0; i < x128::size; ++i) { - thread_abs_max = std::max(thread_abs_max, fabsf(dinp128[i])); + thread_abs_max = fmaxf(thread_abs_max, fabsf(dinp128[i])); } } } diff --git a/src/kernels/rope.cu b/src/kernels/rope.cu index 38f2d0b..4d88014 100644 --- a/src/kernels/rope.cu +++ b/src/kernels/rope.cu @@ -45,10 +45,10 @@ __global__ void rope_kernel(floatX *out, const floatX *inp, const floatX *freqs_ __shared__ float block_abs_max; if (abs_max_ptr) { if(threadIdx.x == 0) - block_abs_max = 1e-10f; + block_abs_max = 0.f; __syncthreads(); } - float thread_abs_max = 1e-10f; + float thread_abs_max = 0.f; int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x64::size; int head_dim_half = head_dim / 2; diff --git a/src/kernels/swiglu.cu b/src/kernels/swiglu.cu index c0ad622..e34e4cd 100644 --- a/src/kernels/swiglu.cu +++ b/src/kernels/swiglu.cu @@ -21,7 +21,7 @@ __global__ void swiglu_forward_kernel(floatX* out, const floatX* inp, float* abs using x128 = GenericVector; // thread coordinates - long idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; floatX* out_ptr = out + idx; int bt = (idx / C); int c = idx % C; @@ -34,7 +34,7 @@ __global__ void swiglu_forward_kernel(floatX* out, const floatX* inp, float* abs // so they don't cost us in this memory-bound kernel. if (abs_max_ptr) { if(threadIdx.x == 0) { - block_max = 1e-10f; + block_max = 0.f; } __syncthreads(); } @@ -56,14 +56,49 @@ __global__ void swiglu_forward_kernel(floatX* out, const floatX* inp, float* abs handle_absmax_reduction(abs_max_ptr, &block_max, thread_max); } +template +__global__ void swiglu_forward_quant_kernel(__nv_fp8_e4m3* out, float* scale_ptr, const floatX* inp, const float* abs_max_ptr, int C) { + using x128 = GenericVector; + using f8v_t = GenericVector<__nv_fp8_e4m3, 16 / sizeof(floatX)>; + + float scale = 448.f / fmaxf(*abs_max_ptr, 1e-10f); + if(threadIdx.x == 0 && blockIdx.x == 0 && scale_ptr) { + *scale_ptr = 1.f / scale; + } + + // thread coordinates + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; + __nv_fp8_e4m3* out_ptr = out + idx; + int bt = (idx / C); + int c = idx % C; + + const floatX* up_ptr = inp + (bt * C * 2 + c); + const floatX* gate_ptr = up_ptr + C; + + f8v_t packed_out; + x128 up_inp = x128::load_cs(up_ptr); + x128 gate_inp = x128::load_cs(gate_ptr); + for(int k = 0; k < up_inp.size; ++k) { + float x1 = (float)up_inp[k]; + float x2 = (float)gate_inp[k]; + float result = (x1 * x2) / (1.0f + expf(-x2)); + floatX qr = (floatX)result; + __nv_fp8_e4m3 quant; + quant.__x = __nv_cvt_float_to_fp8(scale * (float)qr, __nv_saturation_t::__NV_SATFINITE, __nv_fp8_interpretation_t::__NV_E4M3); + packed_out[k] = quant; + } + packed_out.store(out_ptr); +} + + //! persistent kernel for swiglu. If the input tensor is large enough, the persistent kernel gives maybe 5-10% speed-up //! over the simple baseline. template __global__ __launch_bounds__(128) void swiglu_forward_persistent_kernel(floatX* out, const floatX* inp, float* abs_max_ptr, int BT, int C) { using x128 = GenericVector; - long start = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; - long stride = gridDim.x * blockDim.x * x128::size; + int start = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; + int stride = gridDim.x * blockDim.x * x128::size; __shared__ float block_max; __shared__ alignas(16) floatX up_buffer[2 * 128 * (16/sizeof(floatX))]; @@ -72,7 +107,7 @@ __global__ __launch_bounds__(128) void swiglu_forward_persistent_kernel(floatX* // so they don't cost us in this memory-bound kernel. if (abs_max_ptr) { if(threadIdx.x == 0) { - block_max = 1e-10f; + block_max = 0.f; } __syncthreads(); } @@ -95,7 +130,7 @@ __global__ __launch_bounds__(128) void swiglu_forward_persistent_kernel(floatX* __pipeline_commit(); int phase = 0; - for(long idx = start; idx < BT*C; idx += stride) { + for(int idx = start; idx < BT*C; idx += stride) { // note: each thread reads only what it writes itself, so there is no need for further synchronization here __pipeline_wait_prior(1); x128 up_inp = x128::load(up_buffer + lane_base + 128 * x128::size * phase); @@ -124,52 +159,18 @@ __global__ __launch_bounds__(128) void swiglu_forward_persistent_kernel(floatX* handle_absmax_reduction(abs_max_ptr, &block_max, thread_max); } -template -__global__ void swiglu_forward_quant_kernel(__nv_fp8_e4m3* out, float* scale_ptr, const floatX* inp, const float* abs_max_ptr, int C) { - using x128 = GenericVector; - using f8v_t = GenericVector<__nv_fp8_e4m3, 16 / sizeof(floatX)>; - - float scale = 448.f / *abs_max_ptr; - if(threadIdx.x == 0 && blockIdx.x == 0 && scale_ptr) { - *scale_ptr = 1.f / scale; - } - - // thread coordinates - long idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; - __nv_fp8_e4m3* out_ptr = out + idx; - int bt = (idx / C); - int c = idx % C; - - const floatX* up_ptr = inp + (bt * C * 2 + c); - const floatX* gate_ptr = up_ptr + C; - - f8v_t packed_out; - x128 up_inp = x128::load_cs(up_ptr); - x128 gate_inp = x128::load_cs(gate_ptr); - for(int k = 0; k < up_inp.size; ++k) { - float x1 = (float)up_inp[k]; - float x2 = (float)gate_inp[k]; - float result = (x1 * x2) / (1.0f + expf(-x2)); - floatX qr = (floatX)result; - __nv_fp8_e4m3 quant; - quant.__x = __nv_cvt_float_to_fp8(scale * (float)qr, __nv_saturation_t::__NV_SATFINITE, __nv_fp8_interpretation_t::__NV_E4M3); - packed_out[k] = quant; - } - packed_out.store(out_ptr); -} - template __global__ void swiglu_forward_quant_persistent_kernel(__nv_fp8_e4m3* out, float* scale_ptr, const floatX* inp, const float* abs_max_ptr, int BT, int C) { using x128 = GenericVector; using f8v_t = GenericVector<__nv_fp8_e4m3, 16 / sizeof(floatX)>; - float scale = 448.f / *abs_max_ptr; + float scale = 448.f / fmaxf(*abs_max_ptr, 1e-10f); if(threadIdx.x == 0 && blockIdx.x == 0 && scale_ptr) { *scale_ptr = 1.f / scale; } - long start = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; - long stride = gridDim.x * blockDim.x * x128::size; + int start = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; + int stride = gridDim.x * blockDim.x * x128::size; __shared__ alignas(16) floatX up_buffer[2 * 128 * (16/sizeof(floatX))]; __shared__ alignas(16) floatX gate_buffer[2 * 128 * (16/sizeof(floatX))]; @@ -192,7 +193,7 @@ __global__ void swiglu_forward_quant_persistent_kernel(__nv_fp8_e4m3* out, float __pipeline_commit(); int phase = 0; - for(long idx = start; idx < BT*C; idx += stride) { + for(int idx = start; idx < BT*C; idx += stride) { // note: each thread reads only what it writes itself, so there is no need for further synchronization here __pipeline_wait_prior(1); x128 up_inp = x128::load(up_buffer + lane_base + 128 * x128::size * phase); @@ -249,7 +250,7 @@ __global__ void swiglu_backward_kernel1(floatX* dinp, const floatX* dout, const // so they don't cost us in this memory-bound kernel. if (abs_max_ptr) { if(threadIdx.x == 0) { - block_max = 1e-10f; + block_max = 0.f; } __syncthreads(); } @@ -269,8 +270,8 @@ __global__ void swiglu_backward_kernel1(floatX* dinp, const floatX* dout, const dinp2[k] = (floatX)dx2; if (abs_max_ptr) { - thread_max = std::max(thread_max, fabsf(dinp1[k])); - thread_max = std::max(thread_max, fabsf(dinp2[k])); + thread_max = fmaxf(thread_max, fabsf(dinp1[k])); + thread_max = fmaxf(thread_max, fabsf(dinp2[k])); } } dinp1.store(dinp1_ptr); @@ -286,6 +287,11 @@ template void swiglu_forward_impl(floatX* out, const floatX* inp, float* abs_max_ptr, int B, int T, int C, cudaStream_t stream) { // input is (B, T, 2C), output is (B, T, C) // we have that inp[b, t, :] = [fc1, fc2] (i.e. they are concatenated in each C-fiber) + + if (2ll*B*T*C >= std::numeric_limits::max()) { + throw std::runtime_error("swiglu_forward: input too large"); + } + using x128 = GenericVector; if (abs_max_ptr) CUDA_CHECK(cudaMemsetAsync(abs_max_ptr, 0, sizeof(float), stream)); @@ -317,6 +323,9 @@ void swiglu_forward(float* out, const float* inp, float* abs_max_ptr, int B, int } void swiglu_forward_quant(__nv_fp8_e4m3* out, float* scale_ptr, const nv_bfloat16* inp, const float* abs_max_ptr, int B, int T, int C, cudaStream_t stream) { + if (2ll*B*T*C >= std::numeric_limits::max()) { + throw std::runtime_error("swiglu_forward_quant: input too large"); + } using x128 = GenericVector; const int block_size = 128; assert(C % x128::size == 0); @@ -338,6 +347,10 @@ void swiglu_forward_quant(__nv_fp8_e4m3* out, float* scale_ptr, const nv_bfloat1 template void swiglu_backward_impl(floatX* dinp, const floatX* dout, const floatX* inp, float* abs_max, int B, int T, int C, cudaStream_t stream) { + if (2ll*B*T*C >= std::numeric_limits::max()) { + throw std::runtime_error("swiglu_backward: output too large"); + } + using x128 = GenericVector; // input is (B, T, 2C), output is (B, T, C) // we have that inp[b, t, :] = [fc1, fc2] (i.e. they are concatenated in each C-fiber) diff --git a/src/models/llama_config.cpp b/src/models/llama_config.cpp index 216d23c..8c434bd 100644 --- a/src/models/llama_config.cpp +++ b/src/models/llama_config.cpp @@ -180,7 +180,7 @@ static LLamaConfig create_llama3_config(int hidden_size, int intermediate_size, LLamaConfig create_config_from_name(std::string_view name, ETensorDType dtype) { if(iequals(name, "Qwen2.5-0.5B")) { - return create_qwen25_config(896, 4861, 14, 2, 24, 1e-06f, true, dtype); + return create_qwen25_config(896, 4864, 14, 2, 24, 1e-06f, true, dtype); } else if(iequals(name, "Qwen2.5-1.5B")) { return create_qwen25_config(1536, 8960, 12, 2, 28, 1e-06f, true, dtype); } else if(iequals(name, "Qwen2.5-3B")) { diff --git a/src/models/llama_model.cpp b/src/models/llama_model.cpp index 54abd94..a5f0330 100644 --- a/src/models/llama_model.cpp +++ b/src/models/llama_model.cpp @@ -94,12 +94,13 @@ void LLamaModel::forward(Tensor inputs, NCCLCommunicator& comm, int micro_step) Parameters->invalidate(); } - assert(rs->Inputs.Sizes[0] >= B); assert(rs->Inputs.Sizes[1] >= T); assert(inputs.Device == -1); { NvtxRange r{"copy-input"}; + // no point running this copy on side stream: input is needed by embedding gradients, which is + // the last op in backward. CUDA_CHECK(cudaMemcpyAsync(rs->Inputs.Data, inputs.Data, inputs.bytes(), cudaMemcpyHostToDevice, main_stream)); CUDA_CHECK(cudaEventRecord(rs->TransferDone, main_stream)); } @@ -346,9 +347,14 @@ void LLamaModel::backward(Tensor inputs, Tensor targets, NCCLCommunicator& comm, long T = inputs.Sizes[1]; const size_t C = Config.HiddenSize; const size_t L = Config.NumLayers; - - CUDA_CHECK(cudaMemcpyAsync(rs->Targets.Data, targets.Data, targets.bytes(), cudaMemcpyHostToDevice, main_stream)); - CUDA_CHECK(cudaEventRecord(rs->TransferDone, main_stream)); + { + NvtxRange r{"copy-targets"}; + // make sure rs->Targets is no longer needed by the previous step. + CUDA_CHECK(cudaStreamWaitEvent(rs->SideStream, rs->BackwardDone, 0)); + CUDA_CHECK(cudaMemcpyAsync(rs->Targets.Data, targets.Data, targets.bytes(), cudaMemcpyHostToDevice, rs->SideStream)); + CUDA_CHECK(cudaEventRecord(rs->TransferDone, rs->SideStream)); + CUDA_CHECK(cudaStreamWaitEvent(main_stream, rs->TransferDone, 0)); + } bool last_step = micro_step == grad_accum_steps - 1; // on the first micro-step zero the gradients, as we're about to += accumulate into them @@ -426,7 +432,7 @@ void LLamaModel::backward(Tensor inputs, Tensor targets, NCCLCommunicator& comm, auto& d_emb = Grads->get_embeddings_full(main_stream, comm, accumulate); encoder_backward(d_emb, rs->EncoderBwdScratch, rs->EncoderBwdIndices, rs->EncoderBwdInfo, - rs->DEmb, rs->Inputs, inputs, B, T, C, OptimizerRNG(), main_stream); + rs->DEmb, rs->Inputs, inputs, B, T, C, OptimizerRNG(), main_stream, rs->SideStreamEvent, rs->SideStream); Grads->notify_embeddings(main_stream, comm); // make sure all gradients are communicated before we go to the update step. @@ -483,11 +489,6 @@ void LLamaModel::_backward_lmhead(long B, long T, int micro_step, int grad_accum CUDA_CHECK(cudaStreamWaitEvent(main_stream, rs->SideStreamEvent, 0)); } - if (nano_step == 0) { - // BackwardDone ensures that zero-2 gradient accumulation of the previous step has finished, so we can safely write to d_lmhead again. - CUDA_CHECK(cudaEventSynchronize(rs->BackwardDone)); - } - // handle the LM-head. We run the d_lmhead matmul first, so that the gradient reduction can overlap with the DLNF matmul. bool accumulate; auto& d_lmhead = Grads->get_lmhead_full(main_stream, comm, accumulate); @@ -811,6 +812,8 @@ void LLamaModel::update(NCCLCommunicator& comm, float learning_rate, float beta_ void LLamaModel::allocate_run_state(const LLamaOptions& options, NCCLCommunicator& comm, int B, int T) { NVTX_RANGE_FN(); + std::vector> stack_watermark; + // create a dummy stack and simulate the way we're going to use temporaries later, to determine how much we need to allocate int dev; CUDA_CHECK(cudaGetDevice(&dev)); @@ -833,6 +836,7 @@ void LLamaModel::allocate_run_state(const LLamaOptions& options, NCCLCommunicato auto ctx = Allocator->with_context("Stack"); long required_size = stack.max_utilization(); acts.Stack = DeviceMemoryStack{Allocator->allocate(ETensorDType::BYTE, "stack", {required_size}).Data, (std::size_t)required_size, dev}; + acts.Stack.set_high_mark(stack.get_high_mark()); } { diff --git a/src/models/llama_run_state.cpp b/src/models/llama_run_state.cpp index 8938442..dcbc500 100644 --- a/src/models/llama_run_state.cpp +++ b/src/models/llama_run_state.cpp @@ -378,30 +378,30 @@ void LLamaRunState::init(LLamaConfig config, long B, long T, DeviceMemoryStack& bool use_fp8 = Options.grad_dtype() == ETensorDType::FP8_E4M3 || Options.grad_dtype() == ETensorDType::FP8_E5M2; auto bw_qmm = [&](int B, int T, int C, int OC) { if(use_fp8) { - auto wgt_tp = stack.allocate(ETensorDType::FP8_E4M3, {C, OC}); + auto wgt_tp = stack.allocate(ETensorDType::FP8_E4M3, {C, OC}, "wgt_tp"); stack.free(wgt_tp.Data); - auto act_tp = stack.allocate(ETensorDType::FP8_E4M3, {C, B * T}); - auto grd_tp = stack.allocate(Options.grad_dtype(), {OC, B * T}); + auto act_tp = stack.allocate(ETensorDType::FP8_E4M3, {C, B * T}, "act_tp"); + auto grd_tp = stack.allocate(Options.grad_dtype(), {OC, B * T}, "grd_tp"); stack.free(grd_tp.Data); stack.free(act_tp.Data); } }; // simulate to determine required stack size - auto ws = stack.allocate(CuDNNWorkspace.bytes()); - stack.free(stack.allocate(DActs[0].DQKV.Value.bytes())); // attention + auto ws = stack.allocate(CuDNNWorkspace.bytes(), "workspace"); + stack.free(stack.allocate(DActs[0].DQKV.Value.bytes(), "dqkv")); // attention stack.free(ws); // attention - auto dswi = stack.allocate(DActs[0].DSwiGLU.bytes()); + auto dswi = stack.allocate(DActs[0].DSwiGLU.bytes(), "dswiglu"); bw_qmm(B, T, H, C); // backward qmm swiglu stack.free(dswi); if(use_fp8) { - auto dupq = stack.allocate(DActs[0].DMlpUp.Quant->bytes()); + auto dupq = stack.allocate(DActs[0].DMlpUp.Quant->bytes(), "dup.q"); bw_qmm(B, T, C, 2 * H); // backward qmm up stack.free(dupq); } - stack.free(stack.allocate(Output.bytes())); // lm-head + stack.free(stack.allocate(Output.bytes(), "output")); // lm-head MatmulScales = alloc->allocate(ETensorDType::FP32, "mm_scales", {2}); diff --git a/src/models/llama_weights.cpp b/src/models/llama_weights.cpp index f9bb776..d1d618a 100644 --- a/src/models/llama_weights.cpp +++ b/src/models/llama_weights.cpp @@ -43,33 +43,33 @@ void matrix_params_from_stack(sLLamaBlockWeights& target, const LLa long C = config.HiddenSize; long H = config.IntermediateSize; - auto create_matrix_shard = [&](long rows, long cols) { - Tensor raw = memory.allocate(dtype, {div_exact(rows, (long)num_shards), cols}); + auto create_matrix_shard = [&](long rows, long cols, const char* name) { + Tensor raw = memory.allocate(dtype, {div_exact(rows, (long)num_shards), cols}, name); return TensorShard{raw, shard_idx, num_shards, std::vector{rows, cols}}; }; long head_size = C / config.NumQueryHeads; long attn_intermediate_size = (config.NumQueryHeads + 2 * config.NumKeyValHeads) * head_size; - target.Attn_QKV_w = create_matrix_shard(attn_intermediate_size, C); - target.Attn_Out_w = create_matrix_shard(C, C); - target.MLP_Up_w = create_matrix_shard(2 * H, C); - target.MLP_Down_w = create_matrix_shard(C, H); + target.Attn_QKV_w = create_matrix_shard(attn_intermediate_size, C, "Attn_QKV_w"); + target.Attn_Out_w = create_matrix_shard(C, C, "Attn_Out_w"); + target.MLP_Up_w = create_matrix_shard(2 * H, C, "MLP_Up_w"); + target.MLP_Down_w = create_matrix_shard(C, H, "MLP_Down_w"); } void non_matrix_params_from_stack(sLLamaBlockWeights& target, const LLamaConfig& config, ETensorDType dtype, int shard_idx, int num_shards, DeviceMemoryStack& memory) { long C = config.HiddenSize; long HS = config.head_size(); - auto create_vector_shard = [&](long elems) { - Tensor raw = memory.allocate(dtype, {div_exact(elems, (long)num_shards)}); + auto create_vector_shard = [&](long elems, const char* name) { + Tensor raw = memory.allocate(dtype, {div_exact(elems, (long)num_shards)}, name); return TensorShard{raw, shard_idx, num_shards, std::vector{elems}}; }; - target.LN1_w = create_vector_shard(C); - target.LN2_w = create_vector_shard(C); + target.LN1_w = create_vector_shard(C, "LN1_w"); + target.LN2_w = create_vector_shard(C, "LN2_w"); long attn_intermediate_size = (config.NumQueryHeads + 2 * config.NumKeyValHeads) * HS; if(config.UseQKVBias) { - target.Attn_QKV_b = create_vector_shard(attn_intermediate_size); + target.Attn_QKV_b = create_vector_shard(attn_intermediate_size, "Attn_QKV_b"); } else { target.Attn_QKV_b = std::nullopt; } diff --git a/src/training/logging.cpp b/src/training/logging.cpp index 070b595..41fa972 100644 --- a/src/training/logging.cpp +++ b/src/training/logging.cpp @@ -15,6 +15,7 @@ #include "utilities/gpu_info.h" #include "utilities/utils.h" #include "utilities/allocator.h" +#include "utilities/stack.h" #include "utilities/sol.h" #include @@ -320,7 +321,10 @@ void TrainingRunLogger::log_line(std::string_view line) { mFirst = false; } -void TrainingRunLogger::log_allocator(const std::vector>& stats) { +void TrainingRunLogger::log_allocator( + const std::vector>& stats, + const std::vector>& stack_info) +{ if (mRank != 0) return; std::string stat_str = "["; bool first = true; @@ -340,6 +344,15 @@ void TrainingRunLogger::log_allocator(const std::vector(amount / 1024 / 1024); + if(mib > 0) { + printf(" %16s: %6d \n", stack_name.c_str(), mib); + } + } + printf("\n"); } } diff --git a/src/training/logging.h b/src/training/logging.h index 26651f6..684e3dc 100644 --- a/src/training/logging.h +++ b/src/training/logging.h @@ -17,6 +17,7 @@ struct GPUUtilInfo; struct sSegmentMemory; class NCCLCommunicator; class DataLoader; +class DeviceMemoryStack; enum class ETensorDType : int; class TrainingRunLogger @@ -42,7 +43,10 @@ class TrainingRunLogger void log_step(int step, float epoch, int step_tokens, int duration_ms, float norm, float loss, float lr); void log_eval(int step, float epoch, int eval_tokens, int duration_ms, float loss); void log_gpu_state(int step, int gpu_id, const GPUUtilInfo& gpu_util); - void log_allocator(const std::vector>& stats); + void log_allocator( + const std::vector>& stats, + const std::vector>& stack_info + ); // call at the beginning and end of a section of processing. // will record the time between the two calls diff --git a/src/utilities/comm.cpp b/src/utilities/comm.cpp index 5660b8d..ed47755 100644 --- a/src/utilities/comm.cpp +++ b/src/utilities/comm.cpp @@ -339,7 +339,6 @@ std::unique_ptr NCCLCommunicator::make_mpi_communicator() { #include #include -#include class NCCLCommunicatorThreads : public NCCLCommunicator { public: @@ -476,7 +475,6 @@ std::unique_ptr NCCLCommunicator::launch_threads_commun fprintf(stderr, "WARNING: Failed to set CPU affinity for rank %d\n", i); } NCCLCommunicatorThreads comm(i, ngpus, memcpy_allgather, memcpy_send_recv, &nccl_id, bar); - nvtxNameOsThread(pthread_self(), "worker"); work(comm); bar->Barrier->arrive_and_wait(); } catch(...) { @@ -594,7 +592,7 @@ void NCCLCommunicatorThreads::on_finish_transaction(cudaEvent_t signal) { barrier(); // assumes _all_ workers have the same number of receives! for (int j = 0; j < world_size(); ++j) { if (j != rank()) { - cudaStreamWaitEvent(stream(), sync_events[j], 0); + CUDA_CHECK(cudaStreamWaitEvent(stream(), sync_events[j], 0)); } } } diff --git a/src/utilities/sol.cpp b/src/utilities/sol.cpp index dbbc872..8be33e1 100644 --- a/src/utilities/sol.cpp +++ b/src/utilities/sol.cpp @@ -7,7 +7,6 @@ #include #include #include -#include #include #include "utilities/dtype.h" @@ -319,6 +318,7 @@ std::vector> get_transformer_ops(long non_embeddin } cublasLtHandle_t create_cublaslt_handle(); +void destroy_cublaslt_handle(cublasLtHandle_t handle); double measure_real_peak() { nv_bfloat16* a; @@ -369,7 +369,7 @@ double measure_real_peak() { CUDA_CHECK(cudaFree(b)); CUDA_CHECK(cudaFree(c)); CUDA_CHECK(cudaFree(workspace)); - cublasLtDestroy(handle); + destroy_cublaslt_handle(handle); double ops_per_sec = ops_total / ms_total * 1000; return ops_per_sec; diff --git a/src/utilities/stack.cpp b/src/utilities/stack.cpp index 6838495..40772fc 100644 --- a/src/utilities/stack.cpp +++ b/src/utilities/stack.cpp @@ -10,7 +10,7 @@ DeviceMemoryStack::DeviceMemoryStack(std::byte* memory, std::size_t amount, int } -std::byte* DeviceMemoryStack::allocate(std::size_t amount) { +std::byte* DeviceMemoryStack::allocate(std::size_t amount, const char* name) { constexpr size_t alignment = 4096; std::size_t aligned_amount = div_ceil(amount, alignment) * alignment; std::byte* new_top = mTop + aligned_amount; @@ -18,28 +18,43 @@ std::byte* DeviceMemoryStack::allocate(std::size_t amount) { throw std::bad_alloc(); } - mAlloc.emplace_back(mTop, aligned_amount); + mAlloc.emplace_back(mTop, aligned_amount, name); mTop = new_top; - mMaxUtilization = std::max(mMaxUtilization, bytes_used()); - return mAlloc.back().first; + _track_max(); + return mAlloc.back().Pointer; } -Tensor DeviceMemoryStack::allocate(ETensorDType dtype, const std::vector& shape) { +Tensor DeviceMemoryStack::allocate(ETensorDType dtype, const std::vector& shape, const char* name) { std::size_t total = std::accumulate(std::begin(shape), std::end(shape), (long)get_dtype_size(dtype), std::multiplies<>()); - return Tensor::from_pointer(allocate(total), mDeviceID, dtype, shape); + return Tensor::from_pointer(allocate(total, name), mDeviceID, dtype, shape); } void DeviceMemoryStack::free(std::byte* ptr) { if(mAlloc.empty()) { throw std::logic_error("DeviceMemoryStack::free_left called with empty allocation list"); } - if(mAlloc.back().first != ptr) { + if(mAlloc.back().Pointer != ptr) { throw std::logic_error("DeviceMemoryStack::free_left called with wrong pointer"); } - mTop = mAlloc.back().first; + mTop = mAlloc.back().Pointer; mAlloc.pop_back(); } +std::vector> DeviceMemoryStack::get_allocation_stats() const { + std::vector> result; + for (auto& [ptr, amount, name]: get_high_mark()) { + result.emplace_back(name, amount); + } + return result; +} + +void DeviceMemoryStack::_track_max() { + if(bytes_used() > mMaxUtilization) { + mMaxUtilization = bytes_used(); + mHighMark = mAlloc; + } +} + std::size_t DeviceMemoryStack::unused_capacity() const { return mCapacity - (mTop - mBackingMemory); } diff --git a/src/utilities/stack.h b/src/utilities/stack.h index 7b59291..b6f2418 100644 --- a/src/utilities/stack.h +++ b/src/utilities/stack.h @@ -14,8 +14,8 @@ class DeviceMemoryStack { DeviceMemoryStack() = default; DeviceMemoryStack(std::byte* memory, std::size_t amount, int device_id); - std::byte* allocate(std::size_t amount); - Tensor allocate(ETensorDType dtype, const std::vector& shape); + std::byte* allocate(std::size_t amount, const char* name=""); + Tensor allocate(ETensorDType dtype, const std::vector& shape, const char* name=""); void free(std::byte* ptr); void free(Tensor& tensor); @@ -24,16 +24,31 @@ class DeviceMemoryStack { std::size_t bytes_used() const; std::size_t max_utilization() const; int device_id() const; + + struct sAllocRecord { + std::byte* Pointer; + std::size_t Amount; + const char* Name; + }; + using AllocationList = std::vector; + + const AllocationList& get_high_mark() const { return mHighMark; } + void set_high_mark(const AllocationList& list) { mHighMark = list; } + + std::vector> get_allocation_stats() const; + private: int mDeviceID; std::byte* mBackingMemory; std::byte* mTop; std::size_t mCapacity; - using AllocationList = std::vector>; + void _track_max(); + AllocationList mAlloc; std::size_t mMaxUtilization = 0; + std::vector mHighMark; }; #endif //LLMQ_SRC_UTILITIES_STACK_H diff --git a/src/utilities/tensor.cpp b/src/utilities/tensor.cpp index faa8c15..9927d89 100644 --- a/src/utilities/tensor.cpp +++ b/src/utilities/tensor.cpp @@ -9,7 +9,7 @@ #include -Tensor HOST_DEVICE slice(const Tensor& src, int dim, long start, long end) { +Tensor slice(const Tensor& src, int dim, long start, long end) { if (dim != 0) throw std::logic_error("Slices must be contiguous, so only the first dimension can be sliced."); diff --git a/src/utilities/utils.cpp b/src/utilities/utils.cpp index 0fe5da5..3983e81 100644 --- a/src/utilities/utils.cpp +++ b/src/utilities/utils.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include @@ -50,3 +51,7 @@ bool iequals(std::string_view lhs, std::string_view rhs) { return std::tolower(a) == std::tolower(b); }); } + +[[noreturn]] void throw_not_divisible(long long dividend, long long divisor) { + throw std::runtime_error(fmt::format("Cannot divide {} by {}", dividend, divisor)); +} diff --git a/src/utilities/utils.h b/src/utilities/utils.h index 638993f..f194a95 100644 --- a/src/utilities/utils.h +++ b/src/utilities/utils.h @@ -40,10 +40,13 @@ template constexpr T HOST_DEVICE div_ceil(T dividend, T divisor) { return (dividend + divisor - 1) / divisor; } + +[[noreturn]] void throw_not_divisible(long long dividend, long long divisor); + template constexpr T div_exact(T dividend, T divisor) { if(dividend % divisor != 0) { - throw std::runtime_error("Not divisible"); + throw_not_divisible(dividend, divisor); } return dividend / divisor; } diff --git a/train.cpp b/train.cpp index abd072b..8f5345e 100644 --- a/train.cpp +++ b/train.cpp @@ -60,6 +60,7 @@ struct TrainingRunner { float Beta2 = 0.95f; float GradClip = 1.0f; float WeightDecay = 0.1f; + float Epsilon = 1e-8f; int GradAccSteps = 4; bool FromScratch = false; @@ -136,6 +137,7 @@ void TrainingRunner::load_training_config(int argc, const char** argv) { app.add_option("--grad-accumulation", GradAccSteps, "number of micro-batches per optimizer step"); app.add_option("--grad-clip", GradClip, "Gradient clipping"); app.add_option("--weight-decay", WeightDecay, "Weight decay for matrix parameters"); + app.add_option("--adam-epsilon", Epsilon, "Epsilon to use for AdamW"); app.add_option("--steps", MaxSteps, "Number of training steps"); app.add_option("--log-gpu-util", LogGPUEvery, "Log the gpu utilization every n steps. Set to 0 to disable."); @@ -427,7 +429,7 @@ void TrainingRunner::run_training(int argc, const char** argv, NCCLCommunicator& logger.log_dataset(train_loader, test_loader); - logger.log_allocator(model.get_allocator().get_allocation_segments()); + logger.log_allocator(model.get_allocator().get_allocation_segments(), model.run_state().Stack.get_allocation_stats()); Tensor inputs = model.get_input_buffer(); Tensor targets = model.get_target_buffer(); @@ -478,7 +480,7 @@ void TrainingRunner::run_training(int argc, const char** argv, NCCLCommunicator& } float lr = lr_schedule->eval(step); - model.update(comm, lr, Beta1, Beta2, step + 1, 1e-8f, WeightDecay, GradClip); + model.update(comm, lr, Beta1, Beta2, step + 1, Epsilon, WeightDecay, GradClip); CUDA_CHECK(cudaDeviceSynchronize()); std::chrono::high_resolution_clock::time_point end = std::chrono::high_resolution_clock::now(); long ms = std::chrono::duration_cast(end - start).count(); diff --git a/uv.lock b/uv.lock index e82fdf5..b8d0142 100644 --- a/uv.lock +++ b/uv.lock @@ -1816,7 +1816,7 @@ wheels = [ [[package]] name = "pyllmq" -version = "0.2.3" +version = "0.3.0" source = { editable = "." } dependencies = [ { name = "numpy" },