Skip to content

Commit b992389

Browse files
committed
cleaner
1 parent b1d9a36 commit b992389

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

aten/cuda/cluster_kernel.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#include <ATen/ATen.h>
22
#include <ATen/cuda/detail/IndexUtils.cuh>
33

4+
#define THREADS 1024
5+
#define BLOCKS(N) (N + THREADS - 1) / THREADS
6+
47
template <typename scalar_t>
58
__global__ void grid_cuda_kernel(
69
int64_t *cluster, const at::cuda::detail::TensorInfo<scalar_t, int> pos,
@@ -29,16 +32,13 @@ at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
2932
const auto num_nodes = pos.size(0);
3033
auto cluster = at::empty(pos.type().toScalarType(at::kLong), {num_nodes});
3134

32-
const int threads = 1024;
33-
const dim3 blocks((num_nodes + threads - 1) / threads);
34-
3535
AT_DISPATCH_ALL_TYPES(pos.type(), "grid_cuda_kernel", [&] {
3636
auto cluster_data = cluster.data<int64_t>();
3737
auto pos_info = at::cuda::detail::getTensorInfo<scalar_t, int>(pos);
3838
auto size_data = size.data<scalar_t>();
3939
auto start_data = start.data<scalar_t>();
4040
auto end_data = end.data<scalar_t>();
41-
grid_cuda_kernel<scalar_t><<<blocks, threads>>>(
41+
grid_cuda_kernel<scalar_t><<<BLOCKS(num_nodes), THREADS>>>(
4242
cluster_data, pos_info, size_data, start_data, end_data, num_nodes);
4343
});
4444

0 commit comments

Comments
 (0)