|
1 | 1 | #include <ATen/ATen.h> |
2 | 2 | #include <ATen/cuda/detail/IndexUtils.cuh> |
3 | 3 |
|
| 4 | +#define THREADS 1024 |
| 5 | +#define BLOCKS(N) (N + THREADS - 1) / THREADS |
| 6 | + |
4 | 7 | template <typename scalar_t> |
5 | 8 | __global__ void grid_cuda_kernel( |
6 | 9 | int64_t *cluster, const at::cuda::detail::TensorInfo<scalar_t, int> pos, |
@@ -29,16 +32,13 @@ at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start, |
29 | 32 | const auto num_nodes = pos.size(0); |
30 | 33 | auto cluster = at::empty(pos.type().toScalarType(at::kLong), {num_nodes}); |
31 | 34 |
|
32 | | - const int threads = 1024; |
33 | | - const dim3 blocks((num_nodes + threads - 1) / threads); |
34 | | - |
35 | 35 | AT_DISPATCH_ALL_TYPES(pos.type(), "grid_cuda_kernel", [&] { |
36 | 36 | auto cluster_data = cluster.data<int64_t>(); |
37 | 37 | auto pos_info = at::cuda::detail::getTensorInfo<scalar_t, int>(pos); |
38 | 38 | auto size_data = size.data<scalar_t>(); |
39 | 39 | auto start_data = start.data<scalar_t>(); |
40 | 40 | auto end_data = end.data<scalar_t>(); |
41 | | - grid_cuda_kernel<scalar_t><<<blocks, threads>>>( |
| 41 | + grid_cuda_kernel<scalar_t><<<BLOCKS(num_nodes), THREADS>>>( |
42 | 42 | cluster_data, pos_info, size_data, start_data, end_data, num_nodes); |
43 | 43 | }); |
44 | 44 |
|
|
0 commit comments