Skip to content

Commit ce9a53d

Browse files
committed
return C with batch
1 parent 6a2e1a0 commit ce9a53d

File tree

4 files changed

+5
-5
lines changed

4 files changed

+5
-5
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from setuptools import setup, find_packages
44

5-
__version__ = '0.2.0'
5+
__version__ = '0.2.1'
66
url = 'https://github.com/rusty1s/pytorch_cluster'
77

88
install_requires = ['cffi', 'torch-unique']

test/test_grid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_grid_cluster_cpu(tensor):
5050
output, C = grid_cluster(position, size, batch, fake_nodes=True)
5151
expected = torch.LongTensor([0, 5, 1, 0, 2, 6, 11, 7, 6, 8])
5252
assert output.tolist() == expected.tolist()
53-
assert C == 6
53+
assert C == 12
5454

5555

5656
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
@@ -101,4 +101,4 @@ def test_grid_cluster_gpu(tensor): # pragma: no cover
101101
output, C = grid_cluster(position, size, batch, fake_nodes=True)
102102
expected = torch.LongTensor([0, 5, 1, 0, 2, 6, 11, 7, 6, 8])
103103
assert output.cpu().tolist() == expected.tolist()
104-
assert C == 6
104+
assert C == 12

torch_cluster/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .functions.grid import grid_cluster
22

3-
__version__ = '0.2.0'
3+
__version__ = '0.2.1'
44

55
__all__ = ['grid_cluster', '__version__']

torch_cluster/functions/grid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def grid_cluster(position, size, batch=None, origin=None, fake_nodes=False):
5050
cluster = cluster.squeeze(dim=-1)
5151

5252
if fake_nodes:
53-
return cluster, C // c_max[0]
53+
return cluster, C
5454

5555
cluster, u = consecutive(cluster)
5656
return cluster, None if batch is None else (u / (C // c_max[0])).long()

0 commit comments

Comments
 (0)