Skip to content

Commit a315a06

Browse files
committed
first aten try
1 parent de54cce commit a315a06

File tree

3 files changed

+61
-0
lines changed

3 files changed

+61
-0
lines changed

aten/cpu/cluster.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#include <torch/torch.h>
2+
3+
4+
at::Tensor graclus(at::Tensor row, at::Tensor col, at::Tensor weight) {
5+
return row;
6+
}
7+
8+
9+
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor end) {
10+
if (!start.defined()) start = std::get<0>(pos.min(1));
11+
if (!end.defined()) end = std::get<0>(pos.max(1));
12+
13+
size = size.toType(pos.type());
14+
start = start.toType(pos.type());
15+
end = end.toType(pos.type());
16+
17+
pos = pos - start.view({ 1, -1 });
18+
auto num_voxels = ((end - start) / size).toType(at::kLong);
19+
num_voxels = (num_voxels + 1).cumsum(0);
20+
num_voxels = num_voxels - num_voxels[0];
21+
num_voxels[0] = 1;
22+
23+
auto cluster = pos / size.view({ 1, -1 });
24+
cluster = cluster.toType(at::kLong);
25+
cluster *= num_voxels.view({ 1, -1 });
26+
cluster = cluster.sum(1);
27+
28+
return cluster;
29+
}
30+
31+
32+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
33+
m.def("graclus", &graclus, "Graclus (CPU)", py::arg("row"), py::arg("col"), py::arg("weight"));
34+
m.def("grid", &grid, "Grid (CPU)", py::arg("pos"), py::arg("size"), py::arg("start"),
35+
py::arg("end"));
36+
}

aten/cpu/cluster.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import torch
2+
3+
import cluster_cpu
4+
5+
6+
def grid_cluster(pos, size, start, end):
7+
return cluster_cpu.grid(pos, size, start, end)
8+
9+
10+
pos = torch.tensor([[1, 1], [3, 3], [5, 5], [7, 7]], dtype=torch.uint8)
11+
size = torch.tensor([2, 2])
12+
start = torch.tensor([0, 0])
13+
end = torch.tensor([7, 7])
14+
print('pos', pos.tolist())
15+
print('size', size.tolist())
16+
cluster = grid_cluster(pos, size, start, end)
17+
print('result', cluster.tolist(), cluster.dtype)

aten/cpu/setup.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from setuptools import setup
2+
from torch.utils.cpp_extension import BuildExtension, CppExtension
3+
4+
setup(
5+
name='cluster',
6+
ext_modules=[CppExtension('cluster_cpu', ['cluster.cpp'])],
7+
cmdclass={'build_ext': BuildExtension},
8+
)

0 commit comments

Comments
 (0)