Skip to content

Commit 7b7f00b

Browse files
committed
improved asynchrony for input/target transfers
1 parent 17b0575 commit 7b7f00b

File tree

4 files changed

+32
-20
lines changed

4 files changed

+32
-20
lines changed

src/kernels/encoder.cu

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ template<class floatX>
178178
void encoder_backward_imp(floatX* dwte, int* scratch, // gpu outputs & scratch
179179
int* workload_indices, int4* bucket_info, // cpu scratch buffers
180180
const floatX* dout, const int* inp, const int* inputs_cpu, // cpu/gpu inputs
181-
int B, int T, int C, unsigned int seed, cudaStream_t stream) {
181+
int B, int T, int C, unsigned int seed, cudaStream_t stream, cudaEvent_t sync_event, cudaStream_t copy_stream) {
182182
using x128 = GenericVector<floatX, 16/sizeof(floatX)>;
183183

184184
int num_c_groups = div_ceil((size_t)C, x128::size * 32);
@@ -220,12 +220,13 @@ void encoder_backward_imp(floatX* dwte, int* scratch, // gpu outputs & scratch
220220
bucket_index++;
221221
}
222222

223-
// Step 3: Copy data from host to device (async until the last one to avoid synchronising CPU/GPU twice)
224-
// todo - could use CUDA events (even without streams) to avoid CPU/GPU synchronisation completely
223+
// Step 3: Copy data from host to device (async on a different stream)
225224
int4* d_bucket_info = (int4*)scratch;
226225
int* d_workload_indices = (int*)(scratch + B*T*num_c_groups * 4);
227-
CUDA_CHECK(cudaMemcpyAsync(d_bucket_info, bucket_info, num_buckets * sizeof(int4), cudaMemcpyHostToDevice, stream));
228-
CUDA_CHECK(cudaMemcpyAsync(d_workload_indices, workload_indices, total_items * sizeof(int), cudaMemcpyHostToDevice, stream));
226+
CUDA_CHECK(cudaMemcpyAsync(d_bucket_info, bucket_info, num_buckets * sizeof(int4), cudaMemcpyHostToDevice, copy_stream));
227+
CUDA_CHECK(cudaMemcpyAsync(d_workload_indices, workload_indices, total_items * sizeof(int), cudaMemcpyHostToDevice, copy_stream));
228+
CUDA_CHECK(cudaEventRecord(sync_event, copy_stream));
229+
CUDA_CHECK(cudaStreamWaitEvent(stream, sync_event, 0));
229230

230231
// Launch wte kernel
231232
// todo - profile block sizes on more content (depends on number of buckets and on GPU?)
@@ -236,13 +237,13 @@ void encoder_backward_imp(floatX* dwte, int* scratch, // gpu outputs & scratch
236237
void encoder_backward(float* dwte, int* scratch, // gpu outputs & scratch
237238
int* workload_indices, int4* bucket_info, // cpu scratch buffers
238239
const float* dout, const int* inp, const int* inputs_cpu, // cpu/gpu inputs
239-
int B, int T, int C, unsigned int seed, cudaStream_t stream) {
240-
encoder_backward_imp(dwte, scratch, workload_indices, bucket_info, dout, inp, inputs_cpu, B, T, C, seed, stream);
240+
int B, int T, int C, unsigned int seed, cudaStream_t stream, cudaEvent_t sync_event, cudaStream_t copy_stream) {
241+
encoder_backward_imp(dwte, scratch, workload_indices, bucket_info, dout, inp, inputs_cpu, B, T, C, seed, stream, sync_event, copy_stream);
241242
}
242243

243244
void encoder_backward(nv_bfloat16* dwte, int* scratch, // gpu outputs & scratch
244245
int* workload_indices, int4* bucket_info, // cpu scratch buffers
245246
const nv_bfloat16* dout, const int* inp, const int* inputs_cpu, // cpu/gpu inputs
246-
int B, int T, int C, unsigned int seed, cudaStream_t stream) {
247-
encoder_backward_imp(dwte, scratch, workload_indices, bucket_info, dout, inp, inputs_cpu, B, T, C, seed, stream);
247+
int B, int T, int C, unsigned int seed, cudaStream_t stream, cudaEvent_t sync_event, cudaStream_t copy_stream) {
248+
encoder_backward_imp(dwte, scratch, workload_indices, bucket_info, dout, inp, inputs_cpu, B, T, C, seed, stream, sync_event, copy_stream);
248249
}

src/kernels/kernels.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,17 @@ void encoder_forward(Tensor& out, const Tensor& inp, const Tensor& wte, std::opt
116116
void encoder_backward(Tensor& dwte, Tensor& scratch,
117117
Tensor& workload_indices, Tensor& bucket_info,
118118
const Tensor& dout, const Tensor& inp, const Tensor& inputs_cpu,
119-
int B, int T, int C, unsigned int seed, cudaStream_t stream) {
119+
int B, int T, int C, unsigned int seed, cudaStream_t stream, cudaEvent_t sync_event, cudaStream_t copy_stream) {
120120
assert(workload_indices.Device == -1);
121121
assert(bucket_info.Device == -1);
122122
if(dwte.DType == ETensorDType::FP32) {
123-
encoder_backward(dwte.get<float>(), scratch.get<int>(), workload_indices.get<int>(), (int4*)bucket_info.get<int>(), dout.get<float>(), inp.get<std::int32_t>(), inputs_cpu.get<std::int32_t>(), B, T, C, seed, stream);
123+
encoder_backward(dwte.get<float>(), scratch.get<int>(), workload_indices.get<int>(),
124+
(int4*)bucket_info.get<int>(), dout.get<float>(), inp.get<std::int32_t>(), inputs_cpu.get<std::int32_t>(),
125+
B, T, C, seed, stream, sync_event, copy_stream);
124126
} else if(dwte.DType == ETensorDType::BF16) {
125-
encoder_backward(dwte.get<nv_bfloat16>(), scratch.get<int>(), workload_indices.get<int>(), (int4*)bucket_info.get<int>(), dout.get<nv_bfloat16>(), inp.get<std::int32_t>(), inputs_cpu.get<std::int32_t>(), B, T, C, seed, stream);
127+
encoder_backward(dwte.get<nv_bfloat16>(), scratch.get<int>(), workload_indices.get<int>(),
128+
(int4*)bucket_info.get<int>(), dout.get<nv_bfloat16>(), inp.get<std::int32_t>(), inputs_cpu.get<std::int32_t>(),
129+
B, T, C, seed, stream, sync_event, copy_stream);
126130
} else {
127131
throw std::logic_error("encoder_backward: unsupported dtype");
128132
}

src/kernels/kernels.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,18 @@ void encoder_forward(Tensor& out, const Tensor& inp, const Tensor& wte, std::opt
2828
void encoder_backward(float* dwte, int* scratch,
2929
int* workload_indices, int4* bucket_info,
3030
const float* dout, const int* inp, const int* inputs_cpu,
31-
int B, int T, int C, unsigned int seed, cudaStream_t stream);
31+
int B, int T, int C, unsigned int seed, cudaStream_t stream, cudaEvent_t sync_event, cudaStream_t copy_stream);
3232
void encoder_backward(nv_bfloat16* dwte, int* scratch,
3333
int* workload_indices, int4* bucket_info,
3434
const nv_bfloat16* dout, const int* inp, const int* inputs_cpu,
35-
int B, int T, int C, unsigned int seed, cudaStream_t stream);
35+
int B, int T, int C, unsigned int seed, cudaStream_t stream, cudaEvent_t sync_event, cudaStream_t copy_stream);
36+
37+
// The kernel runs on `stream`, but the bucket info that gets generated on CPU to enable efficient determinism
38+
// can be copied using `copy_stream`, so the kernel launch does not have to wait.
3639
void encoder_backward(Tensor& dwte, Tensor& scratch,
3740
Tensor& workload_indices, Tensor& bucket_info,
3841
const Tensor& dout, const Tensor& inp, const Tensor& inputs_cpu,
39-
int B, int T, int C, unsigned int seed, cudaStream_t stream);
42+
int B, int T, int C, unsigned int seed, cudaStream_t stream, cudaEvent_t sync_event, cudaStream_t copy_stream);
4043

4144
void rmsnorm_forward(float* out, float* rms, const float* inp, const float* weight, float* abs_max_ptr, float epsilon, int B, int T, int C, cudaStream_t stream);
4245
void rmsnorm_forward(nv_bfloat16* out, float* rms, const nv_bfloat16* inp, const nv_bfloat16* weight, float* abs_max_ptr, float epsilon, int B, int T, int C, cudaStream_t stream);

src/models/llama_model.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,10 @@ void LLamaModel::forward(Tensor inputs, NCCLCommunicator& comm, int micro_step)
100100
assert(inputs.Device == -1);
101101
{
102102
NvtxRange r{"copy-input"};
103-
CUDA_CHECK(cudaMemcpyAsync(rs->Inputs.Data, inputs.Data, inputs.bytes(), cudaMemcpyHostToDevice, main_stream));
104-
CUDA_CHECK(cudaEventRecord(rs->TransferDone, main_stream));
103+
// by running copy-input on side stream, it can overlap with the previous backward pass.
104+
CUDA_CHECK(cudaMemcpyAsync(rs->Inputs.Data, inputs.Data, inputs.bytes(), cudaMemcpyHostToDevice, rs->SideStream));
105+
CUDA_CHECK(cudaEventRecord(rs->TransferDone, rs->SideStream));
106+
CUDA_CHECK(cudaStreamWaitEvent(main_stream, rs->TransferDone, 0));
105107
}
106108

107109
{
@@ -347,8 +349,10 @@ void LLamaModel::backward(Tensor inputs, Tensor targets, NCCLCommunicator& comm,
347349
const size_t C = Config.HiddenSize;
348350
const size_t L = Config.NumLayers;
349351

350-
CUDA_CHECK(cudaMemcpyAsync(rs->Targets.Data, targets.Data, targets.bytes(), cudaMemcpyHostToDevice, main_stream));
351-
CUDA_CHECK(cudaEventRecord(rs->TransferDone, main_stream));
352+
// copy on side stream so copy can start earlier
353+
CUDA_CHECK(cudaMemcpyAsync(rs->Targets.Data, targets.Data, targets.bytes(), cudaMemcpyHostToDevice, rs->SideStream));
354+
CUDA_CHECK(cudaEventRecord(rs->TransferDone, rs->SideStream));
355+
CUDA_CHECK(cudaStreamWaitEvent(main_stream, rs->TransferDone, 0));
352356

353357
bool last_step = micro_step == grad_accum_steps - 1;
354358
// on the first micro-step zero the gradients, as we're about to += accumulate into them
@@ -426,7 +430,7 @@ void LLamaModel::backward(Tensor inputs, Tensor targets, NCCLCommunicator& comm,
426430

427431
auto& d_emb = Grads->get_embeddings_full(main_stream, comm, accumulate);
428432
encoder_backward(d_emb, rs->EncoderBwdScratch, rs->EncoderBwdIndices, rs->EncoderBwdInfo,
429-
rs->DEmb, rs->Inputs, inputs, B, T, C, OptimizerRNG(), main_stream);
433+
rs->DEmb, rs->Inputs, inputs, B, T, C, OptimizerRNG(), main_stream, rs->SideStreamEvent, rs->SideStream);
430434
Grads->notify_embeddings(main_stream, comm);
431435

432436
// make sure all gradients are communicated before we go to the update step.

0 commit comments

Comments
 (0)