Skip to content

Commit 4e2e69b

Browse files
committed
major clean up
1 parent 1bbf8bd commit 4e2e69b

File tree

15 files changed

+241
-1024
lines changed

15 files changed

+241
-1024
lines changed

csrc/cpu/knn_cpu.cpp

Lines changed: 52 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -1,130 +1,54 @@
1-
#include "radius_cpu.h"
2-
#include <algorithm>
3-
#include "utils.h"
4-
#include <cstdint>
5-
6-
7-
torch::Tensor knn_cpu(torch::Tensor support, torch::Tensor query,
8-
int64_t k, int64_t n_threads){
9-
10-
CHECK_CPU(query);
11-
CHECK_CPU(support);
12-
13-
torch::Tensor out;
14-
std::vector<size_t>* neighbors_indices = new std::vector<size_t>();
15-
auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
16-
int max_count = 0;
17-
18-
AT_DISPATCH_ALL_TYPES(query.scalar_type(), "radius_cpu", [&] {
19-
20-
auto data_q = query.data_ptr<scalar_t>();
21-
auto data_s = support.data_ptr<scalar_t>();
22-
std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
23-
data_q + query.size(0)*query.size(1));
24-
std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
25-
data_s + support.size(0)*support.size(1));
26-
27-
int dim = torch::size(query, 1);
28-
29-
max_count = nanoflann_neighbors<scalar_t>(queries_stl, supports_stl ,neighbors_indices, 0, dim, 0, n_threads, k, 0);
30-
31-
});
32-
33-
size_t* neighbors_indices_ptr = neighbors_indices->data();
34-
35-
const long long tsize = static_cast<long long>(neighbors_indices->size()/2);
36-
out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
37-
out = out.t();
38-
39-
auto result = torch::zeros_like(out);
40-
41-
auto index = torch::tensor({1,0});
42-
43-
result.index_copy_(0, index, out);
44-
45-
return result;
46-
}
1+
#include "knn_cpu.h"
472

48-
49-
void get_size_batch(const std::vector<long>& batch, std::vector<long>& res){
50-
51-
res.resize(batch[batch.size()-1]-batch[0]+1, 0);
52-
long ind = batch[0];
53-
long incr = 1;
54-
for(unsigned long i=1; i < batch.size(); i++){
55-
56-
if(batch[i] == ind)
57-
incr++;
58-
else{
59-
res[ind-batch[0]] = incr;
60-
incr =1;
61-
ind = batch[i];
62-
}
63-
}
64-
res[ind-batch[0]] = incr;
3+
#include "utils.h"
4+
#include "utils/neighbors.cpp"
5+
6+
torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y,
7+
torch::optional<torch::Tensor> ptr_x,
8+
torch::optional<torch::Tensor> ptr_y, int64_t k,
9+
int64_t num_workers) {
10+
11+
CHECK_CPU(x);
12+
CHECK_INPUT(x.dim() == 2);
13+
CHECK_CPU(y);
14+
CHECK_INPUT(y.dim() == 2);
15+
16+
if (ptr_x.has_value()) {
17+
CHECK_CPU(ptr_x.value());
18+
CHECK_INPUT(ptr_x.value().dim() == 1);
19+
}
20+
if (ptr_y.has_value()) {
21+
CHECK_CPU(ptr_y.value());
22+
CHECK_INPUT(ptr_y.value().dim() == 1);
23+
}
24+
25+
std::vector<size_t> *out_vec = new std::vector<size_t>();
26+
27+
AT_DISPATCH_ALL_TYPES(x.scalar_type(), "radius_cpu", [&] {
28+
auto x_data = x.data_ptr<scalar_t>();
29+
auto y_data = y.data_ptr<scalar_t>();
30+
auto x_vec = std::vector<scalar_t>(x_data, x_data + x.numel());
31+
auto y_vec = std::vector<scalar_t>(y_data, y_data + y.numel());
32+
33+
if (!ptr_x.has_value()) {
34+
nanoflann_neighbors<scalar_t>(y_vec, x_vec, out_vec, 0, x.size(-1), 0,
35+
num_workers, k, 0);
36+
} else {
37+
auto sx = (ptr_x.value().narrow(0, 1, ptr_x.value().numel() - 1) -
38+
ptr_x.value().narrow(0, 0, ptr_x.value().numel() - 1));
39+
auto sy = (ptr_y.value().narrow(0, 1, ptr_y.value().numel() - 1) -
40+
ptr_y.value().narrow(0, 0, ptr_y.value().numel() - 1));
41+
auto sx_data = sx.data_ptr<int64_t>();
42+
auto sy_data = sy.data_ptr<int64_t>();
43+
auto sx_vec = std::vector<long>(sx_data, sx_data + sx.numel());
44+
auto sy_vec = std::vector<long>(sy_data, sy_data + sy.numel());
45+
batch_nanoflann_neighbors<scalar_t>(y_vec, x_vec, sy_vec, sx_vec, out_vec,
46+
k, x.size(-1), 0, k, 0);
47+
}
48+
});
49+
50+
const int64_t size = out_vec->size() / 2;
51+
auto out = torch::from_blob(out_vec->data(), {size, 2},
52+
x.options().dtype(torch::kLong));
53+
return out.t().index_select(0, torch::tensor({1, 0}));
6554
}
66-
67-
torch::Tensor batch_knn_cpu(torch::Tensor support,
68-
torch::Tensor query,
69-
torch::Tensor support_batch,
70-
torch::Tensor query_batch,
71-
int64_t k) {
72-
73-
CHECK_CPU(query);
74-
CHECK_CPU(support);
75-
CHECK_CPU(query_batch);
76-
CHECK_CPU(support_batch);
77-
78-
torch::Tensor out;
79-
auto data_qb = query_batch.data_ptr<int64_t>();
80-
auto data_sb = support_batch.data_ptr<int64_t>();
81-
82-
std::vector<long> query_batch_stl = std::vector<long>(data_qb, data_qb+query_batch.size(0));
83-
std::vector<long> size_query_batch_stl;
84-
CHECK_INPUT(std::is_sorted(query_batch_stl.begin(),query_batch_stl.end()));
85-
get_size_batch(query_batch_stl, size_query_batch_stl);
86-
87-
std::vector<long> support_batch_stl = std::vector<long>(data_sb, data_sb+support_batch.size(0));
88-
std::vector<long> size_support_batch_stl;
89-
CHECK_INPUT(std::is_sorted(support_batch_stl.begin(),support_batch_stl.end()));
90-
get_size_batch(support_batch_stl, size_support_batch_stl);
91-
92-
std::vector<size_t>* neighbors_indices = new std::vector<size_t>();
93-
auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
94-
int max_count = 0;
95-
96-
AT_DISPATCH_ALL_TYPES(query.scalar_type(), "batch_radius_cpu", [&] {
97-
auto data_q = query.data_ptr<scalar_t>();
98-
auto data_s = support.data_ptr<scalar_t>();
99-
std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
100-
data_q + query.size(0)*query.size(1));
101-
std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
102-
data_s + support.size(0)*support.size(1));
103-
104-
int dim = torch::size(query, 1);
105-
max_count = batch_nanoflann_neighbors<scalar_t>(queries_stl,
106-
supports_stl,
107-
size_query_batch_stl,
108-
size_support_batch_stl,
109-
neighbors_indices,
110-
0,
111-
dim,
112-
0,
113-
k, 0);
114-
});
115-
116-
size_t* neighbors_indices_ptr = neighbors_indices->data();
117-
118-
119-
const long long tsize = static_cast<long long>(neighbors_indices->size()/2);
120-
out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
121-
out = out.t();
122-
123-
auto result = torch::zeros_like(out);
124-
125-
auto index = torch::tensor({1,0});
126-
127-
result.index_copy_(0, index, out);
128-
129-
return result;
130-
}

csrc/cpu/knn_cpu.h

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
11
#pragma once
22

33
#include <torch/extension.h>
4-
#include "utils/neighbors.cpp"
5-
#include <iostream>
64

7-
torch::Tensor knn_cpu(torch::Tensor support, torch::Tensor query,
8-
int64_t k, int64_t n_threads);
9-
10-
torch::Tensor batch_knn_cpu(torch::Tensor support,
11-
torch::Tensor query,
12-
torch::Tensor support_batch,
13-
torch::Tensor query_batch,
14-
int64_t k);
5+
torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y,
6+
torch::optional<torch::Tensor> ptr_x,
7+
torch::optional<torch::Tensor> ptr_y, int64_t k,
8+
int64_t num_workers);

csrc/cpu/radius_cpu.cpp

Lines changed: 52 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -1,130 +1,55 @@
11
#include "radius_cpu.h"
2-
#include <algorithm>
3-
#include "utils.h"
4-
#include <cstdint>
5-
6-
7-
torch::Tensor radius_cpu(torch::Tensor support, torch::Tensor query,
8-
double radius, int64_t max_num, int64_t n_threads){
9-
10-
CHECK_CPU(query);
11-
CHECK_CPU(support);
12-
13-
torch::Tensor out;
14-
std::vector<size_t>* neighbors_indices = new std::vector<size_t>();
15-
auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
16-
int max_count = 0;
17-
18-
AT_DISPATCH_ALL_TYPES(query.scalar_type(), "radius_cpu", [&] {
19-
20-
auto data_q = query.data_ptr<scalar_t>();
21-
auto data_s = support.data_ptr<scalar_t>();
22-
std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
23-
data_q + query.size(0)*query.size(1));
24-
std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
25-
data_s + support.size(0)*support.size(1));
26-
27-
int dim = torch::size(query, 1);
28-
29-
max_count = nanoflann_neighbors<scalar_t>(queries_stl, supports_stl ,neighbors_indices, radius, dim, max_num, n_threads, 0, 1);
30-
31-
});
32-
33-
size_t* neighbors_indices_ptr = neighbors_indices->data();
34-
35-
const long long tsize = static_cast<long long>(neighbors_indices->size()/2);
36-
out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
37-
out = out.t();
38-
39-
auto result = torch::zeros_like(out);
40-
41-
auto index = torch::tensor({1,0});
42-
43-
result.index_copy_(0, index, out);
44-
45-
return result;
46-
}
472

48-
49-
void get_size_batch(const std::vector<long>& batch, std::vector<long>& res){
50-
51-
res.resize(batch[batch.size()-1]-batch[0]+1, 0);
52-
long ind = batch[0];
53-
long incr = 1;
54-
for(unsigned long i=1; i < batch.size(); i++){
55-
56-
if(batch[i] == ind)
57-
incr++;
58-
else{
59-
res[ind-batch[0]] = incr;
60-
incr =1;
61-
ind = batch[i];
62-
}
63-
}
64-
res[ind-batch[0]] = incr;
3+
#include "utils.h"
4+
#include "utils/neighbors.cpp"
5+
6+
torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
7+
torch::optional<torch::Tensor> ptr_x,
8+
torch::optional<torch::Tensor> ptr_y, double r,
9+
int64_t max_num_neighbors, int64_t num_workers) {
10+
11+
CHECK_CPU(x);
12+
CHECK_INPUT(x.dim() == 2);
13+
CHECK_CPU(y);
14+
CHECK_INPUT(y.dim() == 2);
15+
16+
if (ptr_x.has_value()) {
17+
CHECK_CPU(ptr_x.value());
18+
CHECK_INPUT(ptr_x.value().dim() == 1);
19+
}
20+
if (ptr_y.has_value()) {
21+
CHECK_CPU(ptr_y.value());
22+
CHECK_INPUT(ptr_y.value().dim() == 1);
23+
}
24+
25+
std::vector<size_t> *out_vec = new std::vector<size_t>();
26+
27+
AT_DISPATCH_ALL_TYPES(x.scalar_type(), "radius_cpu", [&] {
28+
auto x_data = x.data_ptr<scalar_t>();
29+
auto y_data = y.data_ptr<scalar_t>();
30+
auto x_vec = std::vector<scalar_t>(x_data, x_data + x.numel());
31+
auto y_vec = std::vector<scalar_t>(y_data, y_data + y.numel());
32+
33+
if (!ptr_x.has_value()) {
34+
nanoflann_neighbors<scalar_t>(y_vec, x_vec, out_vec, r, x.size(-1),
35+
max_num_neighbors, num_workers, 0, 1);
36+
} else {
37+
auto sx = (ptr_x.value().narrow(0, 1, ptr_x.value().numel() - 1) -
38+
ptr_x.value().narrow(0, 0, ptr_x.value().numel() - 1));
39+
auto sy = (ptr_y.value().narrow(0, 1, ptr_y.value().numel() - 1) -
40+
ptr_y.value().narrow(0, 0, ptr_y.value().numel() - 1));
41+
auto sx_data = sx.data_ptr<int64_t>();
42+
auto sy_data = sy.data_ptr<int64_t>();
43+
auto sx_vec = std::vector<long>(sx_data, sx_data + sx.numel());
44+
auto sy_vec = std::vector<long>(sy_data, sy_data + sy.numel());
45+
batch_nanoflann_neighbors<scalar_t>(y_vec, x_vec, sy_vec, sx_vec, out_vec,
46+
r, x.size(-1), max_num_neighbors, 0,
47+
1);
48+
}
49+
});
50+
51+
const int64_t size = out_vec->size() / 2;
52+
auto out = torch::from_blob(out_vec->data(), {size, 2},
53+
x.options().dtype(torch::kLong));
54+
return out.t().index_select(0, torch::tensor({1, 0}));
6555
}
66-
67-
torch::Tensor batch_radius_cpu(torch::Tensor support,
68-
torch::Tensor query,
69-
torch::Tensor support_batch,
70-
torch::Tensor query_batch,
71-
double radius, int64_t max_num) {
72-
73-
CHECK_CPU(query);
74-
CHECK_CPU(support);
75-
CHECK_CPU(query_batch);
76-
CHECK_CPU(support_batch);
77-
78-
torch::Tensor out;
79-
auto data_qb = query_batch.data_ptr<int64_t>();
80-
auto data_sb = support_batch.data_ptr<int64_t>();
81-
82-
std::vector<long> query_batch_stl = std::vector<long>(data_qb, data_qb+query_batch.size(0));
83-
std::vector<long> size_query_batch_stl;
84-
CHECK_INPUT(std::is_sorted(query_batch_stl.begin(),query_batch_stl.end()));
85-
get_size_batch(query_batch_stl, size_query_batch_stl);
86-
87-
std::vector<long> support_batch_stl = std::vector<long>(data_sb, data_sb+support_batch.size(0));
88-
std::vector<long> size_support_batch_stl;
89-
CHECK_INPUT(std::is_sorted(support_batch_stl.begin(),support_batch_stl.end()));
90-
get_size_batch(support_batch_stl, size_support_batch_stl);
91-
92-
std::vector<size_t>* neighbors_indices = new std::vector<size_t>();
93-
auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
94-
int max_count = 0;
95-
96-
AT_DISPATCH_ALL_TYPES(query.scalar_type(), "batch_radius_cpu", [&] {
97-
auto data_q = query.data_ptr<scalar_t>();
98-
auto data_s = support.data_ptr<scalar_t>();
99-
std::vector<scalar_t> queries_stl = std::vector<scalar_t>(data_q,
100-
data_q + query.size(0)*query.size(1));
101-
std::vector<scalar_t> supports_stl = std::vector<scalar_t>(data_s,
102-
data_s + support.size(0)*support.size(1));
103-
104-
int dim = torch::size(query, 1);
105-
max_count = batch_nanoflann_neighbors<scalar_t>(queries_stl,
106-
supports_stl,
107-
size_query_batch_stl,
108-
size_support_batch_stl,
109-
neighbors_indices,
110-
radius,
111-
dim,
112-
max_num,
113-
0, 1);
114-
});
115-
116-
size_t* neighbors_indices_ptr = neighbors_indices->data();
117-
118-
119-
const long long tsize = static_cast<long long>(neighbors_indices->size()/2);
120-
out = torch::from_blob(neighbors_indices_ptr, {tsize, 2}, options=options);
121-
out = out.t();
122-
123-
auto result = torch::zeros_like(out);
124-
125-
auto index = torch::tensor({1,0});
126-
127-
result.index_copy_(0, index, out);
128-
129-
return result;
130-
}

0 commit comments

Comments
 (0)