Skip to content

Commit 1111319

Browse files
committed
options support; correctness across dimensions; more testing
1 parent 25d62f0 commit 1111319

File tree

8 files changed

+682
-49
lines changed

8 files changed

+682
-49
lines changed

csrc/cpu/radius_cpu.cpp

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,12 @@
11
#include "radius_cpu.h"
2-
2+
#include <algorithm>
33
#include "utils.h"
44

5-
torch::Tensor radius_cpu(torch::Tensor q, torch::Tensor s,
6-
torch::Tensor ptr_x, torch::Tensor ptr_y,
5+
torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
76
float radius, int max_num){
87

9-
CHECK_CPU(q);
10-
CHECK_CPU(s);
11-
12-
/*
13-
x = torch.cat([x, 2 * r * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
14-
y = torch.cat([y, 2 * r * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
15-
*/
16-
17-
auto batch_x = ptr_x.clone().reshape({-1, 1});
18-
auto batch_y = ptr_y.clone().reshape({-1, 1});
19-
20-
batch_x.mul_(2*radius);
21-
batch_y.mul_(2*radius);
22-
23-
auto query = torch::cat({q,batch_x},-1);
24-
auto support = torch::cat({s,batch_y},-1);
8+
CHECK_CPU(query);
9+
CHECK_CPU(support);
2510

2611
torch::Tensor out;
2712
std::vector<long> neighbors_indices;
@@ -58,6 +43,7 @@ torch::Tensor radius_cpu(torch::Tensor q, torch::Tensor s,
5843
return result;
5944
}
6045

46+
6147
void get_size_batch(const vector<long>& batch, vector<long>& res){
6248

6349
res.resize(batch[batch.size()-1]-batch[0]+1, 0);
@@ -74,4 +60,54 @@ void get_size_batch(const vector<long>& batch, vector<long>& res){
7460
}
7561
}
7662
res[ind-batch[0]] = incr;
63+
}
64+
65+
torch::Tensor batch_radius_cpu(torch::Tensor query,
66+
torch::Tensor support,
67+
torch::Tensor query_batch,
68+
torch::Tensor support_batch,
69+
float radius, int max_num) {
70+
71+
torch::Tensor out;
72+
auto data_qb = query_batch.data_ptr<long>();
73+
auto data_sb = support_batch.data_ptr<long>();
74+
std::vector<long> query_batch_stl = std::vector<long>(data_qb, data_qb+query_batch.size(0));
75+
std::vector<long> size_query_batch_stl;
76+
get_size_batch(query_batch_stl, size_query_batch_stl);
77+
std::vector<long> support_batch_stl = std::vector<long>(data_sb, data_sb+support_batch.size(0));
78+
std::vector<long> size_support_batch_stl;
79+
get_size_batch(support_batch_stl, size_support_batch_stl);
80+
std::vector<long> neighbors_indices;
81+
auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
82+
int max_count = 0;
83+
84+
85+
AT_DISPATCH_ALL_TYPES(query.scalar_type(), "batch_radius_search", [&] {
86+
auto data_q = query.data_ptr<scalar_t>();
87+
auto data_s = support.data_ptr<scalar_t>();
88+
std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
89+
data_q + query.size(0)*query.size(1));
90+
std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
91+
data_s + support.size(0)*support.size(1));
92+
93+
int dim = torch::size(query, 1);
94+
max_count = batch_nanoflann_neighbors<scalar_t>(queries_stl,
95+
supports_stl,
96+
size_query_batch_stl,
97+
size_support_batch_stl,
98+
neighbors_indices,
99+
radius,
100+
dim,
101+
max_num
102+
);
103+
});
104+
105+
long* neighbors_indices_ptr = neighbors_indices.data();
106+
107+
108+
const long long tsize = static_cast<long long>(neighbors_indices.size()/2);
109+
out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
110+
out = out.t();
111+
112+
return out.clone();
77113
}

csrc/cpu/radius_cpu.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,10 @@
77
#include "compat.h"
88

99
torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
10-
torch::Tensor ptr_x, torch::Tensor ptr_y,
11-
float radius, int max_num);
10+
float radius, int max_num);
11+
12+
torch::Tensor batch_radius_cpu(torch::Tensor query,
13+
torch::Tensor support,
14+
torch::Tensor query_batch,
15+
torch::Tensor support_batch,
16+
float radius, int max_num);

csrc/cpu/utils/cloud.h

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,13 @@ struct PointCloud
2424

2525
void set(std::vector<scalar_t> new_pts, int dim){
2626

27-
// pts = std::vector<Point>((Point*)new_pts, (Point*)new_pts+new_pts.size()/3);
2827
std::vector<std::vector<scalar_t>> temp(new_pts.size()/dim);
2928
for(size_t i=0; i < new_pts.size(); i++){
3029
if(i%dim == 0){
31-
32-
//Point point;
3330
std::vector<scalar_t> point(dim);
34-
//std::vector<scalar_t> vect(sizeof(scalar_t)*dim, 0)
35-
//point.pt = temp;
31+
3632
for (size_t j = 0; j < (size_t)dim; j++) {
3733
point[j]=new_pts[i+j];
38-
//point.pt[j] = new_pts[i+j];
3934
}
4035
temp[i/dim] = point;
4136
}
@@ -46,7 +41,6 @@ struct PointCloud
4641
void set_batch(std::vector<scalar_t> new_pts, int begin, int size, int dim){
4742
std::vector<std::vector<scalar_t>> temp(size);
4843
for(int i=0; i < size; i++){
49-
//std::vector<scalar_t> temp(sizeof(scalar_t)*dim, 0);
5044
std::vector<scalar_t> point(dim);
5145
for (size_t j = 0; j < (size_t)dim; j++) {
5246
point[j] = new_pts[dim*(begin+i)+j];

csrc/cpu/utils/neighbors.cpp

Lines changed: 144 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
4040

4141
// Search params
4242
nanoflann::SearchParams search_params;
43-
search_params.sorted = true;
43+
// search_params.sorted = true;
4444
std::vector< std::vector<std::pair<size_t, scalar_t> > > list_matches(pcd_query.pts.size());
4545

46-
float eps = 0.00001;
46+
float eps = 0.000001;
4747

4848
// indices
4949
size_t i0 = 0;
@@ -61,6 +61,12 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
6161
std::vector<std::pair<size_t, scalar_t> > ret_matches;
6262

6363
const size_t nMatches = index->radiusSearch(query_pt, search_radius+eps, ret_matches, search_params);
64+
65+
//cout << "radiusSearch(): radius=" << search_radius << " -> " << nMatches << " matches\n";
66+
//for (size_t i = 0; i < nMatches; i++)
67+
// cout << "idx["<< i << "]=" << ret_matches[i].first << " dist["<< i << "]=" << ret_matches[i].second << endl;
68+
//cout << "\n";
69+
6470
list_matches[i0] = ret_matches;
6571
if(max_count < nMatches) max_count = nMatches;
6672
i0++;
@@ -107,4 +113,139 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
107113

108114

109115

110-
}
116+
}
117+
118+
template<typename scalar_t>
119+
int batch_nanoflann_neighbors (vector<scalar_t>& queries,
120+
vector<scalar_t>& supports,
121+
vector<long>& q_batches,
122+
vector<long>& s_batches,
123+
vector<long>& neighbors_indices,
124+
float radius, int dim, int max_num){
125+
126+
127+
// Initiate variables
128+
// ******************
129+
// indices
130+
int i0 = 0;
131+
132+
// Square radius
133+
const scalar_t r2 = static_cast<scalar_t>(radius*radius);
134+
135+
// Counting vector
136+
int max_count = 0;
137+
float d2;
138+
139+
140+
// batch index
141+
long b = 0;
142+
long sum_qb = 0;
143+
long sum_sb = 0;
144+
145+
float eps = 0.000001;
146+
// Nanoflann related variables
147+
// ***************************
148+
149+
// CLoud variable
150+
PointCloud<scalar_t> current_cloud;
151+
PointCloud<scalar_t> query_pcd;
152+
query_pcd.set(queries, dim);
153+
vector<vector<pair<size_t, scalar_t> > > all_inds_dists(query_pcd.pts.size());
154+
155+
// Tree parameters
156+
nanoflann::KDTreeSingleIndexAdaptorParams tree_params(10 /* max leaf */);
157+
158+
// KDTree type definition
159+
typedef nanoflann::KDTreeSingleIndexAdaptor< nanoflann::L2_Adaptor<scalar_t, PointCloud<scalar_t> > , PointCloud<scalar_t>> my_kd_tree_t;
160+
161+
// Pointer to trees
162+
my_kd_tree_t* index;
163+
// Build KDTree for the first batch element
164+
current_cloud.set_batch(supports, sum_sb, s_batches[b], dim);
165+
index = new my_kd_tree_t(dim, current_cloud, tree_params);
166+
index->buildIndex();
167+
// Search neigbors indices
168+
// ***********************
169+
// Search params
170+
nanoflann::SearchParams search_params;
171+
search_params.sorted = true;
172+
173+
for (auto& p0 : query_pcd.pts){
174+
// Check if we changed batch
175+
176+
scalar_t query_pt[dim];
177+
std::copy(p0.begin(), p0.end(), query_pt);
178+
179+
/*
180+
std::cout << "\n ========== \n";
181+
for(int i=0; i < dim; i++)
182+
std::cout << query_pt[i] << '\n';
183+
std::cout << "\n ========== \n";
184+
*/
185+
186+
if (i0 == sum_qb + q_batches[b]){
187+
sum_qb += q_batches[b];
188+
sum_sb += s_batches[b];
189+
b++;
190+
191+
// Change the points
192+
current_cloud.pts.clear();
193+
current_cloud.set_batch(supports, sum_sb, s_batches[b], dim);
194+
// Build KDTree of the current element of the batch
195+
delete index;
196+
index = new my_kd_tree_t(dim, current_cloud, tree_params);
197+
index->buildIndex();
198+
}
199+
// Initial guess of neighbors size
200+
all_inds_dists[i0].reserve(max_count);
201+
// Find neighbors
202+
size_t nMatches = index->radiusSearch(query_pt, r2+eps, all_inds_dists[i0], search_params);
203+
// Update max count
204+
205+
std::vector<std::pair<size_t, float> > indices_dists;
206+
nanoflann::RadiusResultSet<float,size_t> resultSet(r2, indices_dists);
207+
208+
index->findNeighbors(resultSet, query_pt, search_params);
209+
210+
if (nMatches > max_count)
211+
max_count = nMatches;
212+
// Increment query idx
213+
i0++;
214+
}
215+
// how many neighbors do we keep
216+
if(max_num > 0) {
217+
max_count = max_num;
218+
}
219+
// Reserve the memory
220+
221+
int size = 0; // total number of edges
222+
for (auto& inds_dists : all_inds_dists){
223+
if(inds_dists.size() <= max_count)
224+
size += inds_dists.size();
225+
else
226+
size += max_count;
227+
}
228+
neighbors_indices.resize(size * 2);
229+
i0 = 0;
230+
sum_sb = 0;
231+
sum_qb = 0;
232+
b = 0;
233+
int u = 0;
234+
for (auto& inds_dists : all_inds_dists){
235+
if (i0 == sum_qb + q_batches[b]){
236+
sum_qb += q_batches[b];
237+
sum_sb += s_batches[b];
238+
b++;
239+
}
240+
for (int j = 0; j < max_count; j++){
241+
if (j < inds_dists.size()){
242+
neighbors_indices[u] = inds_dists[j].first + sum_sb;
243+
neighbors_indices[u + 1] = i0;
244+
u += 2;
245+
}
246+
}
247+
i0++;
248+
}
249+
250+
return max_count;
251+
}

csrc/cpu/utils/neighbors.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,12 @@ using namespace std;
1010

1111
template<typename scalar_t>
1212
int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
13-
vector<long>& neighbors_indices, float radius, int dim, int max_num, int mode);
13+
vector<long>& neighbors_indices, float radius, int dim, int max_num);
14+
15+
template<typename scalar_t>
16+
int batch_nanoflann_neighbors (vector<scalar_t>& queries,
17+
vector<scalar_t>& supports,
18+
vector<long>& q_batches,
19+
vector<long>& s_batches,
20+
vector<long>& neighbors_indices,
21+
float radius, int dim, int max_num);

csrc/radius.cpp

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,48 @@
1010
PyMODINIT_FUNC PyInit__radius(void) { return NULL; }
1111
#endif
1212

13-
torch::Tensor radius(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
14-
torch::Tensor ptr_y, double r, int64_t max_num_neighbors) {
13+
torch::Tensor radius(torch::Tensor x, torch::Tensor y, torch::optional<torch::Tensor> ptr_x,
14+
torch::optional<torch::Tensor> ptr_y, double r, int64_t max_num_neighbors) {
1515
if (x.device().is_cuda()) {
1616
#ifdef WITH_CUDA
17-
return radius_cuda(x, y, ptr_x, ptr_y, r, max_num_neighbors);
17+
if (!(ptr_x.has_value()) && !(ptr_y.has_value())) {
18+
auto batch_x = torch::tensor({0,torch::size(x,0)}).to(torch::kLong).to(torch::kCUDA);
19+
auto batch_y = torch::tensor({0,torch::size(y,0)}).to(torch::kLong).to(torch::kCUDA);
20+
return radius_cuda(x, y, batch_x, batch_y, r, max_num_neighbors);
21+
}
22+
else if (!(ptr_x.has_value())) {
23+
auto batch_x = torch::tensor({0,torch::size(x,0)}).to(torch::kLong).to(torch::kCUDA);
24+
auto batch_y = ptr_y.value();
25+
return radius_cuda(x, y, batch_x, batch_y, r, max_num_neighbors);
26+
}
27+
else if (!(ptr_y.has_value())) {
28+
auto batch_x = ptr_x.value();
29+
auto batch_y = torch::tensor({0,torch::size(y,0)}).to(torch::kLong).to(torch::kCUDA);
30+
return radius_cuda(x, y, batch_x, batch_y, r, max_num_neighbors);
31+
}
32+
auto batch_x = ptr_x.value();
33+
auto batch_y = ptr_y.value();
34+
return radius_cuda(x, y, batch_x, batch_y, r, max_num_neighbors);
1835
#else
1936
AT_ERROR("Not compiled with CUDA support");
2037
#endif
2138
} else {
22-
return radius_cpu(x, y, ptr_x, ptr_y, r, max_num_neighbors);
39+
if (!(ptr_x.has_value()) && !(ptr_y.has_value())) {
40+
return radius_cpu(x,y,r,max_num_neighbors);
41+
}
42+
if (!(ptr_x.has_value())) {
43+
auto batch_x = torch::zeros({torch::size(x,0)}).to(torch::kLong);
44+
auto batch_y = ptr_y.value();
45+
return batch_radius_cpu(x, y, batch_x, batch_y, r, max_num_neighbors);
46+
}
47+
else if (!(ptr_y.has_value())) {
48+
auto batch_x = ptr_x.value();
49+
auto batch_y = torch::zeros({torch::size(y,0)}).to(torch::kLong);
50+
return batch_radius_cpu(x, y, batch_x, batch_y, r, max_num_neighbors);
51+
}
52+
auto batch_x = ptr_x.value();
53+
auto batch_y = ptr_y.value();
54+
return batch_radius_cpu(x, y, batch_x, batch_y, r, max_num_neighbors);
2355
}
2456
}
2557

0 commit comments

Comments
 (0)