Skip to content

Commit 0adaf7f

Browse files
committed
improve knn performance
1 parent 442e8d9 commit 0adaf7f

File tree

3 files changed

+66
-56
lines changed

3 files changed

+66
-56
lines changed

csrc/cpu/utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44

55
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
66
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
7+
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")

csrc/cuda/knn_cuda.cu

Lines changed: 64 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#include "utils.cuh"
66

7-
#define THREADS 1024
7+
#define THREADS 256
88

99
template <typename scalar_t> struct Cosine {
1010
static inline __device__ scalar_t dot(const scalar_t *a, const scalar_t *b,
@@ -27,95 +27,105 @@ template <typename scalar_t> struct Cosine {
2727
}
2828
};
2929

30-
template <typename scalar_t>
31-
__global__ void knn_kernel(const scalar_t *x, const scalar_t *y,
32-
const int64_t *ptr_x, const int64_t *ptr_y,
33-
scalar_t *dist, int64_t *row, int64_t *col,
34-
int64_t K, int64_t dim, bool cosine) {
35-
36-
const int64_t batch_idx = blockIdx.x;
37-
38-
const int64_t x_start_idx = ptr_x[batch_idx];
39-
const int64_t x_end_idx = ptr_x[batch_idx + 1];
40-
41-
const int64_t y_start_idx = ptr_y[batch_idx];
42-
const int64_t y_end_idx = ptr_y[batch_idx + 1];
43-
44-
for (int64_t n_y = y_start_idx + threadIdx.x; n_y < y_end_idx;
45-
n_y += THREADS) {
46-
47-
for (int64_t k = 0; k < K; k++) {
48-
row[n_y * K + k] = n_y;
49-
}
30+
__device__ int64_t get_example_idx(int64_t idx, const int64_t *ptr,
31+
const int64_t num_examples) {
32+
for (int64_t i = 0; i < num_examples; i++) {
33+
if (ptr[i + 1] > idx)
34+
return i;
35+
}
36+
return num_examples - 1;
37+
}
5038

51-
for (int64_t n_x = x_start_idx; n_x < x_end_idx; n_x++) {
52-
53-
scalar_t tmp_dist = 0;
54-
if (cosine) {
55-
tmp_dist = Cosine<scalar_t>::dot(x, y, n_x, n_y, dim) /
56-
(Cosine<scalar_t>::norm(x, n_x, dim) *
57-
Cosine<scalar_t>::norm(y, n_y, dim));
58-
tmp_dist = 1. - tmp_dist;
59-
} else {
60-
for (int64_t d = 0; d < dim; d++) {
61-
tmp_dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
62-
(x[n_x * dim + d] - y[n_y * dim + d]);
63-
}
39+
template <typename scalar_t>
40+
__global__ void
41+
knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
42+
const int64_t *__restrict__ ptr_x, const int64_t *__restrict__ ptr_y,
43+
scalar_t *__restrict__ dist, int64_t *__restrict__ row,
44+
int64_t *__restrict__ col, const int64_t k, const int64_t n,
45+
const int64_t m, const int64_t dim, const int64_t num_examples,
46+
const bool cosine) {
47+
48+
const int64_t n_y = blockIdx.x * blockDim.x + threadIdx.x;
49+
if (n_y >= m)
50+
return;
51+
52+
for (int64_t e = 0; e < k; e++)
53+
row[n_y * k + e] = n_y;
54+
55+
const int64_t example_idx = get_example_idx(n_y, ptr_y, num_examples);
56+
57+
for (int64_t n_x = ptr_x[example_idx]; n_x < ptr_x[example_idx + 1]; n_x++) {
58+
scalar_t tmp_dist = 0;
59+
60+
if (cosine) {
61+
tmp_dist = Cosine<scalar_t>::dot(x, y, n_x, n_y, dim) /
62+
(Cosine<scalar_t>::norm(x, n_x, dim) *
63+
Cosine<scalar_t>::norm(y, n_y, dim));
64+
tmp_dist = 1. - tmp_dist;
65+
} else {
66+
for (int64_t d = 0; d < dim; d++) {
67+
tmp_dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
68+
(x[n_x * dim + d] - y[n_y * dim + d]);
6469
}
70+
}
6571

66-
for (int64_t k_idx_1 = 0; k_idx_1 < K; k_idx_1++) {
67-
if (dist[n_y * K + k_idx_1] > tmp_dist) {
68-
for (ptrdiff_t k_idx_2 = K - 1; k_idx_2 > k_idx_1; k_idx_2--) {
69-
dist[n_y * K + k_idx_2] = dist[n_y * K + k_idx_2 - 1];
70-
col[n_y * K + k_idx_2] = col[n_y * K + k_idx_2 - 1];
71-
}
72-
dist[n_y * K + k_idx_1] = tmp_dist;
73-
col[n_y * K + k_idx_1] = n_x;
74-
break;
72+
for (int64_t e1 = 0; e1 < k; e1++) {
73+
if (dist[n_y * k + e1] > tmp_dist) {
74+
for (int64_t e2 = k - 1; e2 > e1; e2--) {
75+
dist[n_y * k + e2] = dist[n_y * k + e2 - 1];
76+
col[n_y * k + e2] = col[n_y * k + e2 - 1];
7577
}
78+
dist[n_y * k + e1] = tmp_dist;
79+
col[n_y * k + e1] = n_x;
80+
break;
7681
}
7782
}
7883
}
7984
}
8085

81-
torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y,
86+
torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
8287
torch::optional<torch::Tensor> ptr_x,
83-
torch::optional<torch::Tensor> ptr_y, int64_t k,
84-
bool cosine) {
88+
torch::optional<torch::Tensor> ptr_y, const int64_t k,
89+
const bool cosine) {
8590

8691
CHECK_CUDA(x);
8792
CHECK_INPUT(x.dim() == 2);
8893
CHECK_CUDA(y);
8994
CHECK_INPUT(y.dim() == 2);
90-
cudaSetDevice(x.get_device());
95+
CHECK_INPUT(x.size(1) == y.size(1));
9196

9297
if (ptr_x.has_value()) {
9398
CHECK_CUDA(ptr_x.value());
9499
CHECK_INPUT(ptr_x.value().dim() == 1);
95-
} else {
100+
} else
96101
ptr_x = torch::arange(0, x.size(0) + 1, x.size(0),
97102
x.options().dtype(torch::kLong));
98-
}
103+
99104
if (ptr_y.has_value()) {
100105
CHECK_CUDA(ptr_y.value());
101106
CHECK_INPUT(ptr_y.value().dim() == 1);
102-
} else {
107+
} else
103108
ptr_y = torch::arange(0, y.size(0) + 1, y.size(0),
104109
y.options().dtype(torch::kLong));
105-
}
110+
106111
CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel());
107112

108-
auto dist = torch::full(y.size(0) * k, 1e38, y.options());
113+
cudaSetDevice(x.get_device());
114+
115+
auto dist = torch::full(y.size(0) * k, 1e10, y.options());
109116
auto row = torch::empty(y.size(0) * k, ptr_y.value().options());
110117
auto col = torch::full(y.size(0) * k, -1, ptr_y.value().options());
111118

119+
dim3 BLOCKS((y.size(0) + THREADS - 1) / THREADS);
120+
112121
auto stream = at::cuda::getCurrentCUDAStream();
113122
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "knn_kernel", [&] {
114-
knn_kernel<scalar_t><<<ptr_x.value().size(0) - 1, THREADS, 0, stream>>>(
123+
knn_kernel<scalar_t><<<BLOCKS, THREADS, 0, stream>>>(
115124
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
116125
ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
117126
dist.data_ptr<scalar_t>(), row.data_ptr<int64_t>(),
118-
col.data_ptr<int64_t>(), k, x.size(1), cosine);
127+
col.data_ptr<int64_t>(), k, x.size(0), y.size(0), x.size(1),
128+
ptr_x.value().numel() - 1, cosine);
119129
});
120130

121131
auto mask = col != -1;

test/test_knn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,7 @@ def test_knn_graph(dtype, device):
7171
def test_knn_graph_large(dtype, device):
7272
x = torch.randn(1000, 3, dtype=dtype, device=device)
7373

74-
edge_index = knn_graph(x, k=5, flow='target_to_source', loop=True,
75-
num_workers=6)
74+
edge_index = knn_graph(x, k=5, flow='target_to_source', loop=True)
7675

7776
tree = scipy.spatial.cKDTree(x.cpu().numpy())
7877
_, col = tree.query(x.cpu(), k=5)

0 commit comments

Comments
 (0)