|
1 | 1 | #include <torch/torch.h> |
2 | 2 |
|
3 | | -inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row, |
4 | | - at::Tensor col) { |
5 | | - auto mask = row != col; |
6 | | - row = row.masked_select(mask); |
7 | | - col = col.masked_select(mask); |
8 | | - return {row, col}; |
9 | | -} |
10 | | - |
11 | | -inline std::tuple<at::Tensor, at::Tensor> |
12 | | -randperm(at::Tensor row, at::Tensor col, int64_t num_nodes) { |
13 | | - // Randomly reorder row and column indices. |
14 | | - auto perm = at::randperm(torch::CPU(at::kLong), row.size(0)); |
15 | | - row = row.index_select(0, perm); |
16 | | - col = col.index_select(0, perm); |
17 | | - |
18 | | - // Randomly swap row values. |
19 | | - auto node_rid = at::randperm(torch::CPU(at::kLong), num_nodes); |
20 | | - row = node_rid.index_select(0, row); |
21 | | - |
22 | | - // Sort row and column indices row-wise. |
23 | | - std::tie(row, perm) = row.sort(); |
24 | | - col = col.index_select(0, perm); |
25 | | - |
26 | | - // Revert row value swaps. |
27 | | - row = std::get<1>(node_rid.sort()).index_select(0, row); |
28 | | - |
29 | | - return {row, col}; |
30 | | -} |
31 | | - |
32 | | -inline at::Tensor degree(at::Tensor index, int64_t num_nodes) { |
33 | | - auto zero = at::zeros(torch::CPU(at::kLong), {num_nodes}); |
34 | | - return zero.scatter_add_(0, index, at::ones_like(index)); |
35 | | -} |
36 | | - |
37 | | -at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) { |
38 | | - std::tie(row, col) = remove_self_loops(row, col); |
39 | | - std::tie(row, col) = randperm(row, col, num_nodes); |
40 | | - |
41 | | - auto deg = degree(row, num_nodes); |
42 | | - auto cluster = at::empty(torch::CPU(at::kLong), {num_nodes}).fill_(-1); |
43 | | - |
44 | | - auto *row_data = row.data<int64_t>(); |
45 | | - auto *col_data = col.data<int64_t>(); |
46 | | - auto *deg_data = deg.data<int64_t>(); |
47 | | - auto *cluster_data = cluster.data<int64_t>(); |
48 | | - |
49 | | - int64_t e_idx = 0, d_idx, r, c; |
50 | | - while (e_idx < row.size(0)) { |
51 | | - r = row_data[e_idx]; |
52 | | - if (cluster_data[r] < 0) { |
53 | | - cluster_data[r] = r; |
54 | | - for (d_idx = 0; d_idx < deg_data[r]; d_idx++) { |
55 | | - c = col_data[e_idx + d_idx]; |
56 | | - if (cluster_data[c] < 0) { |
57 | | - cluster_data[r] = std::min(r, c); |
58 | | - cluster_data[c] = std::min(r, c); |
59 | | - break; |
60 | | - } |
61 | | - } |
62 | | - } |
63 | | - e_idx += deg_data[r]; |
64 | | - } |
65 | | - |
66 | | - return cluster; |
67 | | -} |
68 | | - |
69 | | -at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start, |
70 | | - at::Tensor end) { |
71 | | - size = size.toType(pos.type()); |
72 | | - start = start.toType(pos.type()); |
73 | | - end = end.toType(pos.type()); |
74 | | - |
75 | | - pos = pos - start.view({1, -1}); |
76 | | - auto num_voxels = ((end - start) / size).toType(at::kLong); |
77 | | - num_voxels = (num_voxels + 1).cumsum(0); |
78 | | - num_voxels -= num_voxels.data<int64_t>()[0]; |
79 | | - num_voxels.data<int64_t>()[0] = 1; |
80 | | - |
81 | | - auto cluster = pos / size.view({1, -1}); |
82 | | - cluster = cluster.toType(at::kLong); |
83 | | - cluster *= num_voxels.view({1, -1}); |
84 | | - cluster = cluster.sum(1); |
85 | | - |
86 | | - return cluster; |
87 | | -} |
| 3 | +#include "graclus.cpp" |
| 4 | +#include "grid.cpp" |
88 | 5 |
|
89 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
90 | 7 | m.def("graclus", &graclus, "Graclus (CPU)"); |
|
0 commit comments