1- from .utils .ffi import grid
1+ import torch
2+ import grid_cpu
3+
4+ if torch .cuda .is_available ():
5+ import grid_cuda
26
37
48def grid_cluster (pos , size , start = None , end = None ):
@@ -8,9 +12,9 @@ def grid_cluster(pos, size, start=None, end=None):
812 Args:
913 pos (Tensor): D-dimensional position of points.
1014 size (Tensor): Size of a voxel in each dimension.
11- start (Tensor or int , optional): Start position of the grid (in each
15+ start (Tensor, optional): Start position of the grid (in each
1216 dimension). (default: :obj:`None`)
13- end (Tensor or int , optional): End position of the grid (in each
17+ end (Tensor, optional): End position of the grid (in each
1418 dimension). (default: :obj:`None`)
1519
1620 Examples::
@@ -21,18 +25,12 @@ def grid_cluster(pos, size, start=None, end=None):
2125 """
2226
2327 pos = pos .unsqueeze (- 1 ) if pos .dim () == 1 else pos
24-
25- assert pos .size (1 ) == size .size (0 ), (
26- 'Last dimension of position tensor must have same size as size tensor' )
27-
2828 start = pos .t ().min (dim = 1 )[0 ] if start is None else start
2929 end = pos .t ().max (dim = 1 )[0 ] if end is None else end
30- pos , end = pos - start , end - start
31-
32- size = size .type_as (pos )
33- count = (end / size ).long () + 1
3430
35- cluster = count .new (pos .size (0 ))
36- grid (cluster , pos , size , count )
31+ if pos .is_cuda :
32+ cluster = grid_cuda .grid (pos , size , start , end )
33+ else :
34+ cluster = grid_cpu .grid (pos , size , start , end )
3735
3836 return cluster
0 commit comments