11#include " radius_cpu.h"
2- #include < algorithm>
3- #include " utils.h"
4- #include < cstdint>
5-
6-
7- torch::Tensor radius_cpu (torch::Tensor support, torch::Tensor query,
8- double radius, int64_t max_num, int64_t n_threads){
9-
10- CHECK_CPU (query);
11- CHECK_CPU (support);
12-
13- torch::Tensor out;
14- std::vector<size_t >* neighbors_indices = new std::vector<size_t >();
15- auto options = torch::TensorOptions ().dtype (torch::kLong ).device (torch::kCPU );
16- int max_count = 0 ;
17-
18- AT_DISPATCH_ALL_TYPES (query.scalar_type (), " radius_cpu" , [&] {
19-
20- auto data_q = query.data_ptr <scalar_t >();
21- auto data_s = support.data_ptr <scalar_t >();
22- std::vector<scalar_t > queries_stl = std::vector<scalar_t >(data_q,
23- data_q + query.size (0 )*query.size (1 ));
24- std::vector<scalar_t > supports_stl = std::vector<scalar_t >(data_s,
25- data_s + support.size (0 )*support.size (1 ));
26-
27- int dim = torch::size (query, 1 );
28-
29- max_count = nanoflann_neighbors<scalar_t >(queries_stl, supports_stl ,neighbors_indices, radius, dim, max_num, n_threads, 0 , 1 );
30-
31- });
32-
33- size_t * neighbors_indices_ptr = neighbors_indices->data ();
34-
35- const long long tsize = static_cast <long long >(neighbors_indices->size ()/2 );
36- out = torch::from_blob (neighbors_indices_ptr, {tsize, 2 }, options=options);
37- out = out.t ();
38-
39- auto result = torch::zeros_like (out);
40-
41- auto index = torch::tensor ({1 ,0 });
42-
43- result.index_copy_ (0 , index, out);
44-
45- return result;
46- }
472
48-
49- void get_size_batch (const std::vector<long >& batch, std::vector<long >& res){
50-
51- res.resize (batch[batch.size ()-1 ]-batch[0 ]+1 , 0 );
52- long ind = batch[0 ];
53- long incr = 1 ;
54- for (unsigned long i=1 ; i < batch.size (); i++){
55-
56- if (batch[i] == ind)
57- incr++;
58- else {
59- res[ind-batch[0 ]] = incr;
60- incr =1 ;
61- ind = batch[i];
62- }
63- }
64- res[ind-batch[0 ]] = incr;
3+ #include " utils.h"
4+ #include " utils/neighbors.cpp"
5+
6+ torch::Tensor radius_cpu (torch::Tensor x, torch::Tensor y,
7+ torch::optional<torch::Tensor> ptr_x,
8+ torch::optional<torch::Tensor> ptr_y, double r,
9+ int64_t max_num_neighbors, int64_t num_workers) {
10+
11+ CHECK_CPU (x);
12+ CHECK_INPUT (x.dim () == 2 );
13+ CHECK_CPU (y);
14+ CHECK_INPUT (y.dim () == 2 );
15+
16+ if (ptr_x.has_value ()) {
17+ CHECK_CPU (ptr_x.value ());
18+ CHECK_INPUT (ptr_x.value ().dim () == 1 );
19+ }
20+ if (ptr_y.has_value ()) {
21+ CHECK_CPU (ptr_y.value ());
22+ CHECK_INPUT (ptr_y.value ().dim () == 1 );
23+ }
24+
25+ std::vector<size_t > *out_vec = new std::vector<size_t >();
26+
27+ AT_DISPATCH_ALL_TYPES (x.scalar_type (), " radius_cpu" , [&] {
28+ auto x_data = x.data_ptr <scalar_t >();
29+ auto y_data = y.data_ptr <scalar_t >();
30+ auto x_vec = std::vector<scalar_t >(x_data, x_data + x.numel ());
31+ auto y_vec = std::vector<scalar_t >(y_data, y_data + y.numel ());
32+
33+ if (!ptr_x.has_value ()) {
34+ nanoflann_neighbors<scalar_t >(y_vec, x_vec, out_vec, r, x.size (-1 ),
35+ max_num_neighbors, num_workers, 0 , 1 );
36+ } else {
37+ auto sx = (ptr_x.value ().narrow (0 , 1 , ptr_x.value ().numel () - 1 ) -
38+ ptr_x.value ().narrow (0 , 0 , ptr_x.value ().numel () - 1 ));
39+ auto sy = (ptr_y.value ().narrow (0 , 1 , ptr_y.value ().numel () - 1 ) -
40+ ptr_y.value ().narrow (0 , 0 , ptr_y.value ().numel () - 1 ));
41+ auto sx_data = sx.data_ptr <int64_t >();
42+ auto sy_data = sy.data_ptr <int64_t >();
43+ auto sx_vec = std::vector<long >(sx_data, sx_data + sx.numel ());
44+ auto sy_vec = std::vector<long >(sy_data, sy_data + sy.numel ());
45+ batch_nanoflann_neighbors<scalar_t >(y_vec, x_vec, sy_vec, sx_vec, out_vec,
46+ r, x.size (-1 ), max_num_neighbors, 0 ,
47+ 1 );
48+ }
49+ });
50+
51+ const int64_t size = out_vec->size () / 2 ;
52+ auto out = torch::from_blob (out_vec->data (), {size, 2 },
53+ x.options ().dtype (torch::kLong ));
54+ return out.t ().index_select (0 , torch::tensor ({1 , 0 }));
6555}
66-
67- torch::Tensor batch_radius_cpu (torch::Tensor support,
68- torch::Tensor query,
69- torch::Tensor support_batch,
70- torch::Tensor query_batch,
71- double radius, int64_t max_num) {
72-
73- CHECK_CPU (query);
74- CHECK_CPU (support);
75- CHECK_CPU (query_batch);
76- CHECK_CPU (support_batch);
77-
78- torch::Tensor out;
79- auto data_qb = query_batch.data_ptr <int64_t >();
80- auto data_sb = support_batch.data_ptr <int64_t >();
81-
82- std::vector<long > query_batch_stl = std::vector<long >(data_qb, data_qb+query_batch.size (0 ));
83- std::vector<long > size_query_batch_stl;
84- CHECK_INPUT (std::is_sorted (query_batch_stl.begin (),query_batch_stl.end ()));
85- get_size_batch (query_batch_stl, size_query_batch_stl);
86-
87- std::vector<long > support_batch_stl = std::vector<long >(data_sb, data_sb+support_batch.size (0 ));
88- std::vector<long > size_support_batch_stl;
89- CHECK_INPUT (std::is_sorted (support_batch_stl.begin (),support_batch_stl.end ()));
90- get_size_batch (support_batch_stl, size_support_batch_stl);
91-
92- std::vector<size_t >* neighbors_indices = new std::vector<size_t >();
93- auto options = torch::TensorOptions ().dtype (torch::kLong ).device (torch::kCPU );
94- int max_count = 0 ;
95-
96- AT_DISPATCH_ALL_TYPES (query.scalar_type (), " batch_radius_cpu" , [&] {
97- auto data_q = query.data_ptr <scalar_t >();
98- auto data_s = support.data_ptr <scalar_t >();
99- std::vector<scalar_t > queries_stl = std::vector<scalar_t >(data_q,
100- data_q + query.size (0 )*query.size (1 ));
101- std::vector<scalar_t > supports_stl = std::vector<scalar_t >(data_s,
102- data_s + support.size (0 )*support.size (1 ));
103-
104- int dim = torch::size (query, 1 );
105- max_count = batch_nanoflann_neighbors<scalar_t >(queries_stl,
106- supports_stl,
107- size_query_batch_stl,
108- size_support_batch_stl,
109- neighbors_indices,
110- radius,
111- dim,
112- max_num,
113- 0 , 1 );
114- });
115-
116- size_t * neighbors_indices_ptr = neighbors_indices->data ();
117-
118-
119- const long long tsize = static_cast <long long >(neighbors_indices->size ()/2 );
120- out = torch::from_blob (neighbors_indices_ptr, {tsize, 2 }, options=options);
121- out = out.t ();
122-
123- auto result = torch::zeros_like (out);
124-
125- auto index = torch::tensor ({1 ,0 });
126-
127- result.index_copy_ (0 , index, out);
128-
129- return result;
130- }
0 commit comments