Skip to content

Commit 92105bf

Browse files
committed
bugfixes
1 parent cb0e5f6 commit 92105bf

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

aten/cuda/cluster.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import torch
2+
import cluster_cuda
3+
4+
dtype = torch.float
5+
device = torch.device('cuda')
6+
7+
8+
def grid_cluster(pos, size, start=None, end=None):
9+
start = pos.t().min(dim=1)[0] if start is None else start
10+
end = pos.t().max(dim=1)[0] if end is None else end
11+
return cluster_cuda.grid(pos, size, start, end)
12+
13+
14+
pos = torch.tensor(
15+
[[1, 1], [3, 3], [5, 5], [7, 7]], dtype=dtype, device=device)
16+
size = torch.tensor([2, 2, 1, 1, 4, 2, 1], dtype=dtype, device=device)
17+
# print('pos', pos.tolist())
18+
# print('size', size.tolist())
19+
cluster = grid_cluster(pos, size)
20+
print('result', cluster.tolist(), cluster.type())

aten/cuda/cluster_kernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
3333
cluster.data<int64_t>(),
3434
at::cuda::detail::getTensorInfo<scalar_t, int>(pos),
3535
size.toType(pos.type()).data<scalar_t>(),
36-
start..toType(pos.type()).data<scalar_t>(),
36+
start.toType(pos.type()).data<scalar_t>(),
3737
end.toType(pos.type()).data<scalar_t>(), num_nodes);
3838
});
3939

0 commit comments

Comments
 (0)