Skip to content

Commit 801e5f1

Browse files
committed
added position translate
1 parent 96c3cd4 commit 801e5f1

File tree

4 files changed

+71
-32
lines changed

4 files changed

+71
-32
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from setuptools import setup, find_packages
44

5-
__version__ = '0.1.1'
5+
__version__ = '0.2.0'
66
url = 'https://github.com/rusty1s/pytorch_cluster'
77

88
install_requires = ['cffi', 'torch-unique']

test/test_grid.py

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,62 +7,98 @@
77

88
@pytest.mark.parametrize('tensor', tensors)
99
def test_grid_cluster_cpu(tensor):
10-
position = Tensor(tensor, [0, 9, 2, 8, 3])
10+
position = Tensor(tensor, [2, 6])
1111
size = torch.LongTensor([5])
12-
expected = torch.LongTensor([0, 1, 0, 1, 0])
13-
output = grid_cluster(position, size)
12+
expected = torch.LongTensor([0, 0])
13+
output, _ = grid_cluster(position, size)
14+
assert output.tolist() == expected.tolist()
15+
16+
expected = torch.LongTensor([0, 1])
17+
output, _ = grid_cluster(position, size, offset=0)
18+
assert output.tolist() == expected.tolist()
19+
20+
position = Tensor(tensor, [0, 17, 2, 8, 3])
21+
expected = torch.LongTensor([0, 2, 0, 1, 0])
22+
output, _ = grid_cluster(position, size)
23+
assert output.tolist() == expected.tolist()
24+
25+
output, _ = grid_cluster(position, size, fake_nodes=True)
26+
expected = torch.LongTensor([0, 3, 0, 1, 0])
1427
assert output.tolist() == expected.tolist()
1528

1629
position = Tensor(tensor, [[0, 0], [9, 9], [2, 8], [2, 2], [8, 3]])
1730
size = torch.LongTensor([5, 5])
1831
expected = torch.LongTensor([0, 3, 1, 0, 2])
19-
output = grid_cluster(position, size)
32+
output, _ = grid_cluster(position, size)
2033
assert output.tolist() == expected.tolist()
2134

22-
position = Tensor(tensor, [[0, 9, 2, 2, 8], [0, 9, 8, 2, 3]]).t()
23-
output = grid_cluster(position, size)
35+
position = Tensor(tensor, [[0, 11, 2, 2, 8], [0, 9, 8, 2, 3]]).t()
36+
output, _ = grid_cluster(position, size)
2437
assert output.tolist() == expected.tolist()
2538

26-
output = grid_cluster(position.expand(2, 5, 2), size)
39+
output, _ = grid_cluster(position.expand(2, 5, 2), size)
2740
assert output.tolist() == expected.expand(2, 5).tolist()
2841

2942
position = position.repeat(2, 1)
3043
batch = torch.LongTensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
3144
expected = torch.LongTensor([0, 3, 1, 0, 2, 4, 7, 5, 4, 6])
32-
expected_batch = torch.LongTensor([0, 0, 0, 0, 1, 1, 1, 1])
33-
output, reduced_batch = grid_cluster(position, size, batch)
45+
expected_batch2 = torch.LongTensor([0, 0, 0, 0, 1, 1, 1, 1])
46+
output, batch2 = grid_cluster(position, size, batch)
47+
assert output.tolist() == expected.tolist()
48+
assert batch2.tolist() == expected_batch2.tolist()
49+
50+
output, C = grid_cluster(position, size, batch, fake_nodes=True)
51+
expected = torch.LongTensor([0, 5, 1, 0, 2, 6, 11, 7, 6, 8])
3452
assert output.tolist() == expected.tolist()
35-
assert reduced_batch.tolist() == expected_batch.tolist()
53+
assert C == 6
3654

3755

3856
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
3957
@pytest.mark.parametrize('tensor', tensors)
4058
def test_grid_cluster_gpu(tensor): # pragma: no cover
41-
position = Tensor(tensor, [0, 9, 2, 8, 3]).cuda()
59+
position = Tensor(tensor, [2, 6]).cuda()
4260
size = torch.cuda.LongTensor([5])
43-
expected = torch.cuda.LongTensor([0, 1, 0, 1, 0])
44-
output = grid_cluster(position, size)
61+
expected = torch.LongTensor([0, 0])
62+
output, _ = grid_cluster(position, size)
63+
assert output.cpu().tolist() == expected.tolist()
64+
65+
expected = torch.LongTensor([0, 1])
66+
output, _ = grid_cluster(position, size, offset=0)
67+
assert output.cpu().tolist() == expected.tolist()
68+
69+
position = Tensor(tensor, [0, 17, 2, 8, 3]).cuda()
70+
expected = torch.LongTensor([0, 2, 0, 1, 0])
71+
output, _ = grid_cluster(position, size)
72+
assert output.cpu().tolist() == expected.tolist()
73+
74+
output, _ = grid_cluster(position, size, fake_nodes=True)
75+
expected = torch.LongTensor([0, 3, 0, 1, 0])
4576
assert output.cpu().tolist() == expected.tolist()
4677

4778
position = Tensor(tensor, [[0, 0], [9, 9], [2, 8], [2, 2], [8, 3]])
4879
position = position.cuda()
4980
size = torch.cuda.LongTensor([5, 5])
50-
expected = torch.cuda.LongTensor([0, 3, 1, 0, 2])
51-
output = grid_cluster(position, size)
81+
expected = torch.LongTensor([0, 3, 1, 0, 2])
82+
output, _ = grid_cluster(position, size)
5283
assert output.cpu().tolist() == expected.tolist()
5384

54-
position = Tensor(tensor, [[0, 9, 2, 2, 8], [0, 9, 8, 2, 3]])
85+
position = Tensor(tensor, [[0, 11, 2, 2, 8], [0, 9, 8, 2, 3]])
5586
position = position.cuda().t()
56-
output = grid_cluster(position, size)
87+
output, _ = grid_cluster(position, size)
5788
assert output.cpu().tolist() == expected.tolist()
5889

59-
output = grid_cluster(position.expand(2, 5, 2), size)
90+
output, _ = grid_cluster(position.expand(2, 5, 2), size)
6091
assert output.tolist() == expected.expand(2, 5).tolist()
6192

6293
position = position.repeat(2, 1)
6394
batch = torch.cuda.LongTensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
6495
expected = torch.LongTensor([0, 3, 1, 0, 2, 4, 7, 5, 4, 6])
65-
expected_batch = torch.LongTensor([0, 0, 0, 0, 1, 1, 1, 1])
66-
output, reduced_batch = grid_cluster(position, size, batch)
96+
expected_batch2 = torch.LongTensor([0, 0, 0, 0, 1, 1, 1, 1])
97+
output, batch2 = grid_cluster(position, size, batch)
98+
assert output.cpu().tolist() == expected.tolist()
99+
assert batch2.cpu().tolist() == expected_batch2.tolist()
100+
101+
output, C = grid_cluster(position, size, batch, fake_nodes=True)
102+
expected = torch.LongTensor([0, 5, 1, 0, 2, 6, 11, 7, 6, 8])
67103
assert output.cpu().tolist() == expected.tolist()
68-
assert reduced_batch.cpu().tolist() == expected_batch.tolist()
104+
assert C == 6

torch_cluster/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .functions.grid import grid_cluster
22

3-
__version__ = '0.1.1'
3+
__version__ = '0.2.0'
44

55
__all__ = ['grid_cluster', '__version__']

torch_cluster/functions/grid.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from .utils import get_func, consecutive
44

55

6-
def grid_cluster(position, size, batch=None):
6+
def grid_cluster(position, size, batch=None, fake_nodes=False, offset=None):
77
# Allow one-dimensional positions.
88
if position.dim() == 1:
99
position = position.unsqueeze(-1)
@@ -22,8 +22,12 @@ def grid_cluster(position, size, batch=None):
2222
size = torch.cat([size.new(1).fill_(1), size], dim=-1)
2323

2424
# Translate to minimal positive positions.
25-
min = position.min(dim=-2, keepdim=True)[0]
26-
position = position - min
25+
if offset is None:
26+
min = position.min(dim=-2, keepdim=True)[0]
27+
position = position - min
28+
else:
29+
position = position + offset
30+
assert position.min() >= 0, 'Offset resulting in negative positions'
2731

2832
# Compute cluster count for each dimension.
2933
max = position.max(dim=0)[0]
@@ -43,10 +47,9 @@ def grid_cluster(position, size, batch=None):
4347
func = get_func('grid', position)
4448
func(C, cluster, position, size, c_max)
4549
cluster = cluster.squeeze(dim=-1)
46-
cluster, u = consecutive(cluster)
4750

48-
if batch is None:
49-
return cluster
50-
else:
51-
batch = (u / c_max[1:].prod()).long()
52-
return cluster, batch
51+
if fake_nodes:
52+
return cluster, C // c_max[0]
53+
54+
cluster, u = consecutive(cluster)
55+
return cluster, None if batch is None else (u / (C // c_max[0])).long()

0 commit comments

Comments
 (0)