Skip to content

Commit 4d72a05

Browse files
committed
randperm
1 parent 24a770f commit 4d72a05

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

aten/cpu/cluster.cpp

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,27 @@
11
#include <torch/torch.h>
22

33

4-
at::Tensor degree(at::Tensor index, int64_t num_nodes) {
5-
auto one = at::ones_like(index);
6-
auto zero = at::zeros(torch::CPU(at::kLong), { num_nodes });
7-
return zero.scatter_add_(0, index, one);
8-
}
9-
10-
11-
at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
12-
auto cluster = at::empty(torch::CPU(at::kLong), { num_nodes }).fill_(-1);
13-
auto deg = degree(row, num_nodes);
14-
4+
inline std::tuple<at::Tensor, at::Tensor> randperm(at::Tensor row, at::Tensor col) {
155
/* at::Tensor perm; */
166
/* std::tie(row, perm) = row.sort(); */
177
/* col = col.index_select(0, perm); */
188

199
/* TODO: randperm */
2010
/* TODO: randperm_sort_row */
11+
return { row, col };
12+
}
13+
14+
15+
inline at::Tensor degree(at::Tensor index, int64_t num_nodes) {
16+
auto zero = at::zeros(torch::CPU(at::kLong), { num_nodes });
17+
return zero.scatter_add_(0, index, at::ones_like(index));
18+
}
19+
20+
21+
at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
22+
std::tie(row, col) = randperm(row, col);
23+
auto deg = degree(row, num_nodes);
24+
auto cluster = at::empty(torch::CPU(at::kLong), { num_nodes }).fill_(-1);
2125

2226
auto *row_data = row.data<int64_t>();
2327
auto *col_data = col.data<int64_t>();

0 commit comments

Comments
 (0)