Skip to content

Commit 0d735d7

Browse files
committed
improve radius performance
1 parent 0adaf7f commit 0d735d7

File tree

5 files changed

+88
-71
lines changed

5 files changed

+88
-71
lines changed

csrc/cpu/utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@
44

55
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
66
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
7-
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
7+
#define CHECK_CONTIGUOUS(x) \
8+
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")

csrc/cuda/knn_cuda.cu

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,33 +27,28 @@ template <typename scalar_t> struct Cosine {
2727
}
2828
};
2929

30-
__device__ int64_t get_example_idx(int64_t idx, const int64_t *ptr,
31-
const int64_t num_examples) {
32-
for (int64_t i = 0; i < num_examples; i++) {
33-
if (ptr[i + 1] > idx)
34-
return i;
35-
}
36-
return num_examples - 1;
37-
}
38-
3930
template <typename scalar_t>
4031
__global__ void
4132
knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
4233
const int64_t *__restrict__ ptr_x, const int64_t *__restrict__ ptr_y,
43-
scalar_t *__restrict__ dist, int64_t *__restrict__ row,
44-
int64_t *__restrict__ col, const int64_t k, const int64_t n,
45-
const int64_t m, const int64_t dim, const int64_t num_examples,
46-
const bool cosine) {
34+
int64_t *__restrict__ row, int64_t *__restrict__ col,
35+
const int64_t k, const int64_t n, const int64_t m, const int64_t dim,
36+
const int64_t num_examples, const bool cosine) {
4737

4838
const int64_t n_y = blockIdx.x * blockDim.x + threadIdx.x;
4939
if (n_y >= m)
5040
return;
5141

52-
for (int64_t e = 0; e < k; e++)
53-
row[n_y * k + e] = n_y;
54-
5542
const int64_t example_idx = get_example_idx(n_y, ptr_y, num_examples);
5643

44+
scalar_t best_dist[100];
45+
int64_t best_idx[100];
46+
47+
for (int e = 0; e < k; e++) {
48+
best_dist[e] = 1e10;
49+
best_idx[e] = -1;
50+
}
51+
5752
for (int64_t n_x = ptr_x[example_idx]; n_x < ptr_x[example_idx + 1]; n_x++) {
5853
scalar_t tmp_dist = 0;
5954

@@ -70,17 +65,22 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
7065
}
7166

7267
for (int64_t e1 = 0; e1 < k; e1++) {
73-
if (dist[n_y * k + e1] > tmp_dist) {
68+
if (best_dist[e1] > tmp_dist) {
7469
for (int64_t e2 = k - 1; e2 > e1; e2--) {
75-
dist[n_y * k + e2] = dist[n_y * k + e2 - 1];
76-
col[n_y * k + e2] = col[n_y * k + e2 - 1];
70+
best_dist[e2] = best_dist[e2 - 1];
71+
best_idx[e2] = best_idx[e2 - 1];
7772
}
78-
dist[n_y * k + e1] = tmp_dist;
79-
col[n_y * k + e1] = n_x;
73+
best_dist[e1] = tmp_dist;
74+
best_idx[e1] = n_x;
8075
break;
8176
}
8277
}
8378
}
79+
80+
for (int64_t e = 0; e < k; e++) {
81+
row[n_y * k + e] = n_y;
82+
col[n_y * k + e] = best_idx[e];
83+
}
8484
}
8585

8686
torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
@@ -89,10 +89,13 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
8989
const bool cosine) {
9090

9191
CHECK_CUDA(x);
92+
CHECK_CONTIGUOUS(x);
9293
CHECK_INPUT(x.dim() == 2);
9394
CHECK_CUDA(y);
95+
CHECK_CONTIGUOUS(y);
9496
CHECK_INPUT(y.dim() == 2);
9597
CHECK_INPUT(x.size(1) == y.size(1));
98+
AT_ASSERTM(k <= 100, "`k` needs to smaller than or equal to 100");
9699

97100
if (ptr_x.has_value()) {
98101
CHECK_CUDA(ptr_x.value());
@@ -112,7 +115,6 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
112115

113116
cudaSetDevice(x.get_device());
114117

115-
auto dist = torch::full(y.size(0) * k, 1e10, y.options());
116118
auto row = torch::empty(y.size(0) * k, ptr_y.value().options());
117119
auto col = torch::full(y.size(0) * k, -1, ptr_y.value().options());
118120

@@ -123,9 +125,8 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
123125
knn_kernel<scalar_t><<<BLOCKS, THREADS, 0, stream>>>(
124126
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
125127
ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
126-
dist.data_ptr<scalar_t>(), row.data_ptr<int64_t>(),
127-
col.data_ptr<int64_t>(), k, x.size(0), y.size(0), x.size(1),
128-
ptr_x.value().numel() - 1, cosine);
128+
row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), k, x.size(0),
129+
y.size(0), x.size(1), ptr_x.value().numel() - 1, cosine);
129130
});
130131

131132
auto mask = col != -1;

csrc/cuda/radius_cuda.cu

Lines changed: 48 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,84 +4,88 @@
44

55
#include "utils.cuh"
66

7-
#define THREADS 1024
7+
#define THREADS 256
88

99
template <typename scalar_t>
10-
__global__ void radius_kernel(const scalar_t *x, const scalar_t *y,
11-
const int64_t *ptr_x, const int64_t *ptr_y,
12-
int64_t *row, int64_t *col, scalar_t radius,
13-
int64_t max_num_neighbors, int64_t dim) {
14-
15-
const int64_t batch_idx = blockIdx.x;
16-
17-
const int64_t x_start_idx = ptr_x[batch_idx];
18-
const int64_t x_end_idx = ptr_x[batch_idx + 1];
19-
20-
const int64_t y_start_idx = ptr_y[batch_idx];
21-
const int64_t y_end_idx = ptr_y[batch_idx + 1];
22-
23-
for (int64_t n_y = y_start_idx + threadIdx.x; n_y < y_end_idx;
24-
n_y += THREADS) {
25-
int64_t count = 0;
26-
for (int64_t n_x = x_start_idx; n_x < x_end_idx; n_x++) {
27-
scalar_t dist = 0;
28-
for (int64_t d = 0; d < dim; d++) {
29-
dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
30-
(x[n_x * dim + d] - y[n_y * dim + d]);
31-
}
32-
dist = sqrt(dist);
33-
34-
if (dist < radius) {
35-
row[n_y * max_num_neighbors + count] = n_y;
36-
col[n_y * max_num_neighbors + count] = n_x;
37-
count++;
38-
}
39-
40-
if (count >= max_num_neighbors) {
41-
break;
42-
}
10+
__global__ void
11+
radius_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
12+
const int64_t *__restrict__ ptr_x,
13+
const int64_t *__restrict__ ptr_y, int64_t *__restrict__ row,
14+
int64_t *__restrict__ col, const scalar_t r, const int64_t n,
15+
const int64_t m, const int64_t dim, const int64_t num_examples,
16+
const int64_t max_num_neighbors) {
17+
18+
const int64_t n_y = blockIdx.x * blockDim.x + threadIdx.x;
19+
if (n_y >= m)
20+
return;
21+
22+
int64_t count = 0;
23+
const int64_t example_idx = get_example_idx(n_y, ptr_y, num_examples);
24+
25+
for (int64_t n_x = ptr_x[example_idx]; n_x < ptr_x[example_idx + 1]; n_x++) {
26+
scalar_t dist = 0;
27+
for (int64_t d = 0; d < dim; d++) {
28+
dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
29+
(x[n_x * dim + d] - y[n_y * dim + d]);
4330
}
31+
32+
if (dist < r) {
33+
row[n_y * max_num_neighbors + count] = n_y;
34+
col[n_y * max_num_neighbors + count] = n_x;
35+
count++;
36+
}
37+
38+
if (count >= max_num_neighbors)
39+
break;
4440
}
4541
}
4642

47-
torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y,
43+
torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y,
4844
torch::optional<torch::Tensor> ptr_x,
49-
torch::optional<torch::Tensor> ptr_y, double r,
50-
int64_t max_num_neighbors) {
45+
torch::optional<torch::Tensor> ptr_y, const double r,
46+
const int64_t max_num_neighbors) {
5147
CHECK_CUDA(x);
48+
CHECK_CONTIGUOUS(x);
5249
CHECK_INPUT(x.dim() == 2);
5350
CHECK_CUDA(y);
51+
CHECK_CONTIGUOUS(y);
5452
CHECK_INPUT(y.dim() == 2);
53+
CHECK_INPUT(x.size(1) == y.size(1));
54+
5555
cudaSetDevice(x.get_device());
5656

5757
if (ptr_x.has_value()) {
5858
CHECK_CUDA(ptr_x.value());
5959
CHECK_INPUT(ptr_x.value().dim() == 1);
60-
} else {
60+
} else
6161
ptr_x = torch::arange(0, x.size(0) + 1, x.size(0),
6262
x.options().dtype(torch::kLong));
63-
}
63+
6464
if (ptr_y.has_value()) {
6565
CHECK_CUDA(ptr_y.value());
6666
CHECK_INPUT(ptr_y.value().dim() == 1);
67-
} else {
67+
} else
6868
ptr_y = torch::arange(0, y.size(0) + 1, y.size(0),
6969
y.options().dtype(torch::kLong));
70-
}
70+
7171
CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel());
7272

73+
cudaSetDevice(x.get_device());
74+
7375
auto row =
7476
torch::full(y.size(0) * max_num_neighbors, -1, ptr_y.value().options());
7577
auto col =
7678
torch::full(y.size(0) * max_num_neighbors, -1, ptr_y.value().options());
7779

80+
dim3 BLOCKS((y.size(0) + THREADS - 1) / THREADS);
81+
7882
auto stream = at::cuda::getCurrentCUDAStream();
7983
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "radius_kernel", [&] {
80-
radius_kernel<scalar_t><<<ptr_x.value().size(0) - 1, THREADS, 0, stream>>>(
84+
radius_kernel<scalar_t><<<BLOCKS, THREADS, 0, stream>>>(
8185
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
8286
ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
83-
row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), r, max_num_neighbors,
84-
x.size(1));
87+
row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), r * r, x.size(0),
88+
y.size(0), x.size(1), ptr_x.value().numel() - 1, max_num_neighbors);
8589
});
8690

8791
auto mask = row != -1;

csrc/cuda/utils.cuh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,14 @@
55
#define CHECK_CUDA(x) \
66
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
77
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
8+
#define CHECK_CONTIGUOUS(x) \
9+
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
10+
11+
__device__ int64_t get_example_idx(int64_t idx, const int64_t *ptr,
12+
const int64_t num_examples) {
13+
for (int64_t i = 0; i < num_examples; i++) {
14+
if (ptr[i + 1] > idx)
15+
return i;
16+
}
17+
return num_examples - 1;
18+
}

test/test_radius.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def test_radius_graph_large(dtype, device):
7171
x = torch.randn(1000, 3, dtype=dtype, device=device)
7272

7373
edge_index = radius_graph(x, r=0.5, flow='target_to_source', loop=True,
74-
max_num_neighbors=2000, num_workers=6)
74+
max_num_neighbors=2000)
7575

7676
tree = scipy.spatial.cKDTree(x.cpu().numpy())
7777
col = tree.query_ball_point(x.cpu(), r=0.5)

0 commit comments

Comments
 (0)