|
1 | 1 | #include <torch/torch.h> |
2 | 2 |
|
3 | 3 |
|
4 | | -at::Tensor degree(at::Tensor index, int64_t num_nodes) { |
5 | | - auto one = at::ones_like(index); |
6 | | - auto zero = at::zeros(torch::CPU(at::kLong), { num_nodes }); |
7 | | - return zero.scatter_add_(0, index, one); |
8 | | -} |
9 | | - |
10 | | - |
11 | | -at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) { |
12 | | - auto cluster = at::empty(torch::CPU(at::kLong), { num_nodes }).fill_(-1); |
13 | | - auto deg = degree(row, num_nodes); |
14 | | - |
| 4 | +inline std::tuple<at::Tensor, at::Tensor> randperm(at::Tensor row, at::Tensor col) { |
15 | 5 | /* at::Tensor perm; */ |
16 | 6 | /* std::tie(row, perm) = row.sort(); */ |
17 | 7 | /* col = col.index_select(0, perm); */ |
18 | 8 |
|
19 | 9 | /* TODO: randperm */ |
20 | 10 | /* TODO: randperm_sort_row */ |
| 11 | + return { row, col }; |
| 12 | +} |
| 13 | + |
| 14 | + |
| 15 | +inline at::Tensor degree(at::Tensor index, int64_t num_nodes) { |
| 16 | + auto zero = at::zeros(torch::CPU(at::kLong), { num_nodes }); |
| 17 | + return zero.scatter_add_(0, index, at::ones_like(index)); |
| 18 | +} |
| 19 | + |
| 20 | + |
| 21 | +at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) { |
| 22 | + std::tie(row, col) = randperm(row, col); |
| 23 | + auto deg = degree(row, num_nodes); |
| 24 | + auto cluster = at::empty(torch::CPU(at::kLong), { num_nodes }).fill_(-1); |
21 | 25 |
|
22 | 26 | auto *row_data = row.data<int64_t>(); |
23 | 27 | auto *col_data = col.data<int64_t>(); |
|
0 commit comments