Skip to content

Commit aa9a388

Browse files
committed
knn cpu and multithreading support with testcases; positions of arguments
1 parent 3d682e5 commit aa9a388

File tree

10 files changed

+316
-73
lines changed

10 files changed

+316
-73
lines changed

csrc/cpu/knn_cpu.cpp

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
#include "radius_cpu.h"
2+
#include <algorithm>
3+
#include "utils.h"
4+
#include <cstdint>
5+
6+
7+
torch::Tensor knn_cpu(torch::Tensor support, torch::Tensor query,
8+
int64_t k, int64_t n_threads){
9+
10+
CHECK_CPU(query);
11+
CHECK_CPU(support);
12+
13+
torch::Tensor out;
14+
std::vector<size_t>* neighbors_indices = new std::vector<size_t>();
15+
auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
16+
int max_count = 0;
17+
18+
AT_DISPATCH_ALL_TYPES(query.scalar_type(), "radius_cpu", [&] {
19+
20+
auto data_q = query.data_ptr<scalar_t>();
21+
auto data_s = support.data_ptr<scalar_t>();
22+
std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
23+
data_q + query.size(0)*query.size(1));
24+
std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
25+
data_s + support.size(0)*support.size(1));
26+
27+
int dim = torch::size(query, 1);
28+
29+
max_count = nanoflann_neighbors<scalar_t>(queries_stl, supports_stl ,neighbors_indices, 0, dim, 0, n_threads, k, 0);
30+
31+
});
32+
33+
size_t* neighbors_indices_ptr = neighbors_indices->data();
34+
35+
const long long tsize = static_cast<long long>(neighbors_indices->size()/2);
36+
out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
37+
out = out.t();
38+
39+
auto result = torch::zeros_like(out);
40+
41+
auto index = torch::tensor({1,0});
42+
43+
result.index_copy_(0, index, out);
44+
45+
return result;
46+
}
47+
48+
49+
void get_size_batch(const std::vector<long>& batch, std::vector<long>& res){
50+
51+
res.resize(batch[batch.size()-1]-batch[0]+1, 0);
52+
long ind = batch[0];
53+
long incr = 1;
54+
for(unsigned long i=1; i < batch.size(); i++){
55+
56+
if(batch[i] == ind)
57+
incr++;
58+
else{
59+
res[ind-batch[0]] = incr;
60+
incr =1;
61+
ind = batch[i];
62+
}
63+
}
64+
res[ind-batch[0]] = incr;
65+
}
66+
67+
torch::Tensor batch_knn_cpu(torch::Tensor support,
68+
torch::Tensor query,
69+
torch::Tensor support_batch,
70+
torch::Tensor query_batch,
71+
int64_t k) {
72+
73+
CHECK_CPU(query);
74+
CHECK_CPU(support);
75+
CHECK_CPU(query_batch);
76+
CHECK_CPU(support_batch);
77+
78+
torch::Tensor out;
79+
auto data_qb = query_batch.data_ptr<int64_t>();
80+
auto data_sb = support_batch.data_ptr<int64_t>();
81+
82+
std::vector<long> query_batch_stl = std::vector<long>(data_qb, data_qb+query_batch.size(0));
83+
std::vector<long> size_query_batch_stl;
84+
CHECK_INPUT(std::is_sorted(query_batch_stl.begin(),query_batch_stl.end()));
85+
get_size_batch(query_batch_stl, size_query_batch_stl);
86+
87+
std::vector<long> support_batch_stl = std::vector<long>(data_sb, data_sb+support_batch.size(0));
88+
std::vector<long> size_support_batch_stl;
89+
CHECK_INPUT(std::is_sorted(support_batch_stl.begin(),support_batch_stl.end()));
90+
get_size_batch(support_batch_stl, size_support_batch_stl);
91+
92+
std::vector<size_t>* neighbors_indices = new std::vector<size_t>();
93+
auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
94+
int max_count = 0;
95+
96+
AT_DISPATCH_ALL_TYPES(query.scalar_type(), "batch_radius_cpu", [&] {
97+
auto data_q = query.data_ptr<scalar_t>();
98+
auto data_s = support.data_ptr<scalar_t>();
99+
std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
100+
data_q + query.size(0)*query.size(1));
101+
std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
102+
data_s + support.size(0)*support.size(1));
103+
104+
int dim = torch::size(query, 1);
105+
max_count = batch_nanoflann_neighbors<scalar_t>(queries_stl,
106+
supports_stl,
107+
size_query_batch_stl,
108+
size_support_batch_stl,
109+
neighbors_indices,
110+
0,
111+
dim,
112+
0,
113+
k, 0);
114+
});
115+
116+
size_t* neighbors_indices_ptr = neighbors_indices->data();
117+
118+
119+
const long long tsize = static_cast<long long>(neighbors_indices->size()/2);
120+
out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
121+
out = out.t();
122+
123+
auto result = torch::zeros_like(out);
124+
125+
auto index = torch::tensor({1,0});
126+
127+
result.index_copy_(0, index, out);
128+
129+
return result;
130+
}

csrc/cpu/knn_cpu.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#pragma once
2+
3+
#include <torch/extension.h>
4+
#include "utils/neighbors.cpp"
5+
#include <iostream>
6+
#include "compat.h"
7+
8+
torch::Tensor knn_cpu(torch::Tensor support, torch::Tensor query,
9+
int64_t k, int64_t n_threads);
10+
11+
torch::Tensor batch_knn_cpu(torch::Tensor support,
12+
torch::Tensor query,
13+
torch::Tensor support_batch,
14+
torch::Tensor query_batch,
15+
int64_t k);

csrc/cpu/radius_cpu.cpp

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <cstdint>
55

66

7-
torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
7+
torch::Tensor radius_cpu(torch::Tensor support, torch::Tensor query,
88
double radius, int64_t max_num, int64_t n_threads){
99

1010
CHECK_CPU(query);
@@ -26,7 +26,7 @@ torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
2626

2727
int dim = torch::size(query, 1);
2828

29-
max_count = nanoflann_neighbors<scalar_t>(queries_stl, supports_stl ,neighbors_indices, radius, dim, max_num, n_threads);
29+
max_count = nanoflann_neighbors<scalar_t>(queries_stl, supports_stl ,neighbors_indices, radius, dim, max_num, n_threads, 0, 1);
3030

3131
});
3232

@@ -36,7 +36,13 @@ torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
3636
out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
3737
out = out.t();
3838

39-
return out.clone();
39+
auto result = torch::zeros_like(out);
40+
41+
auto index = torch::tensor({1,0});
42+
43+
result.index_copy_(0, index, out);
44+
45+
return result;
4046
}
4147

4248

@@ -58,10 +64,10 @@ void get_size_batch(const std::vector<long>& batch, std::vector<long>& res){
5864
res[ind-batch[0]] = incr;
5965
}
6066

61-
torch::Tensor batch_radius_cpu(torch::Tensor query,
62-
torch::Tensor support,
63-
torch::Tensor query_batch,
67+
torch::Tensor batch_radius_cpu(torch::Tensor support,
68+
torch::Tensor query,
6469
torch::Tensor support_batch,
70+
torch::Tensor query_batch,
6571
double radius, int64_t max_num) {
6672

6773
CHECK_CPU(query);
@@ -103,8 +109,8 @@ torch::Tensor batch_radius_cpu(torch::Tensor query,
103109
neighbors_indices,
104110
radius,
105111
dim,
106-
max_num
107-
);
112+
max_num,
113+
0, 1);
108114
});
109115

110116
size_t* neighbors_indices_ptr = neighbors_indices->data();
@@ -114,5 +120,11 @@ torch::Tensor batch_radius_cpu(torch::Tensor query,
114120
out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
115121
out = out.t();
116122

117-
return out.clone();
123+
auto result = torch::zeros_like(out);
124+
125+
auto index = torch::tensor({1,0});
126+
127+
result.index_copy_(0, index, out);
128+
129+
return result;
118130
}

csrc/cpu/radius_cpu.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#pragma once
22

33
#include <torch/extension.h>
4-
//#include "utils/neighbors.h"
54
#include "utils/neighbors.cpp"
65
#include <iostream>
76
#include "compat.h"

csrc/cpu/utils/neighbors.cpp

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <set>
44
#include <cstdint>
55
#include <thread>
6+
#include <iostream>
67

78
typedef struct thread_struct {
89
void* kd_tree;
@@ -15,6 +16,8 @@ typedef struct thread_struct {
1516
size_t end;
1617
double search_radius;
1718
bool small;
19+
bool option;
20+
size_t k;
1821
} thread_args;
1922

2023
template<typename scalar_t>
@@ -37,7 +40,7 @@ void thread_routine(thread_args* targs) {
3740
double search_radius = (double) targs->search_radius;
3841
size_t start = targs->start;
3942
size_t end = targs->end;
40-
43+
auto k = targs->k;
4144
for (size_t i = start; i < end; i++) {
4245

4346
std::vector<scalar_t> p0 = *(((*pcd_query).pts)[i]);
@@ -46,11 +49,23 @@ void thread_routine(thread_args* targs) {
4649
std::copy(p0.begin(), p0.end(), query_pt);
4750
(*matches)[i].reserve(*max_count);
4851
std::vector<std::pair<size_t, scalar_t> > ret_matches;
52+
std::vector<size_t>* knn_ret_matches = new std::vector<size_t>(k);
53+
std::vector<scalar_t>* knn_dist_matches = new std::vector<scalar_t>(k);
4954

5055
tree_m->lock();
5156

52-
const size_t nMatches = index->radiusSearch(query_pt, (scalar_t)(search_radius+eps), ret_matches, nanoflann::SearchParams());
53-
57+
size_t nMatches;
58+
if (targs->option){
59+
nMatches = index->radiusSearch(query_pt, (scalar_t)(search_radius+eps), ret_matches, nanoflann::SearchParams());
60+
}
61+
else {
62+
nMatches = index->knnSearch(query_pt, k, &(*knn_ret_matches)[0],&(* knn_dist_matches)[0]);
63+
auto temp = new std::vector<std::pair<size_t, scalar_t> >((*knn_dist_matches).size());
64+
for (size_t j = 0; j < (*knn_ret_matches).size(); j++){
65+
(*temp)[j] = std::make_pair( (*knn_ret_matches)[j],(*knn_dist_matches)[j] );
66+
}
67+
ret_matches = *temp;
68+
}
5469
tree_m->unlock();
5570

5671
(*matches)[i] = ret_matches;
@@ -67,7 +82,8 @@ void thread_routine(thread_args* targs) {
6782

6883
template<typename scalar_t>
6984
size_t nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>& supports,
70-
std::vector<size_t>*& neighbors_indices, double radius, int dim, int64_t max_num, int64_t n_threads){
85+
std::vector<size_t>*& neighbors_indices, double radius, int dim,
86+
int64_t max_num, int64_t n_threads, int64_t k, int option){
7187

7288
const scalar_t search_radius = static_cast<scalar_t>(radius*radius);
7389

@@ -120,9 +136,21 @@ size_t nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>
120136

121137
(*list_matches)[i0].reserve(*max_count);
122138
std::vector<std::pair<size_t, scalar_t> > ret_matches;
139+
std::vector<size_t>* knn_ret_matches = new std::vector<size_t>(k);
140+
std::vector<scalar_t>* knn_dist_matches = new std::vector<scalar_t>(k);
123141

124-
const size_t nMatches = index->radiusSearch(query_pt, (scalar_t)(search_radius+eps), ret_matches, search_params);
125-
142+
size_t nMatches;
143+
if (!!(option)){
144+
nMatches = index->radiusSearch(query_pt, (scalar_t)(search_radius+eps), ret_matches, search_params);
145+
}
146+
else {
147+
nMatches = index->knnSearch(query_pt, (size_t)k, &(*knn_ret_matches)[0],&(* knn_dist_matches)[0]);
148+
auto temp = new std::vector<std::pair<size_t, scalar_t> >((*knn_dist_matches).size());
149+
for (size_t j = 0; j < (*knn_ret_matches).size(); j++){
150+
(*temp)[j] = std::make_pair( (*knn_ret_matches)[j],(*knn_dist_matches)[j] );
151+
}
152+
ret_matches = *temp;
153+
}
126154
(*list_matches)[i0] = ret_matches;
127155
if(*max_count < nMatches) *max_count = nMatches;
128156
i0++;
@@ -171,6 +199,8 @@ size_t nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>
171199
else {
172200
targs->small = false;
173201
}
202+
targs->option = !!(option);
203+
targs->k = k;
174204
std::thread* temp = new std::thread(thread_routine<scalar_t>, targs);
175205
tid[t] = temp;
176206
}
@@ -220,7 +250,7 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries,
220250
std::vector<long>& q_batches,
221251
std::vector<long>& s_batches,
222252
std::vector<size_t>*& neighbors_indices,
223-
double radius, int dim, int64_t max_num){
253+
double radius, int dim, int64_t max_num, int64_t k, int option){
224254

225255

226256
// indices
@@ -292,14 +322,22 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries,
292322
// Initial guess of neighbors size
293323
all_inds_dists[i0].reserve(max_count);
294324
// Find neighbors
295-
size_t nMatches = index->radiusSearch(query_pt, r2+eps, all_inds_dists[i0], search_params);
296-
// Update max count
297-
298-
std::vector<std::pair<size_t, float> > indices_dists;
299-
nanoflann::RadiusResultSet<float,size_t> resultSet(r2, indices_dists);
300-
301-
index->findNeighbors(resultSet, query_pt, search_params);
302325

326+
size_t nMatches;
327+
if (!!option) {
328+
nMatches = index->radiusSearch(query_pt, r2+eps, all_inds_dists[i0], search_params);
329+
// Update max count
330+
}
331+
else {
332+
std::vector<size_t>* knn_ret_matches = new std::vector<size_t>(k);
333+
std::vector<scalar_t>* knn_dist_matches = new std::vector<scalar_t>(k);
334+
nMatches = index->knnSearch(query_pt, (size_t)k, &(*knn_ret_matches)[0],&(*knn_dist_matches)[0]);
335+
auto temp = new std::vector<std::pair<size_t, scalar_t> >((*knn_dist_matches).size());
336+
for (size_t j = 0; j < (*knn_ret_matches).size(); j++){
337+
(*temp)[j] = std::make_pair( (*knn_ret_matches)[j],(*knn_dist_matches)[j] );
338+
}
339+
all_inds_dists[i0] = *temp;
340+
}
303341
if (nMatches > max_count)
304342
max_count = nMatches;
305343
// Increment query idx

0 commit comments

Comments
 (0)