Skip to content

Commit 5f98eee

Browse files
committed
cuda grid
1 parent 6b18f2d commit 5f98eee

File tree

3 files changed

+76
-0
lines changed

3 files changed

+76
-0
lines changed

aten/cuda/cluster.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#include <torch/torch.h>
2+
3+
at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
4+
at::Tensor end);
5+
6+
#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor")
7+
8+
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
9+
at::Tensor end) {
10+
CHECK_CUDA(pos);
11+
CHECK_CUDA(size);
12+
CHECK_CUDA(start);
13+
CHECK_CUDA(end);
14+
15+
return grid_cuda(pos, size, start, end);
16+
}
17+
18+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
19+
m.def("grid", &grid, "Grid (CUDA)");
20+
}

aten/cuda/cluster_kernel.cu

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/cuda/detail/IndexUtils.cuh>
3+
4+
template <typename scalar_t>
5+
__global__ void grid_cuda_kernel(
6+
int64_t *cluster, const at::cuda::detail::TensorInfo<scalar_t, int> pos,
7+
const scalar_t *__restrict__ size, const scalar_t *__restrict__ start,
8+
const scalar_t *__restrict__ end, const size_t n) {
9+
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
10+
const size_t stride = blockDim.x * gridDim.x;
11+
for (ptrdiff_t i = index; i < n; i += stride) {
12+
int64_t c = 0, k = 1;
13+
scalar_t tmp;
14+
for (ptrdiff_t d = 0; d < pos.sizes[1]; d++) {
15+
tmp = (pos.data[i * pos.strides[0] + d * pos.strides[1]]) - start[d];
16+
c += (int64_t)(tmp / size[d]) * k;
17+
k += (int64_t)((end[d] - start[d]) / size[d]);
18+
}
19+
cluster[i] = c;
20+
}
21+
}
22+
23+
at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
24+
at::Tensor end) {
25+
size = size.toType(pos.type());
26+
start = start.toType(pos.type());
27+
end = end.toType(pos.type());
28+
29+
const auto num_nodes = pos.size(0);
30+
auto cluster = at::empty(pos.type().toScalarType(at::kLong), {num_nodes});
31+
32+
const int threads = 1024;
33+
const dim3 blocks((num_nodes + threads - 1) / threads);
34+
35+
AT_DISPATCH_ALL_TYPES(pos.type(), "unique", [&] {
36+
auto cluster_data = cluster.data<int64_t>();
37+
auto pos_info = at::cuda::detail::getTensorInfo<scalar_t, int>(pos);
38+
auto size_data = size.data<scalar_t>();
39+
auto start_data = start.data<scalar_t>();
40+
auto end_data = end.data<scalar_t>();
41+
grid_cuda_kernel<scalar_t><<<blocks, threads>>>(
42+
cluster_data, pos_info, size_data, start_data, end_data, num_nodes);
43+
});
44+
45+
return cluster;
46+
}

aten/cuda/setup.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from setuptools import setup
2+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3+
4+
setup(
5+
name='cluster_cuda',
6+
ext_modules=[
7+
CUDAExtension('cluster_cuda', ['cluster.cpp', 'cluster_kernel.cu'])
8+
],
9+
cmdclass={'build_ext': BuildExtension},
10+
)

0 commit comments

Comments
 (0)