33from .utils import get_func , consecutive
44
55
6- def grid_cluster (position , size , batch = None , offset = None , fake_nodes = False ):
6+ def grid_cluster (position , size , batch = None , origin = None , fake_nodes = False ):
77 # Allow one-dimensional positions.
88 if position .dim () == 1 :
99 position = position .unsqueeze (- 1 )
@@ -21,14 +21,14 @@ def grid_cluster(position, size, batch=None, offset=None, fake_nodes=False):
2121 position = torch .cat ([batch , position ], dim = - 1 )
2222 size = torch .cat ([size .new (1 ).fill_ (1 ), size ], dim = - 1 )
2323
24- # Translate to minimal positive positions if no offset is passed.
25- if offset is None :
24+ # Translate to minimal positive positions if no origin was passed.
25+ if origin is None :
2626 min = position .min (dim = - 2 , keepdim = True )[0 ]
2727 position = position - min
2828 else :
29- position = position + offset
29+ position = position + origin
3030 assert position .min () >= 0 , (
31- 'Passed offset resulting in unallowed negative positions' )
31+ 'Passed origin resulting in unallowed negative positions' )
3232
3333 # Compute cluster count for each dimension.
3434 max = position .max (dim = 0 )[0 ]
0 commit comments