Skip to content

Commit 4bc9a76

Browse files
committed
2 parents 92105bf + 0587704 commit 4bc9a76

File tree

9 files changed

+127
-98
lines changed

9 files changed

+127
-98
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ dist/
66
.eggs/
77
*.egg-info/
88
.coverage
9+
*.so

aten/cpu/cluster.cpp

Lines changed: 2 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,7 @@
11
#include <torch/torch.h>
22

3-
inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
4-
at::Tensor col) {
5-
auto mask = row != col;
6-
row = row.masked_select(mask);
7-
col = col.masked_select(mask);
8-
return {row, col};
9-
}
10-
11-
inline std::tuple<at::Tensor, at::Tensor>
12-
randperm(at::Tensor row, at::Tensor col, int64_t num_nodes) {
13-
// Randomly reorder row and column indices.
14-
auto perm = at::randperm(torch::CPU(at::kLong), row.size(0));
15-
row = row.index_select(0, perm);
16-
col = col.index_select(0, perm);
17-
18-
// Randomly swap row values.
19-
auto node_rid = at::randperm(torch::CPU(at::kLong), num_nodes);
20-
row = node_rid.index_select(0, row);
21-
22-
// Sort row and column indices row-wise.
23-
std::tie(row, perm) = row.sort();
24-
col = col.index_select(0, perm);
25-
26-
// Revert row value swaps.
27-
row = std::get<1>(node_rid.sort()).index_select(0, row);
28-
29-
return {row, col};
30-
}
31-
32-
inline at::Tensor degree(at::Tensor index, int64_t num_nodes) {
33-
auto zero = at::zeros(torch::CPU(at::kLong), {num_nodes});
34-
return zero.scatter_add_(0, index, at::ones_like(index));
35-
}
36-
37-
at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
38-
std::tie(row, col) = remove_self_loops(row, col);
39-
std::tie(row, col) = randperm(row, col, num_nodes);
40-
41-
auto deg = degree(row, num_nodes);
42-
auto cluster = at::empty(torch::CPU(at::kLong), {num_nodes}).fill_(-1);
43-
44-
auto *row_data = row.data<int64_t>();
45-
auto *col_data = col.data<int64_t>();
46-
auto *deg_data = deg.data<int64_t>();
47-
auto *cluster_data = cluster.data<int64_t>();
48-
49-
int64_t e_idx = 0, d_idx, r, c;
50-
while (e_idx < row.size(0)) {
51-
r = row_data[e_idx];
52-
if (cluster_data[r] < 0) {
53-
cluster_data[r] = r;
54-
for (d_idx = 0; d_idx < deg_data[r]; d_idx++) {
55-
c = col_data[e_idx + d_idx];
56-
if (cluster_data[c] < 0) {
57-
cluster_data[r] = std::min(r, c);
58-
cluster_data[c] = std::min(r, c);
59-
break;
60-
}
61-
}
62-
}
63-
e_idx += deg_data[r];
64-
}
65-
66-
return cluster;
67-
}
68-
69-
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
70-
at::Tensor end) {
71-
size = size.toType(pos.type());
72-
start = start.toType(pos.type());
73-
end = end.toType(pos.type());
74-
75-
pos = pos - start.view({1, -1});
76-
auto num_voxels = ((end - start) / size).toType(at::kLong);
77-
num_voxels = (num_voxels + 1).cumsum(0);
78-
num_voxels -= num_voxels.data<int64_t>()[0];
79-
num_voxels.data<int64_t>()[0] = 1;
80-
81-
auto cluster = pos / size.view({1, -1});
82-
cluster = cluster.toType(at::kLong);
83-
cluster *= num_voxels.view({1, -1});
84-
cluster = cluster.sum(1);
85-
86-
return cluster;
87-
}
3+
#include "graclus.cpp"
4+
#include "grid.cpp"
885

896
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
907
m.def("graclus", &graclus, "Graclus (CPU)");

aten/cpu/cluster.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ def graclus_cluster(row, col, num_nodes):
1313
return cluster_cpu.graclus(row, col, num_nodes)
1414

1515

16-
# pos = torch.tensor([[1, 1], [3, 3], [5, 5], [7, 7]])
17-
# size = torch.tensor([2, 2])
18-
# start = torch.tensor([0, 0])
19-
# end = torch.tensor([7, 7])
20-
# print('pos', pos.tolist())
21-
# print('size', size.tolist())
22-
# cluster = grid_cluster(pos, size)
23-
# print('result', cluster.tolist(), cluster.dtype)
16+
pos = torch.tensor([[1, 1], [3, 3], [5, 5], [7, 7]])
17+
size = torch.tensor([2, 2])
18+
start = torch.tensor([0, 0])
19+
end = torch.tensor([7, 7])
20+
print('pos', pos.tolist())
21+
print('size', size.tolist())
22+
cluster = grid_cluster(pos, size)
23+
print('result', cluster.tolist(), cluster.dtype)
2424

2525
row = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3])
2626
col = torch.tensor([1, 2, 0, 2, 3, 0, 1, 3, 1, 2])

aten/cpu/degree.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#ifndef DEGREE_CPP
2+
#define DEGREE_CPP
3+
4+
#include <torch/torch.h>
5+
6+
inline at::Tensor degree(at::Tensor index, int num_nodes,
7+
at::ScalarType scalar_type) {
8+
auto zero = at::full(index.type().toScalarType(scalar_type), {num_nodes}, 0);
9+
auto one = at::full(zero.type(), {index.size(0)}, 1);
10+
return zero.scatter_add_(0, index, one);
11+
}
12+
13+
#endif // DEGREE_CPP

aten/cpu/graclus.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#include <torch/torch.h>
2+
3+
#include "degree.cpp"
4+
#include "loop.cpp"
5+
#include "perm.cpp"
6+
7+
at::Tensor graclus(at::Tensor row, at::Tensor col, int num_nodes) {
8+
std::tie(row, col) = remove_self_loops(row, col);
9+
std::tie(row, col) = randperm(row, col, num_nodes);
10+
11+
auto cluster = at::full(row.type(), {num_nodes}, -1);
12+
auto deg = degree(row, num_nodes, row.type().scalarType());
13+
14+
auto *row_data = row.data<int64_t>();
15+
auto *col_data = col.data<int64_t>();
16+
auto *deg_data = deg.data<int64_t>();
17+
auto *cluster_data = cluster.data<int64_t>();
18+
19+
int64_t e_idx = 0, d_idx, r, c;
20+
while (e_idx < row.size(0)) {
21+
r = row_data[e_idx];
22+
if (cluster_data[r] < 0) {
23+
cluster_data[r] = r;
24+
for (d_idx = 0; d_idx < deg_data[r]; d_idx++) {
25+
c = col_data[e_idx + d_idx];
26+
if (cluster_data[c] < 0) {
27+
cluster_data[r] = std::min(r, c);
28+
cluster_data[c] = std::min(r, c);
29+
break;
30+
}
31+
}
32+
}
33+
e_idx += deg_data[r];
34+
}
35+
36+
return cluster;
37+
}

aten/cpu/grid.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#include <torch/torch.h>
2+
3+
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
4+
at::Tensor end) {
5+
size = size.toType(pos.type());
6+
start = start.toType(pos.type());
7+
end = end.toType(pos.type());
8+
9+
pos = pos - start.view({1, -1});
10+
auto num_voxels = ((end - start) / size).toType(at::kLong);
11+
num_voxels = (num_voxels + 1).cumsum(0);
12+
num_voxels -= num_voxels.data<int64_t>()[0];
13+
num_voxels.data<int64_t>()[0] = 1;
14+
15+
auto cluster = pos / size.view({1, -1});
16+
cluster = cluster.toType(at::kLong);
17+
cluster *= num_voxels.view({1, -1});
18+
cluster = cluster.sum(1);
19+
20+
return cluster;
21+
}

aten/cpu/loop.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#ifndef LOOP_CPP
2+
#define LOOP_CPP
3+
4+
#include <torch/torch.h>
5+
6+
inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
7+
at::Tensor col) {
8+
auto mask = row != col;
9+
return {row.masked_select(mask), col.masked_select(mask)};
10+
}
11+
12+
#endif // LOOP_CPP

aten/cpu/perm.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#ifndef PERM_CPP
2+
#define PERM_CPP
3+
4+
#include <torch/torch.h>
5+
6+
inline std::tuple<at::Tensor, at::Tensor>
7+
randperm(at::Tensor row, at::Tensor col, int num_nodes) {
8+
// Randomly reorder row and column indices.
9+
auto perm = at::randperm(row.type(), row.size(0));
10+
row = row.index_select(0, perm);
11+
col = col.index_select(0, perm);
12+
13+
// Randomly swap row values.
14+
auto node_rid = at::randperm(row.type(), num_nodes);
15+
row = node_rid.index_select(0, row);
16+
17+
// Sort row and column indices row-wise.
18+
std::tie(row, perm) = row.sort();
19+
col = col.index_select(0, perm);
20+
21+
// Revert row value swaps.
22+
row = std::get<1>(node_rid.sort()).index_select(0, row);
23+
24+
return {row, col};
25+
}
26+
27+
#endif // PERM_CPP

aten/cuda/cluster_kernel.cu

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
#define BLOCKS(N) (N + THREADS - 1) / THREADS
66

77
template <typename scalar_t>
8-
__global__ void grid_cuda_kernel(
9-
int64_t *cluster, const at::cuda::detail::TensorInfo<scalar_t, int> pos,
10-
const scalar_t *__restrict__ size, const scalar_t *__restrict__ start,
11-
const scalar_t *__restrict__ end, const size_t num_nodes) {
8+
__global__ void
9+
grid_cuda_kernel(int64_t *cluster,
10+
at::cuda::detail::TensorInfo<scalar_t, int> pos,
11+
scalar_t *__restrict__ size, scalar_t *__restrict__ start,
12+
scalar_t *__restrict__ end, size_t num_nodes) {
1213
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
1314
const size_t stride = blockDim.x * gridDim.x;
1415
for (ptrdiff_t i = index; i < num_nodes; i += stride) {
@@ -25,7 +26,7 @@ __global__ void grid_cuda_kernel(
2526

2627
at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
2728
at::Tensor end) {
28-
const auto num_nodes = pos.size(0);
29+
auto num_nodes = pos.size(0);
2930
auto cluster = at::empty(pos.type().toScalarType(at::kLong), {num_nodes});
3031

3132
AT_DISPATCH_ALL_TYPES(pos.type(), "grid_cuda_kernel", [&] {

0 commit comments

Comments
 (0)