Skip to content

Commit 7985cdd

Browse files
committed
removed self loops in graclus
1 parent d2ee152 commit 7985cdd

File tree

4 files changed

+7
-3
lines changed

4 files changed

+7
-3
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from setuptools import setup, find_packages
44

5-
__version__ = '1.0.1'
5+
__version__ = '1.0.2'
66
url = 'https://github.com/rusty1s/pytorch_cluster'
77

88
install_requires = ['cffi']

torch_cluster/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .graclus import graclus_cluster
22
from .grid import grid_cluster
33

4-
__version__ = '1.0.1'
4+
__version__ = '1.0.2'
55

66
__all__ = ['graclus_cluster', 'grid_cluster', '__version__']

torch_cluster/graclus.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .utils.loop import remove_self_loops
12
from .utils.perm import randperm, sort_row, randperm_sort_row
23
from .utils.ffi import graclus
34

@@ -19,7 +20,6 @@ def graclus_cluster(row, col, weight=None, num_nodes=None):
1920
>>> weight = torch.Tensor([1, 1, 1, 1])
2021
>>> cluster = graclus_cluster(row, col, weight)
2122
"""
22-
2323
num_nodes = row.max() + 1 if num_nodes is None else num_nodes
2424

2525
if row.is_cuda: # pragma: no cover
@@ -28,6 +28,7 @@ def graclus_cluster(row, col, weight=None, num_nodes=None):
2828
row, col = randperm(row, col)
2929
row, col = randperm_sort_row(row, col, num_nodes)
3030

31+
row, col = remove_self_loops(row, col)
3132
cluster = row.new(num_nodes)
3233
graclus(cluster, row, col, weight)
3334

torch_cluster/utils/loop.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
def remove_self_loops(row, col):
2+
mask = row != col
3+
return row[mask], col[mask]

0 commit comments

Comments
 (0)