Skip to content

Commit 24a770f

Browse files
committed
graclus cpu implementation
1 parent f0a4d21 commit 24a770f

File tree

2 files changed

+61
-12
lines changed

2 files changed

+61
-12
lines changed

aten/cpu/cluster.cpp

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,47 @@
11
#include <torch/torch.h>
22

33

4-
at::Tensor graclus(at::Tensor row, at::Tensor col, at::Tensor weight) {
5-
return row;
4+
at::Tensor degree(at::Tensor index, int64_t num_nodes) {
5+
auto one = at::ones_like(index);
6+
auto zero = at::zeros(torch::CPU(at::kLong), { num_nodes });
7+
return zero.scatter_add_(0, index, one);
8+
}
9+
10+
11+
at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
12+
auto cluster = at::empty(torch::CPU(at::kLong), { num_nodes }).fill_(-1);
13+
auto deg = degree(row, num_nodes);
14+
15+
/* at::Tensor perm; */
16+
/* std::tie(row, perm) = row.sort(); */
17+
/* col = col.index_select(0, perm); */
18+
19+
/* TODO: randperm */
20+
/* TODO: randperm_sort_row */
21+
22+
auto *row_data = row.data<int64_t>();
23+
auto *col_data = col.data<int64_t>();
24+
auto *deg_data = deg.data<int64_t>();
25+
auto *cluster_data = cluster.data<int64_t>();
26+
27+
int64_t n_idx = 0, e_idx = 0, d_idx, r, c;
28+
while (e_idx < row.size(0)) {
29+
r = row_data[e_idx];
30+
if (cluster_data[r] < 0) {
31+
cluster_data[r] = r;
32+
for (d_idx = 0; d_idx < deg_data[r]; d_idx++) {
33+
c = col_data[e_idx + d_idx];
34+
if (cluster_data[c] < 0) {
35+
cluster_data[c] = r;
36+
break;
37+
}
38+
}
39+
}
40+
e_idx += deg_data[n_idx];
41+
n_idx++;
42+
}
43+
44+
return cluster;
645
}
746

847

@@ -14,8 +53,8 @@ at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor en
1453
pos = pos - start.view({ 1, -1 });
1554
auto num_voxels = ((end - start) / size).toType(at::kLong);
1655
num_voxels = (num_voxels + 1).cumsum(0);
17-
num_voxels = num_voxels - num_voxels[0];
18-
num_voxels[0] = 1;
56+
num_voxels -= num_voxels.data<int64_t>()[0];
57+
num_voxels.data<int64_t>()[0] = 1;
1958

2059
auto cluster = pos / size.view({ 1, -1 });
2160
cluster = cluster.toType(at::kLong);

aten/cpu/cluster.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,21 @@ def grid_cluster(pos, size, start=None, end=None):
99
return cluster_cpu.grid(pos, size, start, end)
1010

1111

12-
pos = torch.tensor([[1, 1], [3, 3], [5, 5], [7, 7]])
13-
size = torch.tensor([2, 2])
14-
start = torch.tensor([0, 0])
15-
end = torch.tensor([7, 7])
16-
print('pos', pos.tolist())
17-
print('size', size.tolist())
18-
cluster = grid_cluster(pos, size)
19-
print('result', cluster.tolist(), cluster.dtype)
12+
def graclus_cluster(row, col, num_nodes):
13+
return cluster_cpu.graclus(row, col, num_nodes)
14+
15+
16+
# pos = torch.tensor([[1, 1], [3, 3], [5, 5], [7, 7]])
17+
# size = torch.tensor([2, 2])
18+
# start = torch.tensor([0, 0])
19+
# end = torch.tensor([7, 7])
20+
# print('pos', pos.tolist())
21+
# print('size', size.tolist())
22+
# cluster = grid_cluster(pos, size)
23+
# print('result', cluster.tolist(), cluster.dtype)
24+
25+
row = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3])
26+
col = torch.tensor([1, 2, 0, 2, 3, 0, 1, 3, 1, 2])
27+
print(row)
28+
29+
print(graclus_cluster(row, col, 4))

0 commit comments

Comments
 (0)