Skip to content

Commit 3d682e5

Browse files
committed
additional checks; attempt to fix windows build error
1 parent 4dbba3f commit 3d682e5

File tree

3 files changed

+46
-36
lines changed

3 files changed

+46
-36
lines changed

csrc/cpu/radius_cpu.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,25 @@ torch::Tensor batch_radius_cpu(torch::Tensor query,
6464
torch::Tensor support_batch,
6565
double radius, int64_t max_num) {
6666

67+
CHECK_CPU(query);
68+
CHECK_CPU(support);
69+
CHECK_CPU(query_batch);
70+
CHECK_CPU(support_batch);
71+
6772
torch::Tensor out;
6873
auto data_qb = query_batch.data_ptr<int64_t>();
6974
auto data_sb = support_batch.data_ptr<int64_t>();
75+
7076
std::vector<long> query_batch_stl = std::vector<long>(data_qb, data_qb+query_batch.size(0));
7177
std::vector<long> size_query_batch_stl;
78+
CHECK_INPUT(std::is_sorted(query_batch_stl.begin(),query_batch_stl.end()));
7279
get_size_batch(query_batch_stl, size_query_batch_stl);
80+
7381
std::vector<long> support_batch_stl = std::vector<long>(data_sb, data_sb+support_batch.size(0));
7482
std::vector<long> size_support_batch_stl;
83+
CHECK_INPUT(std::is_sorted(support_batch_stl.begin(),support_batch_stl.end()));
7584
get_size_batch(support_batch_stl, size_support_batch_stl);
85+
7686
std::vector<size_t>* neighbors_indices = new std::vector<size_t>();
7787
auto options = torch::TensorOptions().dtype(torch::kLong).device(torch::kCPU);
7888
int max_count = 0;

csrc/cpu/utils/neighbors.cpp

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ size_t nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>
7979
// CLoud variable
8080
PointCloud<scalar_t> pcd;
8181
pcd.set(supports, dim);
82-
//Cloud query
82+
// Cloud query
8383
PointCloud<scalar_t>* pcd_query = new PointCloud<scalar_t>();
8484
(*pcd_query).set(queries, dim);
8585

@@ -95,7 +95,6 @@ size_t nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>
9595
index = new my_kd_tree_t(dim, pcd, tree_params);
9696
index->buildIndex();
9797
// Search neigbors indices
98-
// ***********************
9998

10099
// Search params
101100
nanoflann::SearchParams search_params;
@@ -137,7 +136,7 @@ size_t nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>
137136
size_t n_queries = (*pcd_query).pts.size();
138137
size_t actual_threads = std::min((long long)n_threads, (long long)n_queries);
139138

140-
std::thread* tid[actual_threads];
139+
std::vector<std::thread*> tid(actual_threads);
141140

142141
size_t start, end;
143142
size_t length;
@@ -147,17 +146,8 @@ size_t nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>
147146
else {
148147
auto res = std::lldiv((long long)n_queries, (long long)n_threads);
149148
length = (size_t)res.quot;
150-
/*
151-
if (res.rem == 0) {
152-
length = res.quot;
153-
}
154-
else {
155-
length =
156-
}
157-
*/
158149
}
159150
for (size_t t = 0; t < actual_threads; t++) {
160-
//sem->wait();
161151
start = t*length;
162152
if (t == actual_threads-1) {
163153
end = n_queries;
@@ -233,12 +223,10 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries,
233223
double radius, int dim, int64_t max_num){
234224

235225

236-
// Initiate variables
237-
// ******************
238-
// indices
226+
// indices
239227
size_t i0 = 0;
240228

241-
// Square radius
229+
// Square radius
242230
const scalar_t r2 = static_cast<scalar_t>(radius*radius);
243231

244232
// Counting vector
@@ -257,7 +245,6 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries,
257245
eps = 0;
258246
}
259247
// Nanoflann related variables
260-
// ***************************
261248

262249
// CLoud variable
263250
PointCloud<scalar_t> current_cloud;
@@ -271,21 +258,20 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries,
271258
// KDTree type definition
272259
typedef nanoflann::KDTreeSingleIndexAdaptor< nanoflann::L2_Adaptor<scalar_t, PointCloud<scalar_t> > , PointCloud<scalar_t>> my_kd_tree_t;
273260

274-
// Pointer to trees
261+
// Pointer to trees
275262
my_kd_tree_t* index;
276263
// Build KDTree for the first batch element
277264
current_cloud.set_batch(supports, sum_sb, s_batches[b], dim);
278265
index = new my_kd_tree_t(dim, current_cloud, tree_params);
279266
index->buildIndex();
280-
// Search neigbors indices
281-
// ***********************
282-
// Search params
267+
// Search neigbors indices
268+
// Search params
283269
nanoflann::SearchParams search_params;
284270
search_params.sorted = true;
285271

286272
for (auto& p : query_pcd.pts){
287273
auto p0 = *p;
288-
// Check if we changed batch
274+
// Check if we changed batch
289275

290276
scalar_t* query_pt = new scalar_t[dim];
291277
std::copy(p0.begin(), p0.end(), query_pt);
@@ -295,19 +281,19 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries,
295281
sum_sb += s_batches[b];
296282
b++;
297283

298-
// Change the points
284+
// Change the points
299285
current_cloud.pts.clear();
300286
current_cloud.set_batch(supports, sum_sb, s_batches[b], dim);
301-
// Build KDTree of the current element of the batch
287+
// Build KDTree of the current element of the batch
302288
delete index;
303289
index = new my_kd_tree_t(dim, current_cloud, tree_params);
304290
index->buildIndex();
305291
}
306-
// Initial guess of neighbors size
292+
// Initial guess of neighbors size
307293
all_inds_dists[i0].reserve(max_count);
308-
// Find neighbors
294+
// Find neighbors
309295
size_t nMatches = index->radiusSearch(query_pt, r2+eps, all_inds_dists[i0], search_params);
310-
// Update max count
296+
// Update max count
311297

312298
std::vector<std::pair<size_t, float> > indices_dists;
313299
nanoflann::RadiusResultSet<float,size_t> resultSet(r2, indices_dists);
@@ -316,14 +302,17 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries,
316302

317303
if (nMatches > max_count)
318304
max_count = nMatches;
319-
// Increment query idx
305+
// Increment query idx
320306
i0++;
321307
}
308+
309+
310+
322311
// how many neighbors do we keep
323312
if(max_num > 0) {
324313
max_count = max_num;
325314
}
326-
// Reserve the memory
315+
// Reserve the memory
327316

328317
size_t size = 0; // total number of edges
329318
for (auto& inds_dists : all_inds_dists){
@@ -332,6 +321,7 @@ size_t batch_nanoflann_neighbors (std::vector<scalar_t>& queries,
332321
else
333322
size += max_count;
334323
}
324+
335325
neighbors_indices->resize(size * 2);
336326
i0 = 0;
337327
sum_sb = 0;

torch_cluster/radius.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Optional
22
import torch
3+
import numpy as np
34

45

56
def radius(x: torch.Tensor, y: torch.Tensor, r: float,
@@ -15,16 +16,17 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
1516
y (Tensor): Node feature matrix
1617
:math:`\mathbf{Y} \in \mathbb{R}^{M \times F}`.
1718
r (float): The radius.
18-
batch_x (LongTensor, optional): Batch vector
19+
batch_x (LongTensor, optional): Batch vector (must be sorted)
1920
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
2021
node to a specific example. (default: :obj:`None`)
21-
batch_y (LongTensor, optional): Batch vector
22+
batch_y (LongTensor, optional): Batch vector (must be sorted)
2223
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
2324
node to a specific example. (default: :obj:`None`)
2425
max_num_neighbors (int, optional): The maximum number of neighbors to
2526
return for each element in :obj:`y`. (default: :obj:`32`)
26-
n_threads (int): number of threads when the input is on CPU.
27-
(default: :obj:`1`)
27+
n_threads (int): number of threads when the input is on CPU. Note
28+
that this has no effect when batch_x or batch_y is not None, or
29+
x is on GPU. (default: :obj:`1`)
2830
2931
.. code-block:: python
3032
@@ -41,9 +43,13 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
4143
x = x.view(-1, 1) if x.dim() == 1 else x
4244
y = y.view(-1, 1) if y.dim() == 1 else y
4345

46+
def is_sorted(x):
47+
return (np.diff(x.detach().cpu()) >= 0).all()
48+
4449
if x.is_cuda:
4550
if batch_x is not None:
4651
assert x.size(0) == batch_x.numel()
52+
assert is_sorted(batch_x)
4753
batch_size = int(batch_x.max()) + 1
4854

4955
deg = x.new_zeros(batch_size, dtype=torch.long)
@@ -56,6 +62,7 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
5662

5763
if batch_y is not None:
5864
assert y.size(0) == batch_y.numel()
65+
assert is_sorted(batch_y)
5966
batch_size = int(batch_y.max()) + 1
6067

6168
deg = y.new_zeros(batch_size, dtype=torch.long)
@@ -72,11 +79,13 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
7279
assert x.dim() == 2
7380
if batch_x is not None:
7481
assert batch_x.dim() == 1
82+
assert is_sorted(batch_x)
7583
assert x.size(0) == batch_x.size(0)
7684

7785
assert y.dim() == 2
7886
if batch_y is not None:
7987
assert batch_y.dim() == 1
88+
assert is_sorted(batch_y)
8089
assert y.size(0) == batch_y.size(0)
8190
assert x.size(1) == y.size(1)
8291

@@ -97,7 +106,7 @@ def radius_graph(x: torch.Tensor, r: float,
97106
x (Tensor): Node feature matrix
98107
:math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
99108
r (float): The radius.
100-
batch (LongTensor, optional): Batch vector
109+
batch (LongTensor, optional): Batch vector (must be sorted)
101110
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
102111
node to a specific example. (default: :obj:`None`)
103112
loop (bool, optional): If :obj:`True`, the graph will contain
@@ -107,8 +116,9 @@ def radius_graph(x: torch.Tensor, r: float,
107116
flow (string, optional): The flow direction when using in combination
108117
with message passing (:obj:`"source_to_target"` or
109118
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
110-
n_threads (int): number of threads when the input is on CPU.
111-
(default: :obj:`1`)
119+
n_threads (int): number of threads when the input is on CPU. Note
120+
that this has no effect when batch_x or batch_y is not None, or
121+
x is on GPU. (default: :obj:`1`)
112122
113123
:rtype: :class:`LongTensor`
114124

0 commit comments

Comments
 (0)