@@ -178,7 +178,7 @@ template<class floatX>
178178void encoder_backward_imp (floatX* dwte, int * scratch, // gpu outputs & scratch
179179 int * workload_indices, int4 * bucket_info, // cpu scratch buffers
180180 const floatX* dout, const int * inp, const int * inputs_cpu, // cpu/gpu inputs
181- int B, int T, int C, unsigned int seed, cudaStream_t stream) {
181+ int B, int T, int C, unsigned int seed, cudaStream_t stream, cudaEvent_t sync_event, cudaStream_t copy_stream ) {
182182 using x128 = GenericVector<floatX, 16 /sizeof (floatX)>;
183183
184184 int num_c_groups = div_ceil ((size_t )C, x128::size * 32 );
@@ -220,12 +220,13 @@ void encoder_backward_imp(floatX* dwte, int* scratch, // gpu outputs & scratch
220220 bucket_index++;
221221 }
222222
223- // Step 3: Copy data from host to device (async until the last one to avoid synchronising CPU/GPU twice)
224- // todo - could use CUDA events (even without streams) to avoid CPU/GPU synchronisation completely
223+ // Step 3: Copy data from host to device (async on a different stream)
225224 int4 * d_bucket_info = (int4 *)scratch;
226225 int * d_workload_indices = (int *)(scratch + B*T*num_c_groups * 4 );
227- CUDA_CHECK (cudaMemcpyAsync (d_bucket_info, bucket_info, num_buckets * sizeof (int4 ), cudaMemcpyHostToDevice, stream));
228- CUDA_CHECK (cudaMemcpyAsync (d_workload_indices, workload_indices, total_items * sizeof (int ), cudaMemcpyHostToDevice, stream));
226+ CUDA_CHECK (cudaMemcpyAsync (d_bucket_info, bucket_info, num_buckets * sizeof (int4 ), cudaMemcpyHostToDevice, copy_stream));
227+ CUDA_CHECK (cudaMemcpyAsync (d_workload_indices, workload_indices, total_items * sizeof (int ), cudaMemcpyHostToDevice, copy_stream));
228+ CUDA_CHECK (cudaEventRecord (sync_event, copy_stream));
229+ CUDA_CHECK (cudaStreamWaitEvent (stream, sync_event, 0 ));
229230
230231 // Launch wte kernel
231232 // todo - profile block sizes on more content (depends on number of buckets and on GPU?)
@@ -236,13 +237,13 @@ void encoder_backward_imp(floatX* dwte, int* scratch, // gpu outputs & scratch
236237void encoder_backward (float * dwte, int * scratch, // gpu outputs & scratch
237238 int * workload_indices, int4 * bucket_info, // cpu scratch buffers
238239 const float * dout, const int * inp, const int * inputs_cpu, // cpu/gpu inputs
239- int B, int T, int C, unsigned int seed, cudaStream_t stream) {
240- encoder_backward_imp (dwte, scratch, workload_indices, bucket_info, dout, inp, inputs_cpu, B, T, C, seed, stream);
240+ int B, int T, int C, unsigned int seed, cudaStream_t stream, cudaEvent_t sync_event, cudaStream_t copy_stream ) {
241+ encoder_backward_imp (dwte, scratch, workload_indices, bucket_info, dout, inp, inputs_cpu, B, T, C, seed, stream, sync_event, copy_stream );
241242}
242243
243244void encoder_backward (nv_bfloat16* dwte, int * scratch, // gpu outputs & scratch
244245 int * workload_indices, int4 * bucket_info, // cpu scratch buffers
245246 const nv_bfloat16* dout, const int * inp, const int * inputs_cpu, // cpu/gpu inputs
246- int B, int T, int C, unsigned int seed, cudaStream_t stream) {
247- encoder_backward_imp (dwte, scratch, workload_indices, bucket_info, dout, inp, inputs_cpu, B, T, C, seed, stream);
247+ int B, int T, int C, unsigned int seed, cudaStream_t stream, cudaEvent_t sync_event, cudaStream_t copy_stream ) {
248+ encoder_backward_imp (dwte, scratch, workload_indices, bucket_info, dout, inp, inputs_cpu, B, T, C, seed, stream, sync_event, copy_stream );
248249}
0 commit comments