Skip to content

Commit 55f68ad

Browse files
committed
multhread support for CPU; correctness for large samples
1 parent 962fc02 commit 55f68ad

File tree

8 files changed

+231
-75
lines changed

8 files changed

+231
-75
lines changed

csrc/cpu/radius_cpu.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
8-
double radius, int64_t max_num){
8+
double radius, int64_t max_num, int64_t n_threads){
99

1010
CHECK_CPU(query);
1111
CHECK_CPU(support);
@@ -26,7 +26,7 @@ torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
2626

2727
int dim = torch::size(query, 1);
2828

29-
max_count = nanoflann_neighbors<scalar_t>(queries_stl, supports_stl ,neighbors_indices, radius, dim, max_num);
29+
max_count = nanoflann_neighbors<scalar_t>(queries_stl, supports_stl ,neighbors_indices, radius, dim, max_num, n_threads);
3030

3131
});
3232

@@ -40,7 +40,7 @@ torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
4040
}
4141

4242

43-
void get_size_batch(const vector<long>& batch, vector<long>& res){
43+
void get_size_batch(const std::vector<long>& batch, std::vector<long>& res){
4444

4545
res.resize(batch[batch.size()-1]-batch[0]+1, 0);
4646
long ind = batch[0];

csrc/cpu/radius_cpu.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
#pragma once
22

33
#include <torch/extension.h>
4-
#include "utils/neighbors.h"
4+
//#include "utils/neighbors.h"
55
#include "utils/neighbors.cpp"
66
#include <iostream>
77
#include "compat.h"
88

99
torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
10-
double radius, int64_t max_num);
10+
double radius, int64_t max_num, int64_t n_threads);
1111

1212
torch::Tensor batch_radius_cpu(torch::Tensor query,
1313
torch::Tensor support,

csrc/cpu/utils/neighbors.cpp

Lines changed: 175 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,94 @@
1+
#include "cloud.h"
2+
#include "nanoflann.hpp"
3+
#include <set>
4+
#include <cstdint>
5+
#include <thread>
6+
7+
typedef struct thread_struct {
8+
void* kd_tree;
9+
void* matches;
10+
void* queries;
11+
size_t* max_count;
12+
std::mutex* ct_m;
13+
std::mutex* tree_m;
14+
size_t start;
15+
size_t end;
16+
double search_radius;
17+
bool small;
18+
} thread_args;
119

2-
// 3D Version https://github.com/HuguesTHOMAS/KPConv
20+
template<typename scalar_t>
21+
void thread_routine(thread_args* targs) {
22+
typedef nanoflann::KDTreeSingleIndexAdaptor< nanoflann::L2_Adaptor<scalar_t, PointCloud<scalar_t> > , PointCloud<scalar_t>> my_kd_tree_t;
23+
typedef std::vector< std::vector<std::pair<size_t, scalar_t> > > kd_pair;
24+
my_kd_tree_t* index = (my_kd_tree_t*) targs->kd_tree;
25+
kd_pair* matches = (kd_pair*)targs->matches;
26+
PointCloud<scalar_t>* pcd_query = (PointCloud<scalar_t>*)targs->queries;
27+
size_t* max_count = targs->max_count;
28+
std::mutex* ct_m = targs->ct_m;
29+
std::mutex* tree_m = targs->tree_m;
30+
double eps;
31+
if (targs->small) {
32+
eps = 0.000001;
33+
}
34+
else {
35+
eps = 0;
36+
}
37+
double search_radius = (double) targs->search_radius;
38+
size_t start = targs->start;
39+
size_t end = targs->end;
40+
41+
for (size_t i = start; i < end; i++) {
42+
43+
std::vector<scalar_t> p0 = *(((*pcd_query).pts)[i]);
44+
45+
scalar_t* query_pt = new scalar_t[p0.size()];
46+
std::copy(p0.begin(), p0.end(), query_pt);
47+
(*matches)[i].reserve(*max_count);
48+
std::vector<std::pair<size_t, scalar_t> > ret_matches;
49+
50+
tree_m->lock();
51+
52+
const size_t nMatches = index->radiusSearch(query_pt, (scalar_t)(search_radius+eps), ret_matches, nanoflann::SearchParams());
53+
54+
tree_m->unlock();
55+
56+
(*matches)[i] = ret_matches;
57+
58+
ct_m->lock();
59+
if(*max_count < nMatches) {
60+
*max_count = nMatches;
61+
}
62+
ct_m->unlock();
63+
64+
}
365

4-
#include "neighbors.h"
66+
}
567

668
template<typename scalar_t>
7-
size_t nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
8-
vector<size_t>*& neighbors_indices, double radius, int dim, int64_t max_num){
69+
size_t nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>& supports,
70+
std::vector<size_t>*& neighbors_indices, double radius, int dim, int64_t max_num, int64_t n_threads){
971

1072
const scalar_t search_radius = static_cast<scalar_t>(radius*radius);
1173

1274
// Counting vector
13-
size_t max_count = 1;
75+
size_t* max_count = new size_t();
76+
*max_count = 1;
1477

78+
size_t ssize = supports.size();
1579
// CLoud variable
1680
PointCloud<scalar_t> pcd;
1781
pcd.set(supports, dim);
1882
//Cloud query
19-
PointCloud<scalar_t> pcd_query;
20-
pcd_query.set(queries, dim);
83+
PointCloud<scalar_t>* pcd_query = new PointCloud<scalar_t>();
84+
(*pcd_query).set(queries, dim);
2185

2286
// Tree parameters
2387
nanoflann::KDTreeSingleIndexAdaptorParams tree_params(15 /* max leaf */);
2488

2589
// KDTree type definition
2690
typedef nanoflann::KDTreeSingleIndexAdaptor< nanoflann::L2_Adaptor<scalar_t, PointCloud<scalar_t> > , PointCloud<scalar_t>> my_kd_tree_t;
91+
typedef std::vector< std::vector<std::pair<size_t, scalar_t> > > kd_pair;
2792

2893
// Pointer to trees
2994
my_kd_tree_t* index;
@@ -35,47 +100,114 @@ size_t nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports
35100
// Search params
36101
nanoflann::SearchParams search_params;
37102
// search_params.sorted = true;
38-
std::vector< std::vector<std::pair<size_t, scalar_t> > > list_matches(pcd_query.pts.size());
103+
kd_pair* list_matches = new kd_pair((*pcd_query).pts.size());
104+
105+
// single threaded routine
106+
if (n_threads == 1){
107+
size_t i0 = 0;
108+
double eps;
109+
if (ssize < 10) {
110+
eps = 0.000001;
111+
}
112+
else {
113+
eps = 0;
114+
}
39115

40-
double eps = 0.000001;
116+
for (auto& p : (*pcd_query).pts){
117+
auto p0 = *p;
118+
// Find neighbors
119+
scalar_t* query_pt = new scalar_t[dim];
120+
std::copy(p0.begin(), p0.end(), query_pt);
41121

42-
// indices
43-
size_t i0 = 0;
122+
(*list_matches)[i0].reserve(*max_count);
123+
std::vector<std::pair<size_t, scalar_t> > ret_matches;
44124

45-
for (auto& p : pcd_query.pts){
46-
auto p0 = *p;
47-
// Find neighbors
48-
scalar_t* query_pt = new scalar_t[dim];
49-
std::copy(p0.begin(), p0.end(), query_pt);
125+
const size_t nMatches = index->radiusSearch(query_pt, (scalar_t)(search_radius+eps), ret_matches, search_params);
126+
127+
(*list_matches)[i0] = ret_matches;
128+
if(*max_count < nMatches) *max_count = nMatches;
129+
i0++;
50130

51-
list_matches[i0].reserve(max_count);
52-
std::vector<std::pair<size_t, scalar_t> > ret_matches;
131+
}
132+
}
133+
else {// Multi-threaded routine
134+
std::mutex* mtx = new std::mutex();
135+
std::mutex* mtx_tree = new std::mutex();
53136

54-
const size_t nMatches = index->radiusSearch(query_pt, (scalar_t)(search_radius+eps), ret_matches, search_params);
55-
56-
list_matches[i0] = ret_matches;
57-
if(max_count < nMatches) max_count = nMatches;
58-
i0++;
137+
size_t n_queries = (*pcd_query).pts.size();
138+
size_t actual_threads = std::min((long long)n_threads, (long long)n_queries);
139+
140+
std::thread* tid[actual_threads];
59141

142+
size_t start, end;
143+
size_t length;
144+
if (n_queries) {
145+
length = 1;
146+
}
147+
else {
148+
auto res = std::lldiv((long long)n_queries, (long long)n_threads);
149+
length = (size_t)res.quot;
150+
/*
151+
if (res.rem == 0) {
152+
length = res.quot;
153+
}
154+
else {
155+
length =
156+
}
157+
*/
158+
}
159+
for (size_t t = 0; t < actual_threads; t++) {
160+
//sem->wait();
161+
start = t*length;
162+
if (t == actual_threads-1) {
163+
end = n_queries;
164+
}
165+
else {
166+
end = (t+1)*length;
167+
}
168+
thread_args* targs = new thread_args();
169+
targs->kd_tree = index;
170+
targs->matches = list_matches;
171+
targs->max_count = max_count;
172+
targs->ct_m = mtx;
173+
targs->tree_m = mtx_tree;
174+
targs->search_radius = search_radius;
175+
targs->queries = pcd_query;
176+
targs->start = start;
177+
targs->end = end;
178+
if (ssize < 10) {
179+
targs->small = true;
180+
}
181+
else {
182+
targs->small = false;
183+
}
184+
std::thread* temp = new std::thread(thread_routine<scalar_t>, targs);
185+
tid[t] = temp;
186+
}
187+
188+
for (size_t t = 0; t < actual_threads; t++){
189+
tid[t]->join();
190+
}
60191
}
192+
61193
// Reserve the memory
62194
if(max_num > 0) {
63-
max_count = max_num;
195+
*max_count = max_num;
64196
}
65197

66198
size_t size = 0; // total number of edges
67-
for (auto& inds : list_matches){
68-
if(inds.size() <= max_count)
199+
for (auto& inds : *list_matches){
200+
if(inds.size() <= *max_count)
69201
size += inds.size();
70202
else
71-
size += max_count;
203+
size += *max_count;
72204
}
73205

74206
neighbors_indices->resize(size*2);
75207
size_t i1 = 0; // index of the query points
76208
size_t u = 0; // curent index of the neighbors_indices
77-
for (auto& inds : list_matches){
78-
for (size_t j = 0; j < max_count; j++){
209+
for (auto& inds : *list_matches){
210+
for (size_t j = 0; j < *max_count; j++){
79211
if(j < inds.size()){
80212
(*neighbors_indices)[u] = inds[j].first;
81213
(*neighbors_indices)[u + 1] = i1;
@@ -85,19 +217,19 @@ size_t nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports
85217
i1++;
86218
}
87219

88-
return max_count;
220+
return *max_count;
89221

90222

91223

92224

93225
}
94226

95227
template<typename scalar_t>
96-
size_t batch_nanoflann_neighbors (vector<scalar_t>& queries,
97-
vector<scalar_t>& supports,
98-
vector<long>& q_batches,
99-
vector<long>& s_batches,
100-
vector<size_t>*& neighbors_indices,
228+
size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries,
229+
std::vector<scalar_t>& supports,
230+
std::vector<long>& q_batches,
231+
std::vector<long>& s_batches,
232+
std::vector<size_t>*& neighbors_indices,
101233
double radius, int dim, int64_t max_num){
102234

103235

@@ -117,15 +249,21 @@ size_t batch_nanoflann_neighbors (vector<scalar_t>& queries,
117249
size_t sum_qb = 0;
118250
size_t sum_sb = 0;
119251

120-
double eps = 0.000001;
252+
double eps;
253+
if (supports.size() < 10){
254+
eps = 0.000001;
255+
}
256+
else {
257+
eps = 0;
258+
}
121259
// Nanoflann related variables
122260
// ***************************
123261

124262
// CLoud variable
125263
PointCloud<scalar_t> current_cloud;
126264
PointCloud<scalar_t> query_pcd;
127265
query_pcd.set(queries, dim);
128-
vector<vector<pair<size_t, scalar_t> > > all_inds_dists(query_pcd.pts.size());
266+
std::vector<std::vector<std::pair<size_t, scalar_t> > > all_inds_dists(query_pcd.pts.size());
129267

130268
// Tree parameters
131269
nanoflann::KDTreeSingleIndexAdaptorParams tree_params(10 /* max leaf */);

csrc/cpu/utils/neighbors.h

Lines changed: 0 additions & 22 deletions
This file was deleted.

csrc/radius.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <Python.h>
22
#include <torch/script.h>
3+
#include <iostream>
34

45
#ifdef WITH_CUDA
56
#include "cuda/radius_cuda.h"
@@ -11,7 +12,7 @@ PyMODINIT_FUNC PyInit__radius(void) { return NULL; }
1112
#endif
1213

1314
torch::Tensor radius(torch::Tensor x, torch::Tensor y, torch::optional<torch::Tensor> ptr_x,
14-
torch::optional<torch::Tensor> ptr_y, double r, int64_t max_num_neighbors) {
15+
torch::optional<torch::Tensor> ptr_y, double r, int64_t max_num_neighbors, int64_t n_threads) {
1516
if (x.device().is_cuda()) {
1617
#ifdef WITH_CUDA
1718
if (!(ptr_x.has_value()) && !(ptr_y.has_value())) {
@@ -37,7 +38,7 @@ torch::Tensor radius(torch::Tensor x, torch::Tensor y, torch::optional<torch::Te
3738
#endif
3839
} else {
3940
if (!(ptr_x.has_value()) && !(ptr_y.has_value())) {
40-
return radius_cpu(x,y,r,max_num_neighbors);
41+
return radius_cpu(x,y,r,max_num_neighbors, n_threads);
4142
}
4243
if (!(ptr_x.has_value())) {
4344
auto batch_x = torch::zeros({torch::size(x,0)}).to(torch::kLong);

test/radius_test_large.pkl

960 KB
Binary file not shown.

0 commit comments

Comments
 (0)