Skip to content

Commit f0a4d21

Browse files
committed
none arguments
1 parent fa36a83 commit f0a4d21

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

aten/cpu/cluster.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@ at::Tensor graclus(at::Tensor row, at::Tensor col, at::Tensor weight) {
77

88

99
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor end) {
10-
if (!start.defined()) start = std::get<0>(pos.min(1));
11-
if (!end.defined()) end = std::get<0>(pos.max(1));
12-
1310
size = size.toType(pos.type());
1411
start = start.toType(pos.type());
1512
end = end.toType(pos.type());
@@ -30,7 +27,6 @@ at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor en
3027

3128

3229
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
33-
m.def("graclus", &graclus, "Graclus (CPU)", py::arg("row"), py::arg("col"), py::arg("weight"));
34-
m.def("grid", &grid, "Grid (CPU)", py::arg("pos"), py::arg("size"), py::arg("start"),
35-
py::arg("end"));
30+
m.def("graclus", &graclus, "Graclus (CPU)");
31+
m.def("grid", &grid, "Grid (CPU)");
3632
}

aten/cpu/cluster.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,17 @@
33
import cluster_cpu
44

55

6-
def grid_cluster(pos, size, start, end):
6+
def grid_cluster(pos, size, start=None, end=None):
7+
start = pos.t().min(dim=1)[0] if start is None else start
8+
end = pos.t().max(dim=1)[0] if end is None else end
79
return cluster_cpu.grid(pos, size, start, end)
810

911

10-
pos = torch.tensor([[1, 1], [3, 3], [5, 5], [7, 7]], dtype=torch.uint8)
12+
pos = torch.tensor([[1, 1], [3, 3], [5, 5], [7, 7]])
1113
size = torch.tensor([2, 2])
1214
start = torch.tensor([0, 0])
1315
end = torch.tensor([7, 7])
1416
print('pos', pos.tolist())
1517
print('size', size.tolist())
16-
cluster = grid_cluster(pos, size, start, end)
18+
cluster = grid_cluster(pos, size)
1719
print('result', cluster.tolist(), cluster.dtype)

0 commit comments

Comments
 (0)