|
4 | 4 |
|
5 | 5 | #include "utils.cuh" |
6 | 6 |
|
7 | | -#define THREADS 1024 |
| 7 | +#define THREADS 256 |
8 | 8 |
|
9 | 9 | template <typename scalar_t> |
10 | | -__global__ void radius_kernel(const scalar_t *x, const scalar_t *y, |
11 | | - const int64_t *ptr_x, const int64_t *ptr_y, |
12 | | - int64_t *row, int64_t *col, scalar_t radius, |
13 | | - int64_t max_num_neighbors, int64_t dim) { |
14 | | - |
15 | | - const int64_t batch_idx = blockIdx.x; |
16 | | - |
17 | | - const int64_t x_start_idx = ptr_x[batch_idx]; |
18 | | - const int64_t x_end_idx = ptr_x[batch_idx + 1]; |
19 | | - |
20 | | - const int64_t y_start_idx = ptr_y[batch_idx]; |
21 | | - const int64_t y_end_idx = ptr_y[batch_idx + 1]; |
22 | | - |
23 | | - for (int64_t n_y = y_start_idx + threadIdx.x; n_y < y_end_idx; |
24 | | - n_y += THREADS) { |
25 | | - int64_t count = 0; |
26 | | - for (int64_t n_x = x_start_idx; n_x < x_end_idx; n_x++) { |
27 | | - scalar_t dist = 0; |
28 | | - for (int64_t d = 0; d < dim; d++) { |
29 | | - dist += (x[n_x * dim + d] - y[n_y * dim + d]) * |
30 | | - (x[n_x * dim + d] - y[n_y * dim + d]); |
31 | | - } |
32 | | - dist = sqrt(dist); |
33 | | - |
34 | | - if (dist < radius) { |
35 | | - row[n_y * max_num_neighbors + count] = n_y; |
36 | | - col[n_y * max_num_neighbors + count] = n_x; |
37 | | - count++; |
38 | | - } |
39 | | - |
40 | | - if (count >= max_num_neighbors) { |
41 | | - break; |
42 | | - } |
| 10 | +__global__ void |
| 11 | +radius_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y, |
| 12 | + const int64_t *__restrict__ ptr_x, |
| 13 | + const int64_t *__restrict__ ptr_y, int64_t *__restrict__ row, |
| 14 | + int64_t *__restrict__ col, const scalar_t r, const int64_t n, |
| 15 | + const int64_t m, const int64_t dim, const int64_t num_examples, |
| 16 | + const int64_t max_num_neighbors) { |
| 17 | + |
| 18 | + const int64_t n_y = blockIdx.x * blockDim.x + threadIdx.x; |
| 19 | + if (n_y >= m) |
| 20 | + return; |
| 21 | + |
| 22 | + int64_t count = 0; |
| 23 | + const int64_t example_idx = get_example_idx(n_y, ptr_y, num_examples); |
| 24 | + |
| 25 | + for (int64_t n_x = ptr_x[example_idx]; n_x < ptr_x[example_idx + 1]; n_x++) { |
| 26 | + scalar_t dist = 0; |
| 27 | + for (int64_t d = 0; d < dim; d++) { |
| 28 | + dist += (x[n_x * dim + d] - y[n_y * dim + d]) * |
| 29 | + (x[n_x * dim + d] - y[n_y * dim + d]); |
43 | 30 | } |
| 31 | + |
| 32 | + if (dist < r) { |
| 33 | + row[n_y * max_num_neighbors + count] = n_y; |
| 34 | + col[n_y * max_num_neighbors + count] = n_x; |
| 35 | + count++; |
| 36 | + } |
| 37 | + |
| 38 | + if (count >= max_num_neighbors) |
| 39 | + break; |
44 | 40 | } |
45 | 41 | } |
46 | 42 |
|
47 | | -torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y, |
| 43 | +torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y, |
48 | 44 | torch::optional<torch::Tensor> ptr_x, |
49 | | - torch::optional<torch::Tensor> ptr_y, double r, |
50 | | - int64_t max_num_neighbors) { |
| 45 | + torch::optional<torch::Tensor> ptr_y, const double r, |
| 46 | + const int64_t max_num_neighbors) { |
51 | 47 | CHECK_CUDA(x); |
| 48 | + CHECK_CONTIGUOUS(x); |
52 | 49 | CHECK_INPUT(x.dim() == 2); |
53 | 50 | CHECK_CUDA(y); |
| 51 | + CHECK_CONTIGUOUS(y); |
54 | 52 | CHECK_INPUT(y.dim() == 2); |
| 53 | + CHECK_INPUT(x.size(1) == y.size(1)); |
| 54 | + |
55 | 55 | cudaSetDevice(x.get_device()); |
56 | 56 |
|
57 | 57 | if (ptr_x.has_value()) { |
58 | 58 | CHECK_CUDA(ptr_x.value()); |
59 | 59 | CHECK_INPUT(ptr_x.value().dim() == 1); |
60 | | - } else { |
| 60 | + } else |
61 | 61 | ptr_x = torch::arange(0, x.size(0) + 1, x.size(0), |
62 | 62 | x.options().dtype(torch::kLong)); |
63 | | - } |
| 63 | + |
64 | 64 | if (ptr_y.has_value()) { |
65 | 65 | CHECK_CUDA(ptr_y.value()); |
66 | 66 | CHECK_INPUT(ptr_y.value().dim() == 1); |
67 | | - } else { |
| 67 | + } else |
68 | 68 | ptr_y = torch::arange(0, y.size(0) + 1, y.size(0), |
69 | 69 | y.options().dtype(torch::kLong)); |
70 | | - } |
| 70 | + |
71 | 71 | CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel()); |
72 | 72 |
|
| 73 | + cudaSetDevice(x.get_device()); |
| 74 | + |
73 | 75 | auto row = |
74 | 76 | torch::full(y.size(0) * max_num_neighbors, -1, ptr_y.value().options()); |
75 | 77 | auto col = |
76 | 78 | torch::full(y.size(0) * max_num_neighbors, -1, ptr_y.value().options()); |
77 | 79 |
|
| 80 | + dim3 BLOCKS((y.size(0) + THREADS - 1) / THREADS); |
| 81 | + |
78 | 82 | auto stream = at::cuda::getCurrentCUDAStream(); |
79 | 83 | AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "radius_kernel", [&] { |
80 | | - radius_kernel<scalar_t><<<ptr_x.value().size(0) - 1, THREADS, 0, stream>>>( |
| 84 | + radius_kernel<scalar_t><<<BLOCKS, THREADS, 0, stream>>>( |
81 | 85 | x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(), |
82 | 86 | ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(), |
83 | | - row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), r, max_num_neighbors, |
84 | | - x.size(1)); |
| 87 | + row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), r * r, x.size(0), |
| 88 | + y.size(0), x.size(1), ptr_x.value().numel() - 1, max_num_neighbors); |
85 | 89 | }); |
86 | 90 |
|
87 | 91 | auto mask = row != -1; |
|
0 commit comments