Skip to content

Commit fbd14a9

Browse files
committed
setup for cuda and cpu
1 parent 4bc9a76 commit fbd14a9

File tree

13 files changed

+83
-84
lines changed

13 files changed

+83
-84
lines changed

aten/cluster.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import torch
2+
3+
import cluster_cpu
4+
import cluster_cuda
5+
6+
7+
def grid(pos, size, start=None, end=None):
8+
lib = cluster_cuda if pos.is_cuda else cluster_cpu
9+
start = pos.t().min(dim=1)[0] if start is None else start
10+
end = pos.t().max(dim=1)[0] if end is None else end
11+
return lib.grid(pos, size, start, end)
12+
13+
14+
def graclus(row, col, num_nodes):
15+
return cluster_cpu.graclus(row, col, num_nodes)
16+
17+
18+
device = torch.device('cuda')
19+
pos = torch.tensor([[1, 1], [3, 3], [5, 5], [7, 7]], device=device)
20+
size = torch.tensor([2, 2], device=device)
21+
print('pos', pos.tolist())
22+
print('size', size.tolist())
23+
cluster = grid(pos, size)
24+
print('result', cluster.tolist(), cluster.dtype, cluster.device)
25+
26+
row = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3])
27+
col = torch.tensor([1, 2, 0, 2, 3, 0, 1, 3, 1, 2])
28+
print(row)
29+
print(col)
30+
print('-----------------')
31+
cluster = graclus(row, col, 4)
32+
print(cluster)

aten/cpu/cluster.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

aten/cpu/graclus.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#include <torch/torch.h>
22

3-
#include "degree.cpp"
4-
#include "loop.cpp"
5-
#include "perm.cpp"
3+
#include "../include/degree.cpp"
4+
#include "../include/loop.cpp"
5+
#include "../include/perm.cpp"
66

77
at::Tensor graclus(at::Tensor row, at::Tensor col, int num_nodes) {
88
std::tie(row, col) = remove_self_loops(row, col);

aten/cpu/setup.py

Lines changed: 0 additions & 8 deletions
This file was deleted.

aten/cuda/cluster.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

aten/cuda/setup.py

Lines changed: 0 additions & 10 deletions
This file was deleted.
Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
#ifndef DEGREE_CPP
2-
#define DEGREE_CPP
1+
#include "degree.h"
32

43
#include <torch/torch.h>
54

@@ -9,5 +8,3 @@ inline at::Tensor degree(at::Tensor index, int num_nodes,
98
auto one = at::full(zero.type(), {index.size(0)}, 1);
109
return zero.scatter_add_(0, index, one);
1110
}
12-
13-
#endif // DEGREE_CPP

aten/include/degree.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#ifndef DEGREE_INC
2+
#define DEGREE_INC
3+
4+
#include <torch/torch.h>
5+
6+
inline at::Tensor degree(at::Tensor index, int num_nodes,
7+
at::ScalarType scalar_type);
8+
9+
#endif // DEGREE_INC
Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
#ifndef LOOP_CPP
2-
#define LOOP_CPP
1+
#include "loop.h"
32

43
#include <torch/torch.h>
54

@@ -8,5 +7,3 @@ inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
87
auto mask = row != col;
98
return {row.masked_select(mask), col.masked_select(mask)};
109
}
11-
12-
#endif // LOOP_CPP

aten/include/loop.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#ifndef LOOP_INC
2+
#define LOOP_INC
3+
4+
#include <torch/torch.h>
5+
6+
inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
7+
at::Tensor col);
8+
9+
#endif // LOOP_INC

0 commit comments

Comments
 (0)