Skip to content

Commit 6a2e1a0

Browse files
committed
rename
1 parent fb3d8a8 commit 6a2e1a0

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

test/test_grid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def test_grid_cluster_cpu(tensor):
1414
assert output.tolist() == expected.tolist()
1515

1616
expected = torch.LongTensor([0, 1])
17-
output, _ = grid_cluster(position, size, offset=0)
17+
output, _ = grid_cluster(position, size, origin=0)
1818
assert output.tolist() == expected.tolist()
1919

2020
position = Tensor(tensor, [0, 17, 2, 8, 3])
@@ -63,7 +63,7 @@ def test_grid_cluster_gpu(tensor): # pragma: no cover
6363
assert output.cpu().tolist() == expected.tolist()
6464

6565
expected = torch.LongTensor([0, 1])
66-
output, _ = grid_cluster(position, size, offset=0)
66+
output, _ = grid_cluster(position, size, origin=0)
6767
assert output.cpu().tolist() == expected.tolist()
6868

6969
position = Tensor(tensor, [0, 17, 2, 8, 3]).cuda()

torch_cluster/functions/grid.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from .utils import get_func, consecutive
44

55

6-
def grid_cluster(position, size, batch=None, offset=None, fake_nodes=False):
6+
def grid_cluster(position, size, batch=None, origin=None, fake_nodes=False):
77
# Allow one-dimensional positions.
88
if position.dim() == 1:
99
position = position.unsqueeze(-1)
@@ -21,14 +21,14 @@ def grid_cluster(position, size, batch=None, offset=None, fake_nodes=False):
2121
position = torch.cat([batch, position], dim=-1)
2222
size = torch.cat([size.new(1).fill_(1), size], dim=-1)
2323

24-
# Translate to minimal positive positions if no offset is passed.
25-
if offset is None:
24+
# Translate to minimal positive positions if no origin was passed.
25+
if origin is None:
2626
min = position.min(dim=-2, keepdim=True)[0]
2727
position = position - min
2828
else:
29-
position = position + offset
29+
position = position + origin
3030
assert position.min() >= 0, (
31-
'Passed offset resulting in unallowed negative positions')
31+
'Passed origin resulting in unallowed negative positions')
3232

3333
# Compute cluster count for each dimension.
3434
max = position.max(dim=0)[0]

0 commit comments

Comments
 (0)