44
55#include " utils.cuh"
66
7- #define THREADS 1024
7+ #define THREADS 256
88
99template <typename scalar_t > struct Cosine {
1010 static inline __device__ scalar_t dot (const scalar_t *a, const scalar_t *b,
@@ -27,95 +27,105 @@ template <typename scalar_t> struct Cosine {
2727 }
2828};
2929
30- template <typename scalar_t >
31- __global__ void knn_kernel (const scalar_t *x, const scalar_t *y,
32- const int64_t *ptr_x, const int64_t *ptr_y,
33- scalar_t *dist, int64_t *row, int64_t *col,
34- int64_t K, int64_t dim, bool cosine) {
35-
36- const int64_t batch_idx = blockIdx .x ;
37-
38- const int64_t x_start_idx = ptr_x[batch_idx];
39- const int64_t x_end_idx = ptr_x[batch_idx + 1 ];
40-
41- const int64_t y_start_idx = ptr_y[batch_idx];
42- const int64_t y_end_idx = ptr_y[batch_idx + 1 ];
43-
44- for (int64_t n_y = y_start_idx + threadIdx .x ; n_y < y_end_idx;
45- n_y += THREADS) {
46-
47- for (int64_t k = 0 ; k < K; k++) {
48- row[n_y * K + k] = n_y;
49- }
30+ __device__ int64_t get_example_idx (int64_t idx, const int64_t *ptr,
31+ const int64_t num_examples) {
32+ for (int64_t i = 0 ; i < num_examples; i++) {
33+ if (ptr[i + 1 ] > idx)
34+ return i;
35+ }
36+ return num_examples - 1 ;
37+ }
5038
51- for (int64_t n_x = x_start_idx; n_x < x_end_idx; n_x++) {
52-
53- scalar_t tmp_dist = 0 ;
54- if (cosine) {
55- tmp_dist = Cosine<scalar_t >::dot (x, y, n_x, n_y, dim) /
56- (Cosine<scalar_t >::norm (x, n_x, dim) *
57- Cosine<scalar_t >::norm (y, n_y, dim));
58- tmp_dist = 1 . - tmp_dist;
59- } else {
60- for (int64_t d = 0 ; d < dim; d++) {
61- tmp_dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
62- (x[n_x * dim + d] - y[n_y * dim + d]);
63- }
39+ template <typename scalar_t >
40+ __global__ void
41+ knn_kernel (const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
42+ const int64_t *__restrict__ ptr_x, const int64_t *__restrict__ ptr_y,
43+ scalar_t *__restrict__ dist, int64_t *__restrict__ row,
44+ int64_t *__restrict__ col, const int64_t k, const int64_t n,
45+ const int64_t m, const int64_t dim, const int64_t num_examples,
46+ const bool cosine) {
47+
48+ const int64_t n_y = blockIdx .x * blockDim .x + threadIdx .x ;
49+ if (n_y >= m)
50+ return ;
51+
52+ for (int64_t e = 0 ; e < k; e++)
53+ row[n_y * k + e] = n_y;
54+
55+ const int64_t example_idx = get_example_idx (n_y, ptr_y, num_examples);
56+
57+ for (int64_t n_x = ptr_x[example_idx]; n_x < ptr_x[example_idx + 1 ]; n_x++) {
58+ scalar_t tmp_dist = 0 ;
59+
60+ if (cosine) {
61+ tmp_dist = Cosine<scalar_t >::dot (x, y, n_x, n_y, dim) /
62+ (Cosine<scalar_t >::norm (x, n_x, dim) *
63+ Cosine<scalar_t >::norm (y, n_y, dim));
64+ tmp_dist = 1 . - tmp_dist;
65+ } else {
66+ for (int64_t d = 0 ; d < dim; d++) {
67+ tmp_dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
68+ (x[n_x * dim + d] - y[n_y * dim + d]);
6469 }
70+ }
6571
66- for (int64_t k_idx_1 = 0 ; k_idx_1 < K; k_idx_1++) {
67- if (dist[n_y * K + k_idx_1] > tmp_dist) {
68- for (ptrdiff_t k_idx_2 = K - 1 ; k_idx_2 > k_idx_1; k_idx_2--) {
69- dist[n_y * K + k_idx_2] = dist[n_y * K + k_idx_2 - 1 ];
70- col[n_y * K + k_idx_2] = col[n_y * K + k_idx_2 - 1 ];
71- }
72- dist[n_y * K + k_idx_1] = tmp_dist;
73- col[n_y * K + k_idx_1] = n_x;
74- break ;
72+ for (int64_t e1 = 0 ; e1 < k; e1 ++) {
73+ if (dist[n_y * k + e1 ] > tmp_dist) {
74+ for (int64_t e2 = k - 1 ; e2 > e1 ; e2 --) {
75+ dist[n_y * k + e2 ] = dist[n_y * k + e2 - 1 ];
76+ col[n_y * k + e2 ] = col[n_y * k + e2 - 1 ];
7577 }
78+ dist[n_y * k + e1 ] = tmp_dist;
79+ col[n_y * k + e1 ] = n_x;
80+ break ;
7681 }
7782 }
7883 }
7984}
8085
81- torch::Tensor knn_cuda (torch::Tensor x, torch::Tensor y,
86+ torch::Tensor knn_cuda (const torch::Tensor x, const torch::Tensor y,
8287 torch::optional<torch::Tensor> ptr_x,
83- torch::optional<torch::Tensor> ptr_y, int64_t k,
84- bool cosine) {
88+ torch::optional<torch::Tensor> ptr_y, const int64_t k,
89+ const bool cosine) {
8590
8691 CHECK_CUDA (x);
8792 CHECK_INPUT (x.dim () == 2 );
8893 CHECK_CUDA (y);
8994 CHECK_INPUT (y.dim () == 2 );
90- cudaSetDevice (x.get_device ( ));
95+ CHECK_INPUT (x.size ( 1 ) == y. size ( 1 ));
9196
9297 if (ptr_x.has_value ()) {
9398 CHECK_CUDA (ptr_x.value ());
9499 CHECK_INPUT (ptr_x.value ().dim () == 1 );
95- } else {
100+ } else
96101 ptr_x = torch::arange (0 , x.size (0 ) + 1 , x.size (0 ),
97102 x.options ().dtype (torch::kLong ));
98- }
103+
99104 if (ptr_y.has_value ()) {
100105 CHECK_CUDA (ptr_y.value ());
101106 CHECK_INPUT (ptr_y.value ().dim () == 1 );
102- } else {
107+ } else
103108 ptr_y = torch::arange (0 , y.size (0 ) + 1 , y.size (0 ),
104109 y.options ().dtype (torch::kLong ));
105- }
110+
106111 CHECK_INPUT (ptr_x.value ().numel () == ptr_y.value ().numel ());
107112
108- auto dist = torch::full (y.size (0 ) * k, 1e38 , y.options ());
113+ cudaSetDevice (x.get_device ());
114+
115+ auto dist = torch::full (y.size (0 ) * k, 1e10 , y.options ());
109116 auto row = torch::empty (y.size (0 ) * k, ptr_y.value ().options ());
110117 auto col = torch::full (y.size (0 ) * k, -1 , ptr_y.value ().options ());
111118
119+ dim3 BLOCKS ((y.size (0 ) + THREADS - 1 ) / THREADS);
120+
112121 auto stream = at::cuda::getCurrentCUDAStream ();
113122 AT_DISPATCH_FLOATING_TYPES (x.scalar_type (), " knn_kernel" , [&] {
114- knn_kernel<scalar_t ><<<ptr_x.value().size( 0 ) - 1 , THREADS, 0 , stream>>> (
123+ knn_kernel<scalar_t ><<<BLOCKS , THREADS, 0 , stream>>> (
115124 x.data_ptr <scalar_t >(), y.data_ptr <scalar_t >(),
116125 ptr_x.value ().data_ptr <int64_t >(), ptr_y.value ().data_ptr <int64_t >(),
117126 dist.data_ptr <scalar_t >(), row.data_ptr <int64_t >(),
118- col.data_ptr <int64_t >(), k, x.size (1 ), cosine);
127+ col.data_ptr <int64_t >(), k, x.size (0 ), y.size (0 ), x.size (1 ),
128+ ptr_x.value ().numel () - 1 , cosine);
119129 });
120130
121131 auto mask = col != -1 ;
0 commit comments