Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pyllmq"
version = "0.3.0"
version = "0.3.2"
description = "Python wrappers for LLMQ. LLM pretraining written in CUDA."
readme = "README.md"
requires-python = ">=3.12"
Expand Down
2 changes: 1 addition & 1 deletion scripts/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def main():

print("\nmemory consumption:")
for k, v in trainer.get_allocator_info(0).items():
print(f" {k:20}: {v // 1024 // 1024:6} MiB")
print(f" {k:20}: {v['device'] // 1024 // 1024:6} MiB")

train_loader.load_batch(in_tokens, out_tokens)

Expand Down
2 changes: 1 addition & 1 deletion scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def main():

# Log allocator stats
for idx in range(config.gpus):
logger.log_allocator(trainer.get_allocator_info(idx))
logger.log_allocator(trainer, idx)

# calculate the expected time at peak flops for speed-of-light estimation
logger.set_expected_time_per_token(trainer)
Expand Down
25 changes: 19 additions & 6 deletions src/binding/binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ NB_MODULE(_pyllmq, m) {

nb::class_<MultiGPUPyTrainer>(m, "LLMQTrainer")
.def("__init__", [](MultiGPUPyTrainer *t, int ngpu, LLamaConfig config, LLamaOptions options, int batch_size, int seq_len, int grad_accum, bool memcpy_all_gather, bool memcpy_send_recv) {
options.ModelType = config.DType;
new (t) MultiGPUPyTrainer(ngpu, config, options, batch_size, seq_len, grad_accum, memcpy_all_gather, memcpy_send_recv);
}, nb::arg("ngpu"), nb::arg("config"), nb::arg("options"), nb::arg("batch_size"), nb::arg("seq_len"), nb::arg("grad_accum"),
nb::arg("memcpy_all_gather") = true, nb::arg("memcpy_send_recv") = true)
Expand Down Expand Up @@ -337,6 +338,13 @@ NB_MODULE(_pyllmq, m) {
res["pageable"] = size.PageableHost;
ret[nb::cast(name)] = res;
}

auto stack = trainer->get_stack_info(gpu_id);
for (const auto& [name, size] : stack) {
nb::dict res;
res["stack"] = size;
ret[nb::cast(name)] = res;
}
return ret;
}, nb::arg("gpu_id") = 0, "Get the current memory allocations for the given GPU")
;
Expand Down Expand Up @@ -440,17 +448,22 @@ NB_MODULE(_pyllmq, m) {
"Log GPU utilization state")
.def("log_allocator", [](TrainingRunLogger* logger, const nb::dict& stats) {
std::vector<std::pair<std::string, sSegmentMemory>> cpp_stats;
std::vector<std::pair<std::string, long>> cpp_stack;
cpp_stats.reserve(stats.size());
for (auto item : stats) {
std::string key = nb::cast<std::string>(item.first);
nb::dict value = nb::cast<nb::dict>(item.second);
long device = nb::cast<long>(value["device"]);
long managed = nb::cast<long>(value["managed"]);
long pinned = nb::cast<long>(value["pinned"]);
long pageable = nb::cast<long>(value["pageable"]);
cpp_stats.emplace_back(key, sSegmentMemory{device, managed, pinned, pageable});
if (value.contains("stack")) {
cpp_stack.emplace_back(key, nb::cast<long>(value["stack"]));
} else {
long device = nb::cast<long>(value["device"]);
long managed = nb::cast<long>(value["managed"]);
long pinned = nb::cast<long>(value["pinned"]);
long pageable = nb::cast<long>(value["pageable"]);
cpp_stats.emplace_back(key, sSegmentMemory{device, managed, pinned, pageable});
}
}
logger->log_allocator(cpp_stats);
logger->log_allocator(cpp_stats, cpp_stack);
}, nb::arg("stats"), "Log memory allocator statistics")
.def("set_expected_time_per_token", [](TrainingRunLogger* logger, const MultiGPUPyTrainer* trainer){
auto& config = trainer->config();
Expand Down
9 changes: 9 additions & 0 deletions src/binding/py_train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "utilities/comm.h"
#include "kernels/kernels.h"
#include "models/llama_gradients.h"
#include "models/llama_run_state.h"

MultiGPUPyTrainer::MultiGPUPyTrainer(int ngpus, LLamaConfig config, LLamaOptions options, int batch_size, int seq_len, int grad_accum, bool memcpy_all_gather, bool memcpy_send_recv) :
mConfig(config), mOptions(options), B(batch_size), T(seq_len), mGradAccumulation(grad_accum)
Expand Down Expand Up @@ -249,6 +250,14 @@ std::vector<std::pair<std::string, sSegmentMemory>> MultiGPUPyTrainer::get_alloc
return result;
}

std::vector<std::pair<std::string, long>> MultiGPUPyTrainer::get_stack_info(int gpu_id) {
std::vector<std::pair<std::string, long>> result;
run_work([&result](sThreadContext& ctx) {
result = ctx.Model->run_state().Stack.get_allocation_stats();
}, gpu_id);
return result;
}

std::vector<std::pair<std::string, Tensor>> MultiGPUPyTrainer::get_gradients(int gpu_id) {
std::vector<std::pair<std::string, Tensor>> result;
run_work([&result](sThreadContext& ctx) {
Expand Down
1 change: 1 addition & 0 deletions src/binding/py_train.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class MultiGPUPyTrainer
const LLamaOptions& options() const { return mOptions; }

std::vector<std::pair<std::string, sSegmentMemory>> get_allocations(int gpu_id);
std::vector<std::pair<std::string, long>> get_stack_info(int gpu_id);
std::vector<std::pair<std::string, Tensor>> get_gradients(int gpu_id);

private:
Expand Down
1 change: 1 addition & 0 deletions src/binding/python/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ def training_logger_context(config: TrainingConfig):
logger.log_cmd(sys.argv)
log_options = asdict(config)
log_options["matmul_dtype"] = log_options["matmul_dtype"] or config.model_dtype
log_options["gradient_dtype"] = log_options["gradient_dtype"] or log_options["matmul_dtype"]
log_options["verbosity"] = str(config.verbosity)
logger.log_options(log_options)
yield logger
13 changes: 6 additions & 7 deletions src/kernels/adamw.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
//
// Based on llm.c https://github.com/karpathy/llm.c

#include <cuda/atomic>
#include <algorithm>

#include "squirrel_noise.cuh"
Expand Down Expand Up @@ -37,7 +36,7 @@ __device__ GenericVector<float, VecElems> load_vector(const FloatIn* memory, con
}

template <typename FloatOut, std::size_t VecElems>
__device__ void store_vector(FloatOut* memory, const GenericVector<float, VecElems>& in, float* scales, long idx) {
__device__ void store_vector(FloatOut* memory, const GenericVector<float, VecElems>& in, float* scales, long idx, unsigned active_mask) {
float factor = 1.f;
unsigned int rng = get_noise_2d(idx, blockIdx.y, 51245);
if(scales != nullptr) {
Expand All @@ -47,9 +46,8 @@ __device__ void store_vector(FloatOut* memory, const GenericVector<float, VecEle
}
constexpr int Threads = 128 / VecElems;
static_assert(Threads <= 32, "#threads > warp size");
unsigned mask = __activemask();
for(int i = 1; i <= Threads; i *= 2) {
abs_max = std::max(abs_max, __shfl_xor_sync(mask, abs_max, i, Threads));
for(int i = 1; i < Threads; i *= 2) {
abs_max = std::max(abs_max, __shfl_xor_sync(active_mask, abs_max, i, Threads));
}
if(abs_max > 1e-10f) {
factor = 448.f / abs_max;
Expand Down Expand Up @@ -77,6 +75,7 @@ __device__ auto adamw_update(floatX* params_memory, const floatX* grads_memory,
using vec_m_t = GenericVector<floatM, VecElems>;
using vec_v_t = GenericVector<floatV, VecElems>;

const unsigned active_mask = __ballot_sync(0xffffffff, idx < num_parameters);
if (idx >= num_parameters) { return vec_x_t::zeros(); } // guard

vec_f_t m = load_vector<VecElems>(m_memory, m_scales, idx);
Expand Down Expand Up @@ -108,7 +107,7 @@ __device__ auto adamw_update(floatX* params_memory, const floatX* grads_memory,
}

p_new.store(params_memory + idx);
store_vector(m_memory, m, m_scales, idx);
store_vector(m_memory, m, m_scales, idx, active_mask);
v_new.store(v_memory + idx);

return p_new;
Expand All @@ -122,7 +121,7 @@ __global__ void adamw_kernel(floatX* params_memory, const floatX* grads_memory,
using vec_x_t = GenericVector<floatX, VecElems>;
__shared__ float block_abs_max;
if(threadIdx.x == 0) {
block_abs_max = 1e-10f;
block_abs_max = 0.f;
}

float thread_abs_max = 0.0f;
Expand Down
32 changes: 16 additions & 16 deletions src/kernels/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
#include <cmath>

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>

#include "kernels/kernels.h"
#include "utilities/tensor.h"
#include "utilities/vec.cuh"
#include "kernel_utils.cuh"

namespace cg = cooperative_groups;

Expand All @@ -22,10 +21,11 @@ __global__ void __launch_bounds__(512) attention_forward_gpu_kernel(
scalar_t* out, float* stats, float scale,
const scalar_t* qkv,
int B, int T, int Hq, int Hkv) {
constexpr const int SubWarpSize = 16;

auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block);
auto sub_warp = cg::tiled_partition<16>(block);
auto sub_warp = cg::tiled_partition<SubWarpSize>(block);

extern __shared__ float scratch[];

Expand All @@ -43,19 +43,19 @@ __global__ void __launch_bounds__(512) attention_forward_gpu_kernel(

using vec_t = GenericVector<scalar_t, 4>;
using fvec_t = GenericVector<float, 4>;
using q_cache_t = GenericVector<float, E / sub_warp.size()>;
using q_cache_t = GenericVector<float, E / SubWarpSize>;
q_cache_t q_cache;

// combine values
using v_cache_t = GenericVector<float, E / sub_warp.size()>;
using v_cache_t = GenericVector<float, E / SubWarpSize>;
v_cache_t v_cache = v_cache_t::zeros();

// determine maximum and online logsumexp
float maximum = std::numeric_limits<float>::lowest();
float lse = 0;

for (int ee = 0; ee < E / (sub_warp.size() * vec_t::size); ++ee) {
int e = (ee * sub_warp.size() + sub_warp.thread_rank()) * vec_t::size;
for (int ee = 0; ee < E / (SubWarpSize * vec_t::size); ++ee) {
int e = (ee * SubWarpSize + sub_warp.thread_rank()) * vec_t::size;
vec_t qv = vec_t::load(query + e);
for (int j = 0; j < vec_t::size; ++j) {
q_cache[ee * vec_t::size + j] = (float)qv[j];
Expand All @@ -65,14 +65,14 @@ __global__ void __launch_bounds__(512) attention_forward_gpu_kernel(
for (int l = sub_warp.meta_group_rank(); l <= t; l += sub_warp.meta_group_size()) {
ptrdiff_t kv_offset = l * TH * E;
float qk = 0;
for (int ee = 0; ee < E / (sub_warp.size() * vec_t::size); ++ee) {
int e = (ee * sub_warp.size() + sub_warp.thread_rank()) * vec_t::size;
for (int ee = 0; ee < E / (SubWarpSize * vec_t::size); ++ee) {
int e = (ee * SubWarpSize + sub_warp.thread_rank()) * vec_t::size;
vec_t kv = vec_t::load(keys + kv_offset + e);
for (int j = 0; j < vec_t::size; ++j) {
qk += q_cache[ee * vec_t::size + j] * (float)kv[j];
}
}
qk = cg::reduce(sub_warp, qk, cg::plus<float>{});
qk = reduce_group_add(sub_warp, qk);
if (qk > maximum) {
float rescale = std::exp(scale * (maximum - qk));
for (int j = 0; j < v_cache_t::size; ++j) {
Expand All @@ -84,8 +84,8 @@ __global__ void __launch_bounds__(512) attention_forward_gpu_kernel(
float att = std::exp(scale * (qk - maximum));
lse += std::exp(scale * (qk - maximum));

for (int ee = 0; ee < E / (sub_warp.size() * vec_t::size); ++ee) {
int e = (ee * sub_warp.size() + sub_warp.thread_rank()) * vec_t::size;
for (int ee = 0; ee < E / (SubWarpSize * vec_t::size); ++ee) {
int e = (ee * SubWarpSize + sub_warp.thread_rank()) * vec_t::size;
vec_t vv = vec_t::load(values + kv_offset + e);
for (int j = 0; j < vec_t::size; ++j) {
v_cache[ee * vec_t::size + j] += att * (float)vv[j];
Expand All @@ -108,9 +108,9 @@ __global__ void __launch_bounds__(512) attention_forward_gpu_kernel(
r_lse = scratch[warp.thread_rank() + sub_warp.meta_group_size()];
}

maximum = cg::reduce(warp, r_max, cg::greater<float>{});
maximum = reduce_group_max(warp, r_max);
r_lse *= std::exp(scale * (r_max - maximum));
lse = cg::reduce(warp, r_lse, cg::plus<float>{});
lse = reduce_group_add(warp, r_lse);
float rescale = std::exp(scale * (l_max - maximum)) / lse;
for (int j = 0; j < v_cache_t::size; ++j) {
v_cache[j] *= rescale;
Expand All @@ -120,8 +120,8 @@ __global__ void __launch_bounds__(512) attention_forward_gpu_kernel(
}
__syncthreads();

for (int ee = 0; ee < E / (sub_warp.size() * vec_t::size); ++ee) {
int e = (ee * sub_warp.size() + sub_warp.thread_rank()) * vec_t::size;
for (int ee = 0; ee < E / (SubWarpSize * vec_t::size); ++ee) {
int e = (ee * SubWarpSize + sub_warp.thread_rank()) * vec_t::size;
fvec_t store;
for (int j = 0; j < vec_t::size; ++j) {
store[j] = v_cache[ee * vec_t::size + j];
Expand Down
4 changes: 2 additions & 2 deletions src/kernels/bias.cu
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ int get_bias_backward_scratch_size(ETensorDType dtype, int OC, const cudaDeviceP
const int block_size = dp.maxThreadsPerMultiProcessor == 1536 ? 768 : 1024;
const int OC_per_warp = 8 * ( 16 / get_dtype_size(dtype) ); // 64 at BF16
const int grid_size_x = div_ceil(OC, OC_per_warp); // e.g. 12 horizontal blocks for 768 OCs at BF16
const int grid_size_y = max(1, block_size * dp.multiProcessorCount / (block_size * grid_size_x)); // full GPU!
const int grid_size_y = std::max(1, block_size * dp.multiProcessorCount / (block_size * grid_size_x)); // full GPU!
return grid_size_y * OC * sizeof(float);
}

Expand All @@ -155,7 +155,7 @@ void backward_bias_imp(floatX* dbias, const FloatY* dout, const float* scale_a,
dim3 block_dim = {4, 8, (unsigned)block_size/32};
const int OC_per_warp = block_dim.y * x128::size; // 64 at BF16
const int grid_size_x = div_ceil(OC, OC_per_warp); // e.g. 12 horizontal blocks for 768 OCs at BF16
const int grid_size_y = max(1, block_size * dp.multiProcessorCount / (block_size * grid_size_x)); // full GPU!
const int grid_size_y = std::max(1, block_size * dp.multiProcessorCount / (block_size * grid_size_x)); // full GPU!

if( (scale_a == nullptr) != (scale_b == nullptr) ) {
throw std::logic_error("backward_bias: scale_a and scale_b must be both nullptr or both non-nullptr");
Expand Down
19 changes: 10 additions & 9 deletions src/kernels/encoder.cu
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ template<class floatX>
void encoder_backward_imp(floatX* dwte, int* scratch, // gpu outputs & scratch
int* workload_indices, int4* bucket_info, // cpu scratch buffers
const floatX* dout, const int* inp, const int* inputs_cpu, // cpu/gpu inputs
int B, int T, int C, unsigned int seed, cudaStream_t stream) {
int B, int T, int C, unsigned int seed, cudaStream_t stream, cudaEvent_t sync_event, cudaStream_t copy_stream) {
using x128 = GenericVector<floatX, 16/sizeof(floatX)>;

int num_c_groups = div_ceil((size_t)C, x128::size * 32);
Expand Down Expand Up @@ -220,12 +220,13 @@ void encoder_backward_imp(floatX* dwte, int* scratch, // gpu outputs & scratch
bucket_index++;
}

// Step 3: Copy data from host to device (async until the last one to avoid synchronising CPU/GPU twice)
// todo - could use CUDA events (even without streams) to avoid CPU/GPU synchronisation completely
// Step 3: Copy data from host to device (async on a different stream)
int4* d_bucket_info = (int4*)scratch;
int* d_workload_indices = (int*)(scratch + B*T*num_c_groups * 4);
CUDA_CHECK(cudaMemcpyAsync(d_bucket_info, bucket_info, num_buckets * sizeof(int4), cudaMemcpyHostToDevice, stream));
CUDA_CHECK(cudaMemcpyAsync(d_workload_indices, workload_indices, total_items * sizeof(int), cudaMemcpyHostToDevice, stream));
CUDA_CHECK(cudaMemcpyAsync(d_bucket_info, bucket_info, num_buckets * sizeof(int4), cudaMemcpyHostToDevice, copy_stream));
CUDA_CHECK(cudaMemcpyAsync(d_workload_indices, workload_indices, total_items * sizeof(int), cudaMemcpyHostToDevice, copy_stream));
CUDA_CHECK(cudaEventRecord(sync_event, copy_stream));
CUDA_CHECK(cudaStreamWaitEvent(stream, sync_event, 0));

// Launch wte kernel
// todo - profile block sizes on more content (depends on number of buckets and on GPU?)
Expand All @@ -236,13 +237,13 @@ void encoder_backward_imp(floatX* dwte, int* scratch, // gpu outputs & scratch
void encoder_backward(float* dwte, int* scratch, // gpu outputs & scratch
int* workload_indices, int4* bucket_info, // cpu scratch buffers
const float* dout, const int* inp, const int* inputs_cpu, // cpu/gpu inputs
int B, int T, int C, unsigned int seed, cudaStream_t stream) {
encoder_backward_imp(dwte, scratch, workload_indices, bucket_info, dout, inp, inputs_cpu, B, T, C, seed, stream);
int B, int T, int C, unsigned int seed, cudaStream_t stream, cudaEvent_t sync_event, cudaStream_t copy_stream) {
encoder_backward_imp(dwte, scratch, workload_indices, bucket_info, dout, inp, inputs_cpu, B, T, C, seed, stream, sync_event, copy_stream);
}

void encoder_backward(nv_bfloat16* dwte, int* scratch, // gpu outputs & scratch
int* workload_indices, int4* bucket_info, // cpu scratch buffers
const nv_bfloat16* dout, const int* inp, const int* inputs_cpu, // cpu/gpu inputs
int B, int T, int C, unsigned int seed, cudaStream_t stream) {
encoder_backward_imp(dwte, scratch, workload_indices, bucket_info, dout, inp, inputs_cpu, B, T, C, seed, stream);
int B, int T, int C, unsigned int seed, cudaStream_t stream, cudaEvent_t sync_event, cudaStream_t copy_stream) {
encoder_backward_imp(dwte, scratch, workload_indices, bucket_info, dout, inp, inputs_cpu, B, T, C, seed, stream, sync_event, copy_stream);
}
17 changes: 1 addition & 16 deletions src/kernels/fused_classifier.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,13 @@

#include <cassert>

#include "kernel_utils.cuh"
#include "utilities/utils.h"
#include "utilities/vec.cuh"

// ----------------------------------------------------------------------------
// CUDA kernels

// warp-level reduction for finding the maximum value
__device__ inline float warpReduceMax(float val) {
for (int offset = 16; offset > 0; offset /= 2) {
val = fmaxf(val, __shfl_xor_sync(0xFFFFFFFF, val, offset));
}
return val;
}

__device__ inline float warpReduceSum(float val) {
for (int offset = 16; offset > 0; offset /= 2) {
val += __shfl_xor_sync(0xFFFFFFFF, val, offset);
}
return val;
}


// requires all 32 threads in the warp to be active, but should work for any block size
// uses non-dynamic shared memory so every call increases shared memory requirements by 128 bytes
// the fact it's unique shared memory allows us to avoid an extra __syncthreads() call at the end
Expand Down
Loading