Skip to content

Commit 878e119

Browse files
committed
new api
1 parent cba1cdb commit 878e119

File tree

2 files changed

+14
-16
lines changed

2 files changed

+14
-16
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ from torch_cluster import graclus_cluster
5656

5757
row = torch.tensor([0, 1, 1, 2])
5858
col = torch.tensor([1, 0, 2, 1])
59-
weight = torch.tensor([1, 1, 1, 1]) # Optional edge weights.
59+
weight = torch.Tensor([1, 1, 1, 1]) # Optional edge weights.
6060

6161
cluster = graclus_cluster(row, col, weight)
6262
```
@@ -74,8 +74,8 @@ A clustering algorithm, which overlays a regular grid of user-defined size over
7474
import torch
7575
from torch_cluster import grid_cluster
7676

77-
pos = torch.tensor([[0, 0], [11, 9], [2, 8], [2, 2], [8, 3]])
78-
size = torch.tensor([5, 5])
77+
pos = torch.Tensor([[0, 0], [11, 9], [2, 8], [2, 2], [8, 3]])
78+
size = torch.Tensor([5, 5])
7979

8080
cluster = grid_cluster(pos, size)
8181
```

torch_cluster/grid.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
from .utils.ffi import grid
1+
import torch
2+
import grid_cpu
3+
4+
if torch.cuda.is_available():
5+
import grid_cuda
26

37

48
def grid_cluster(pos, size, start=None, end=None):
@@ -8,9 +12,9 @@ def grid_cluster(pos, size, start=None, end=None):
812
Args:
913
pos (Tensor): D-dimensional position of points.
1014
size (Tensor): Size of a voxel in each dimension.
11-
start (Tensor or int, optional): Start position of the grid (in each
15+
start (Tensor, optional): Start position of the grid (in each
1216
dimension). (default: :obj:`None`)
13-
end (Tensor or int, optional): End position of the grid (in each
17+
end (Tensor, optional): End position of the grid (in each
1418
dimension). (default: :obj:`None`)
1519
1620
Examples::
@@ -21,18 +25,12 @@ def grid_cluster(pos, size, start=None, end=None):
2125
"""
2226

2327
pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos
24-
25-
assert pos.size(1) == size.size(0), (
26-
'Last dimension of position tensor must have same size as size tensor')
27-
2828
start = pos.t().min(dim=1)[0] if start is None else start
2929
end = pos.t().max(dim=1)[0] if end is None else end
30-
pos, end = pos - start, end - start
31-
32-
size = size.type_as(pos)
33-
count = (end / size).long() + 1
3430

35-
cluster = count.new(pos.size(0))
36-
grid(cluster, pos, size, count)
31+
if pos.is_cuda:
32+
cluster = grid_cuda.grid(pos, size, start, end)
33+
else:
34+
cluster = grid_cpu.grid(pos, size, start, end)
3735

3836
return cluster

0 commit comments

Comments
 (0)