Skip to content

Commit 59dfc41

Browse files
committed
cleaning up according to flake8; c++ api now working (except options)
1 parent 303e889 commit 59dfc41

File tree

6 files changed

+50
-264
lines changed

6 files changed

+50
-264
lines changed

csrc/cpu/radius_cpu.cpp

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,26 @@
22

33
#include "utils.h"
44

5-
torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
5+
torch::Tensor radius_cpu(torch::Tensor q, torch::Tensor s,
66
torch::Tensor ptr_x, torch::Tensor ptr_y,
77
float radius, int max_num){
88

9-
CHECK_CPU(query);
10-
CHECK_CPU(support);
9+
CHECK_CPU(q);
10+
CHECK_CPU(s);
1111

1212
/*
1313
x = torch.cat([x, 2 * r * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
1414
y = torch.cat([y, 2 * r * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
15-
auto batch_x = ptr_x.clone();
16-
auto batch_y = ptr_y.clone();
15+
*/
1716

18-
batch_x._mul(2*radius);
19-
batch_y._mul(2*radius);
17+
auto batch_x = ptr_x.clone().reshape({-1, 1});
18+
auto batch_y = ptr_y.clone().reshape({-1, 1});
2019

21-
auto query = torch::cat({query,batch_x},-1);
22-
auto support = torch::cat({support,batch_y},-1);
23-
*/
20+
batch_x.mul_(2*radius);
21+
batch_y.mul_(2*radius);
22+
23+
auto query = torch::cat({q,batch_x},-1);
24+
auto support = torch::cat({s,batch_y},-1);
2425

2526
torch::Tensor out;
2627
std::vector<long> neighbors_indices;
@@ -43,9 +44,18 @@ torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
4344
});
4445

4546
long* neighbors_indices_ptr = neighbors_indices.data();
46-
out = torch::from_blob(neighbors_indices_ptr, {neighbors_indices.size()/2, 2}, options=options);
4747

48-
return out.t().clone();
48+
const long long tsize = static_cast<long long>(neighbors_indices.size()/2);
49+
out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
50+
out = out.t();
51+
52+
auto result = torch::zeros_like(out);
53+
54+
auto index = torch::tensor({0,1});
55+
56+
result.index_copy_(0, index, out);
57+
58+
return result;
4959
}
5060

5161
void get_size_batch(const vector<long>& batch, vector<long>& res){

csrc/cpu/radius_cpu.h

Lines changed: 3 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,11 @@
11
#pragma once
22

33
#include <torch/extension.h>
4-
#include <ATen/ATen.h>
54
#include "utils/neighbors.h"
65
#include "utils/neighbors.cpp"
76
#include <iostream>
87
#include "compat.h"
98

10-
torch::Tensor radius_cpu(torch::Tensor query,
11-
torch::Tensor support,torch::Tensor ptr_x,
12-
torch::Tensor ptr_y,
13-
float radius, int max_num);
14-
/*
15-
using namespace pybind11::literals;
16-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
17-
m.def("radius_search",
18-
&radius_search,
19-
"compute the radius search of a point cloud using nanoflann"
20-
"-query : a pytorch tensor of size N1 x d,. used to query the nearest neighbors"
21-
"- support : a pytorch tensor of size N2 x d. used to build the tree"
22-
"- radius : float number, size of the ball for the radius search."
23-
"- max_num : int number, indicate the maximum of neaghbors allowed(if -1 then all the possible neighbors will be computed). "
24-
" - mode : int number that indicate which format for the neighborhood"
25-
" mode=0 mean a matrix of neighbors(-1 for shadow neighbors)"
26-
"mode=1 means a matrix of edges of size Num_edge x 2"
27-
"return a tensor of size N1 x M where M is either max_num or the maximum number of neighbors found if mode = 0, if mode=1 return a tensor of size Num_edge x 2.",
28-
"query"_a, "support"_a, "radius"_a, "dim"_a, "max_num"_a=-1, "mode"_a=0);
29-
m.def("batch_radius_search",
30-
&batch_radius_search,
31-
"compute the radius search of a point cloud for each batch using nanoflann"
32-
"-query : a pytorch tensor (float) of size N1 x d,. used to query the nearest neighbors"
33-
"- support : a pytorch tensor(float) of size N2 x d. used to build the tree"
34-
"- query_batch : a pytorch tensor(long) contains indices of the batch of the query size N1"
35-
"NB : the batch must be sorted"
36-
"- support_batch: a pytorch tensor(long) contains indices of the batch of the support size N2"
37-
"NB: the batch must be sorted"
38-
"-radius: float number, size of the ball for the radius search."
39-
"- max_num : int number, indicate the maximum of neaghbors allowed(if -1 then all the possible neighbors wrt the radius will be computed)."
40-
"- mode : int number that indicate which format for the neighborhood"
41-
"mode=0 mean a matrix of neighbors(N2 for shadow neighbors)"
42-
"mode=1 means a matrix of edges of size Num_edge x 2"
43-
"return a tensor of size N1 x M where M is either max_num or the maximum number of neighbors found if mode = 0, if mode=1 return a tensor of size Num_edge x 2.",
44-
"query"_a, "support"_a, "query_batch"_a, "support_batch"_a, "radius"_a, "dim"_a, "max_num"_a=-1, "mode"_a=0);
45-
}
46-
*/
9+
torch::Tensor radius_cpu(torch::Tensor query, torch::Tensor support,
10+
torch::Tensor ptr_x, torch::Tensor ptr_y,
11+
float radius, int max_num);

csrc/cpu/utils/cloud.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@ struct PointCloud
2626

2727
// pts = std::vector<Point>((Point*)new_pts, (Point*)new_pts+new_pts.size()/3);
2828
std::vector<std::vector<scalar_t>> temp(new_pts.size()/dim);
29-
for(unsigned int i=0; i < new_pts.size(); i++){
29+
for(size_t i=0; i < new_pts.size(); i++){
3030
if(i%dim == 0){
3131

3232
//Point point;
3333
std::vector<scalar_t> point(dim);
3434
//std::vector<scalar_t> vect(sizeof(scalar_t)*dim, 0)
3535
//point.pt = temp;
36-
for (unsigned int j = 0; j < dim; j++) {
36+
for (size_t j = 0; j < (size_t)dim; j++) {
3737
point[j]=new_pts[i+j];
3838
//point.pt[j] = new_pts[i+j];
3939
}
@@ -47,9 +47,8 @@ struct PointCloud
4747
std::vector<std::vector<scalar_t>> temp(size);
4848
for(int i=0; i < size; i++){
4949
//std::vector<scalar_t> temp(sizeof(scalar_t)*dim, 0);
50-
//point.pt = temp;
5150
std::vector<scalar_t> point(dim);
52-
for (unsigned int j = 0; j < dim; j++) {
51+
for (size_t j = 0; j < (size_t)dim; j++) {
5352
point[j] = new_pts[dim*(begin+i)+j];
5453
}
5554

csrc/cpu/utils/neighbors.cpp

Lines changed: 12 additions & 177 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
// Taken from https://github.com/HuguesTHOMAS/KPConv
2+
// 3D Version https://github.com/HuguesTHOMAS/KPConv
33

44
#include "neighbors.h"
55

@@ -10,16 +10,10 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
1010
// Initiate variables
1111
// ******************
1212

13-
// square radius
14-
1513
const scalar_t search_radius = static_cast<scalar_t>(radius*radius);
1614

17-
// indices
18-
int i0 = 0;
19-
2015
// Counting vector
21-
int max_count = 1;
22-
float d2;
16+
size_t max_count = 1;
2317

2418
// Nanoflann related variables
2519
// ***************************
@@ -32,7 +26,7 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
3226
pcd_query.set(queries, dim);
3327

3428
// Tree parameters
35-
nanoflann::KDTreeSingleIndexAdaptorParams tree_params(10 /* max leaf */);
29+
nanoflann::KDTreeSingleIndexAdaptorParams tree_params(15 /* max leaf */);
3630

3731
// KDTree type definition
3832
typedef nanoflann::KDTreeSingleIndexAdaptor< nanoflann::L2_Adaptor<scalar_t, PointCloud<scalar_t> > , PointCloud<scalar_t>> my_kd_tree_t;
@@ -51,6 +45,9 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
5145

5246
float eps = 0.00001;
5347

48+
// indices
49+
size_t i0 = 0;
50+
5451
for (auto& p0 : pcd_query.pts){
5552

5653
// Find neighbors
@@ -62,7 +59,6 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
6259

6360
list_matches[i0].reserve(max_count);
6461
std::vector<std::pair<size_t, scalar_t> > ret_matches;
65-
//nanoflann::RadiusResultSet<float,size_t> resultSet(r2, indices_dists);
6662

6763
const size_t nMatches = index->radiusSearch(&query_pt[0], search_radius+eps, ret_matches, search_params);
6864
list_matches[i0] = ret_matches;
@@ -84,19 +80,19 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
8480
max_count = max_num;
8581
}
8682

87-
int size = 0; // total number of edges
83+
size_t size = 0; // total number of edges
8884
for (auto& inds : list_matches){
8985
if(inds.size() <= max_count)
9086
size += inds.size();
9187
else
9288
size += max_count;
9389
}
94-
90+
9591
neighbors_indices.resize(size*2);
96-
int i1 = 0; // index of the query points
97-
int u = 0; // curent index of the neighbors_indices
92+
size_t i1 = 0; // index of the query points
93+
size_t u = 0; // curent index of the neighbors_indices
9894
for (auto& inds : list_matches){
99-
for (int j = 0; j < max_count; j++){
95+
for (size_t j = 0; j < max_count; j++){
10096
if(j < inds.size()){
10197
neighbors_indices[u] = inds[j].first;
10298
neighbors_indices[u + 1] = i1;
@@ -111,165 +107,4 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
111107

112108

113109

114-
}
115-
template<typename scalar_t>
116-
int batch_nanoflann_neighbors (vector<scalar_t>& queries,
117-
vector<scalar_t>& supports,
118-
vector<long>& q_batches,
119-
vector<long>& s_batches,
120-
vector<long>& neighbors_indices,
121-
float radius, int dim, int max_num,
122-
int mode){
123-
124-
125-
// Initiate variables
126-
// ******************
127-
// indices
128-
int i0 = 0;
129-
130-
// Square radius
131-
float r2 = radius * radius;
132-
133-
// Counting vector
134-
int max_count = 0;
135-
float d2;
136-
137-
138-
// batch index
139-
long b = 0;
140-
long sum_qb = 0;
141-
long sum_sb = 0;
142-
143-
// Nanoflann related variables
144-
// ***************************
145-
146-
// CLoud variable
147-
PointCloud<scalar_t> current_cloud;
148-
PointCloud<scalar_t> query_pcd;
149-
query_pcd.set(queries, dim);
150-
vector<vector<pair<size_t, scalar_t> > > all_inds_dists(query_pcd.pts.size());
151-
152-
// Tree parameters
153-
nanoflann::KDTreeSingleIndexAdaptorParams tree_params(10 /* max leaf */);
154-
155-
// KDTree type definition
156-
typedef nanoflann::KDTreeSingleIndexAdaptor< nanoflann::L2_Adaptor<scalar_t, PointCloud<scalar_t> > , PointCloud<scalar_t>> my_kd_tree_t;
157-
158-
// Pointer to trees
159-
my_kd_tree_t* index;
160-
// Build KDTree for the first batch element
161-
current_cloud.set_batch(supports, sum_sb, s_batches[b], dim);
162-
index = new my_kd_tree_t(dim, current_cloud, tree_params);
163-
index->buildIndex();
164-
// Search neigbors indices
165-
// ***********************
166-
// Search params
167-
nanoflann::SearchParams search_params;
168-
search_params.sorted = true;
169-
170-
for (auto& p0 : query_pcd.pts){
171-
// Check if we changed batch
172-
173-
scalar_t query_pt[dim];
174-
std::copy(p0.begin(), p0.end(), query_pt);
175-
176-
/*
177-
std::cout << "\n ========== \n";
178-
for(int i=0; i < dim; i++)
179-
std::cout << query_pt[i] << '\n';
180-
std::cout << "\n ========== \n";
181-
*/
182-
183-
if (i0 == sum_qb + q_batches[b]){
184-
sum_qb += q_batches[b];
185-
sum_sb += s_batches[b];
186-
b++;
187-
188-
// Change the points
189-
current_cloud.pts.clear();
190-
current_cloud.set_batch(supports, sum_sb, s_batches[b], dim);
191-
// Build KDTree of the current element of the batch
192-
delete index;
193-
index = new my_kd_tree_t(dim, current_cloud, tree_params);
194-
index->buildIndex();
195-
}
196-
// Initial guess of neighbors size
197-
all_inds_dists[i0].reserve(max_count);
198-
// Find neighbors
199-
size_t nMatches = index->radiusSearch(query_pt, r2, all_inds_dists[i0], search_params);
200-
// Update max count
201-
202-
std::vector<std::pair<size_t, float> > indices_dists;
203-
nanoflann::RadiusResultSet<float,size_t> resultSet(r2, indices_dists);
204-
205-
index->findNeighbors(resultSet, query_pt, search_params);
206-
207-
if (nMatches > max_count)
208-
max_count = nMatches;
209-
// Increment query idx
210-
i0++;
211-
}
212-
// how many neighbors do we keep
213-
if(max_num > 0) {
214-
max_count = max_num;
215-
}
216-
// Reserve the memory
217-
if(mode == 0){
218-
neighbors_indices.resize(query_pcd.pts.size() * max_count);
219-
i0 = 0;
220-
sum_sb = 0;
221-
sum_qb = 0;
222-
b = 0;
223-
224-
for (auto& inds_dists : all_inds_dists){// Check if we changed batch
225-
226-
if (i0 == sum_qb + q_batches[b]){
227-
sum_qb += q_batches[b];
228-
sum_sb += s_batches[b];
229-
b++;
230-
}
231-
232-
for (int j = 0; j < max_count; j++){
233-
if (j < inds_dists.size())
234-
neighbors_indices[i0 * max_count + j] = inds_dists[j].first + sum_sb;
235-
else
236-
neighbors_indices[i0 * max_count + j] = supports.size();
237-
238-
}
239-
240-
i0++;
241-
}
242-
delete index;
243-
}
244-
else if(mode == 1){
245-
int size = 0; // total number of edges
246-
for (auto& inds_dists : all_inds_dists){
247-
if(inds_dists.size() <= max_count)
248-
size += inds_dists.size();
249-
else
250-
size += max_count;
251-
}
252-
neighbors_indices.resize(size * 2);
253-
i0 = 0;
254-
sum_sb = 0;
255-
sum_qb = 0;
256-
b = 0;
257-
int u = 0;
258-
for (auto& inds_dists : all_inds_dists){
259-
if (i0 == sum_qb + q_batches[b]){
260-
sum_qb += q_batches[b];
261-
sum_sb += s_batches[b];
262-
b++;
263-
}
264-
for (int j = 0; j < max_count; j++){
265-
if (j < inds_dists.size()){
266-
neighbors_indices[u] = inds_dists[j].first + sum_sb;
267-
neighbors_indices[u + 1] = i0;
268-
u += 2;
269-
}
270-
}
271-
i0++;
272-
}
273-
}
274-
return max_count;
275-
}
110+
}

csrc/cpu/utils/neighbors.h

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,4 @@ using namespace std;
1010

1111
template<typename scalar_t>
1212
int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
13-
vector<long>& neighbors_indices, float radius, int dim, int max_num, int mode);
14-
15-
16-
template<typename scalar_t>
17-
int batch_nanoflann_neighbors(vector<scalar_t>& queries,
18-
vector<scalar_t>& supports,
19-
vector<long>& q_batches,
20-
vector<long>& s_batches,
21-
vector<long>& neighbors_indices,
22-
float radius, int dim, int max_num, int mode);
13+
vector<long>& neighbors_indices, float radius, int dim, int max_num, int mode);

0 commit comments

Comments
 (0)