Skip to content

Commit 16b976c

Browse files
committed
new cuda layout
1 parent 9a9d973 commit 16b976c

File tree

10 files changed

+113
-98
lines changed

10 files changed

+113
-98
lines changed

aten/cuda/cluster.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
#include <torch/torch.h>
22

3-
#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor")
3+
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
4+
at::Tensor end);
45

5-
#include "graclus.cpp"
6-
#include "grid.cpp"
6+
at::Tensor graclus(at::Tensor row, at::Tensor col, int num_nodes);
7+
8+
at::Tensor weighted_graclus(at::Tensor row, at::Tensor col, at::Tensor weight,
9+
int num_nodes);
710

811
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
9-
m.def("graclus", &graclus, "Graclus (CUDA)");
1012
m.def("grid", &grid, "Grid (CUDA)");
13+
m.def("graclus", &graclus, "Graclus (CUDA)");
14+
m.def("weighted_graclus", &weighted_graclus, "Weightes Graclus (CUDA)");
1115
}

aten/cuda/color.cuh

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
5+
#include "common.cuh"
6+
7+
#define BLUE_PROB 0.53406
8+
9+
__global__ void color_kernel(int64_t *cluster, size_t num_nodes) {
10+
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
11+
const size_t stride = blockDim.x * gridDim.x;
12+
for (ptrdiff_t i = index; i < num_nodes; i += stride) {
13+
}
14+
}
15+
16+
inline bool color(at::Tensor cluster) {
17+
color_kernel<scalar_t><<<BLOCKS(cluster.size(0)), THREADS>>>(
18+
cluster.data<int64_t>(), cluster.size(0));
19+
20+
return true;
21+
}

aten/cuda/common.cuh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
5+
#define THREADS 1024
6+
#define BLOCKS(N) (N + THREADS - 1) / THREADS
7+
8+
inline at::Tensor degree(at::Tensor index, int num_nodes) {
9+
auto zero = at::zeros(index.type(), {num_nodes});
10+
auto one = at::ones(index.type(), {index.size(0)});
11+
return zero.scatter_add_(0, index, one);
12+
}

aten/cuda/graclus.cpp

Lines changed: 0 additions & 14 deletions
This file was deleted.

aten/cuda/graclus_kernel.cu

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#include <ATen/ATen.h>
2+
3+
#include "color.cuh"
4+
#include "common.cuh"
5+
6+
at::Tensor graclus(at::Tensor row, at::Tensor col, int num_nodes) {
7+
// Remove self-loops.
8+
auto mask = row != col;
9+
row = row.masked_select(mask);
10+
col.masked_select(mask);
11+
12+
// Sort by row index.
13+
at::Tensor perm;
14+
std::tie(row, perm) = row.sort();
15+
col = col.index_select(0, perm);
16+
17+
// Generate helper vectors.
18+
auto cluster = at::full(row.type(), {num_nodes}, -1);
19+
auto prop = at::full(row.type(), {num_nodes}, -1);
20+
auto deg = degree(row, num_nodes);
21+
auto cum_deg = deg.cumsum(0);
22+
23+
color(cluster);
24+
25+
/* while (!color(cluster)) { */
26+
/* propose(cluster, prop, row, col, weight, deg, cum_deg); */
27+
/* response(cluster, prop, row, col, weight, deg, cum_deg); */
28+
/* } */
29+
30+
return cluster;
31+
}
32+
33+
at::Tensor weighted_graclus(at::Tensor row, at::Tensor col, at::Tensor weight,
34+
int num_nodes) {
35+
// Remove self-loops.
36+
auto mask = row != col;
37+
row = row.masked_select(mask);
38+
col = col.masked_select(mask);
39+
weight = weight.masked_select(mask);
40+
41+
// Sort by row index.
42+
at::Tensor perm;
43+
std::tie(row, perm) = row.sort();
44+
col = col.index_select(0, perm);
45+
weight = weight.index_select(0, perm);
46+
47+
// Generate helper vectors.
48+
auto cluster = at::full(row.type(), {num_nodes}, -1);
49+
auto prop = at::full(row.type(), {num_nodes}, -1);
50+
auto deg = degree(row, num_nodes);
51+
auto cum_deg = deg.cumsum(0);
52+
53+
color(cluster);
54+
55+
/* while (!color(cluster)) { */
56+
/* weighted_propose(cluster, prop, row, col, weight, deg, cum_deg); */
57+
/* weighted_response(cluster, prop, row, col, weight, deg, cum_deg); */
58+
/* } */
59+
60+
return cluster;
61+
}

aten/cuda/grid.cpp

Lines changed: 0 additions & 14 deletions
This file was deleted.

aten/cuda/grid_kernel.cu

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,38 @@
11
#include <ATen/ATen.h>
22
#include <ATen/cuda/detail/IndexUtils.cuh>
33

4-
#define THREADS 1024
5-
#define BLOCKS(N) (N + THREADS - 1) / THREADS
4+
#include "common.cuh"
65

76
template <typename scalar_t>
87
__global__ void
9-
grid_cuda_kernel(int64_t *cluster,
10-
at::cuda::detail::TensorInfo<scalar_t, int> pos,
11-
scalar_t *__restrict__ size, scalar_t *__restrict__ start,
12-
scalar_t *__restrict__ end, size_t num_nodes) {
8+
grid_kernel(int64_t *cluster, at::cuda::detail::TensorInfo<scalar_t, int> pos,
9+
scalar_t *__restrict__ size, scalar_t *__restrict__ start,
10+
scalar_t *__restrict__ end, size_t num_nodes) {
1311
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
1412
const size_t stride = blockDim.x * gridDim.x;
1513
for (ptrdiff_t i = index; i < num_nodes; i += stride) {
1614
int64_t c = 0, k = 1;
1715
scalar_t tmp;
1816
for (ptrdiff_t d = 0; d < pos.sizes[1]; d++) {
19-
tmp = (pos.data[i * pos.strides[0] + d * pos.strides[1]]) - start[d];
17+
tmp = pos.data[i * pos.strides[0] + d * pos.strides[1]] - start[d];
2018
c += (int64_t)(tmp / size[d]) * k;
2119
k += (int64_t)((end[d] - start[d]) / size[d]);
2220
}
2321
cluster[i] = c;
2422
}
2523
}
2624

27-
at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
28-
at::Tensor end) {
29-
auto num_nodes = pos.size(0);
30-
auto cluster = at::empty(pos.type().toScalarType(at::kLong), {num_nodes});
25+
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
26+
at::Tensor end) {
27+
auto cluster = at::empty(pos.type().toScalarType(at::kLong), {pos.size(0)});
3128

32-
AT_DISPATCH_ALL_TYPES(pos.type(), "grid_cuda_kernel", [&] {
33-
grid_cuda_kernel<scalar_t><<<BLOCKS(num_nodes), THREADS>>>(
29+
AT_DISPATCH_ALL_TYPES(pos.type(), "grid_kernel", [&] {
30+
grid_kernel<scalar_t><<<BLOCKS(pos.size(0)), THREADS>>>(
3431
cluster.data<int64_t>(),
3532
at::cuda::detail::getTensorInfo<scalar_t, int>(pos),
3633
size.toType(pos.type()).data<scalar_t>(),
3734
start.toType(pos.type()).data<scalar_t>(),
38-
end.toType(pos.type()).data<scalar_t>(), num_nodes);
35+
end.toType(pos.type()).data<scalar_t>(), pos.size(0));
3936
});
4037

4138
return cluster;

aten/include/degree.cpp

Lines changed: 0 additions & 13 deletions
This file was deleted.

aten/include/loop.cpp

Lines changed: 0 additions & 12 deletions
This file was deleted.

aten/include/perm.cpp

Lines changed: 0 additions & 27 deletions
This file was deleted.

0 commit comments

Comments
 (0)