Skip to content

Commit 642d548

Browse files
committed
new try
1 parent fbd14a9 commit 642d548

File tree

13 files changed

+58
-54
lines changed

13 files changed

+58
-54
lines changed

aten/cluster.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ def grid(pos, size, start=None, end=None):
1212

1313

1414
def graclus(row, col, num_nodes):
15-
return cluster_cpu.graclus(row, col, num_nodes)
15+
lib = cluster_cuda if pos.is_cuda else cluster_cpu
16+
return lib.graclus(row, col, num_nodes)
1617

1718

1819
device = torch.device('cuda')
@@ -23,10 +24,11 @@ def graclus(row, col, num_nodes):
2324
cluster = grid(pos, size)
2425
print('result', cluster.tolist(), cluster.dtype, cluster.device)
2526

26-
row = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3])
27-
col = torch.tensor([1, 2, 0, 2, 3, 0, 1, 3, 1, 2])
28-
print(row)
29-
print(col)
3027
print('-----------------')
28+
29+
row = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3], device=device)
30+
col = torch.tensor([1, 2, 0, 2, 3, 0, 1, 3, 1, 2], device=device)
31+
print('row', row.tolist())
32+
print('col', col.tolist())
3133
cluster = graclus(row, col, 4)
32-
print(cluster)
34+
print('result', cluster.tolist(), cluster.dtype, cluster.device)

aten/cpu/graclus.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
at::Tensor graclus(at::Tensor row, at::Tensor col, int num_nodes) {
88
std::tie(row, col) = remove_self_loops(row, col);
99
std::tie(row, col) = randperm(row, col, num_nodes);
10+
auto deg = degree(row, num_nodes, row.type().scalarType());
1011

1112
auto cluster = at::full(row.type(), {num_nodes}, -1);
12-
auto deg = degree(row, num_nodes, row.type().scalarType());
1313

1414
auto *row_data = row.data<int64_t>();
1515
auto *col_data = col.data<int64_t>();

aten/cuda/cluster.cpp

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,11 @@
11
#include <torch/torch.h>
22

3-
at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
4-
at::Tensor end);
5-
63
#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor")
74

8-
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
9-
at::Tensor end) {
10-
CHECK_CUDA(pos);
11-
CHECK_CUDA(size);
12-
CHECK_CUDA(start);
13-
CHECK_CUDA(end);
14-
15-
return grid_cuda(pos, size, start, end);
16-
}
5+
#include "graclus.cpp"
6+
#include "grid.cpp"
177

188
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
9+
m.def("graclus", &graclus, "Graclus (CUDA)");
1910
m.def("grid", &grid, "Grid (CUDA)");
2011
}

aten/cuda/graclus.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#include <torch/torch.h>
2+
3+
#include "../include/degree.cpp"
4+
#include "../include/loop.cpp"
5+
6+
at::Tensor graclus(at::Tensor row, at::Tensor col, int num_nodes) {
7+
CHECK_CUDA(row);
8+
CHECK_CUDA(col);
9+
10+
std::tie(row, col) = remove_self_loops(row, col);
11+
auto deg = degree(row, num_nodes, row.type().scalarType());
12+
13+
return deg;
14+
}

aten/cuda/grid.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#include <torch/torch.h>
2+
3+
at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
4+
at::Tensor end);
5+
6+
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
7+
at::Tensor end) {
8+
CHECK_CUDA(pos);
9+
CHECK_CUDA(size);
10+
CHECK_CUDA(start);
11+
CHECK_CUDA(end);
12+
13+
return grid_cuda(pos, size, start, end);
14+
}

aten/include/degree.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
#include "degree.h"
1+
#ifndef DEGREE_INC
2+
#define DEGREE_INC
23

34
#include <torch/torch.h>
45

@@ -8,3 +9,5 @@ inline at::Tensor degree(at::Tensor index, int num_nodes,
89
auto one = at::full(zero.type(), {index.size(0)}, 1);
910
return zero.scatter_add_(0, index, one);
1011
}
12+
13+
#endif // DEGREE_INC

aten/include/degree.h

Lines changed: 0 additions & 9 deletions
This file was deleted.

aten/include/loop.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
#include "loop.h"
1+
#ifndef LOOP_INC
2+
#define LOOP_INC
23

34
#include <torch/torch.h>
45

@@ -7,3 +8,5 @@ inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
78
auto mask = row != col;
89
return {row.masked_select(mask), col.masked_select(mask)};
910
}
11+
12+
#endif // LOOP_INC

aten/include/loop.h

Lines changed: 0 additions & 9 deletions
This file was deleted.

0 commit comments

Comments
 (0)