Skip to content

Commit 1edd387

Browse files
committed
pytorch 0.4.0
1 parent b87eab0 commit 1edd387

File tree

7 files changed

+42
-75
lines changed

7 files changed

+42
-75
lines changed

test/tensor.py

Lines changed: 0 additions & 4 deletions
This file was deleted.

test/test_graclus.py

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33
import pytest
44
import torch
5-
import numpy as np
65
from torch_cluster import graclus_cluster
76

8-
from .tensor import tensors
7+
from .utils import dtypes, devices, tensor
98

109
tests = [{
1110
'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
@@ -18,51 +17,34 @@
1817

1918

2019
def assert_correct_graclus(row, col, cluster):
21-
row, col = row.cpu().numpy(), col.cpu().numpy()
22-
cluster, n_nodes = cluster.cpu().numpy(), cluster.size(0)
20+
row, col, cluster = row.to('cpu'), col.to('cpu'), cluster.to('cpu')
21+
n = cluster.size(0)
2322

2423
# Every node was assigned a cluster.
2524
assert cluster.min() >= 0
2625

2726
# There are no more than two nodes in each cluster.
28-
_, count = np.unique(cluster, return_counts=True)
27+
_, index = torch.unique(cluster, return_inverse=True)
28+
count = torch.zeros_like(cluster)
29+
count.scatter_add_(0, index, torch.ones_like(cluster))
2930
assert (count > 2).max() == 0
3031

3132
# Cluster value is minimal.
32-
assert (cluster <= np.arange(n_nodes, dtype=row.dtype)).sum() == n_nodes
33+
assert (cluster <= torch.arange(n, dtype=cluster.dtype)).sum() == n
3334

3435
# Corresponding clusters must be adjacent.
35-
for n in range(cluster.shape[0]):
36-
x = cluster[col[row == n]] == cluster[n] # Neighbors with same cluster
37-
y = cluster == cluster[n] # Nodes with same cluster
38-
y[n] = 0 # Do not look at cluster of node `n`.
36+
for i in range(n):
37+
x = cluster[col[row == i]] == cluster[i] # Neighbors with same cluster
38+
y = cluster == cluster[i] # Nodes with same cluster.
39+
y[i] = 0 # Do not look at cluster of `i`.
3940
assert x.sum() == y.sum()
4041

4142

42-
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests))))
43-
def test_graclus_cluster_cpu(tensor, i):
44-
data = tests[i]
45-
46-
row = torch.LongTensor(data['row'])
47-
col = torch.LongTensor(data['col'])
48-
49-
weight = data.get('weight')
50-
weight = weight if weight is None else getattr(torch, tensor)(weight)
51-
52-
cluster = graclus_cluster(row, col, weight)
53-
assert_correct_graclus(row, col, cluster)
54-
55-
56-
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
57-
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests))))
58-
def test_graclus_cluster_gpu(tensor, i): # pragma: no cover
59-
data = tests[i]
60-
61-
row = torch.cuda.LongTensor(data['row'])
62-
col = torch.cuda.LongTensor(data['col'])
63-
64-
weight = data.get('weight')
65-
weight = weight if weight is None else getattr(torch.cuda, tensor)(weight)
43+
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
44+
def test_graclus_cluster_cpu(test, dtype, device):
45+
row = tensor(test['row'], torch.long, device)
46+
col = tensor(test['col'], torch.long, device)
47+
weight = tensor(test.get('weight'), dtype, device)
6648

6749
cluster = graclus_cluster(row, col, weight)
6850
assert_correct_graclus(row, col, cluster)

test/test_grid.py

Lines changed: 8 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from itertools import product
22

33
import pytest
4-
import torch
54
from torch_cluster import grid_cluster
65

7-
from .tensor import tensors
6+
from .utils import dtypes, devices, tensor
87

98
tests = [{
109
'pos': [2, 6],
@@ -27,36 +26,12 @@
2726
}]
2827

2928

30-
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests))))
31-
def test_grid_cluster_cpu(tensor, i):
32-
data = tests[i]
33-
34-
pos = getattr(torch, tensor)(data['pos'])
35-
size = getattr(torch, tensor)(data['size'])
36-
37-
start = data.get('start')
38-
start = start if start is None else getattr(torch, tensor)(start)
39-
40-
end = data.get('end')
41-
end = end if end is None else getattr(torch, tensor)(end)
42-
43-
cluster = grid_cluster(pos, size, start, end)
44-
assert cluster.tolist() == data['cluster']
45-
46-
47-
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
48-
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests))))
49-
def test_grid_cluster_gpu(tensor, i): # pragma: no cover
50-
data = tests[i]
51-
52-
pos = getattr(torch.cuda, tensor)(data['pos'])
53-
size = getattr(torch.cuda, tensor)(data['size'])
54-
55-
start = data.get('start')
56-
start = start if start is None else getattr(torch.cuda, tensor)(start)
57-
58-
end = data.get('end')
59-
end = end if end is None else getattr(torch.cuda, tensor)(end)
29+
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
30+
def test_grid_cluster_cpu(test, dtype, device):
31+
pos = tensor(test['pos'], dtype, device)
32+
size = tensor(test['size'], dtype, device)
33+
start = tensor(test.get('start'), dtype, device)
34+
end = tensor(test.get('end'), dtype, device)
6035

6136
cluster = grid_cluster(pos, size, start, end)
62-
assert cluster.tolist() == data['cluster']
37+
assert cluster.tolist() == test['cluster']

test/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import torch
2+
from torch.testing import get_all_dtypes
3+
4+
dtypes = get_all_dtypes()
5+
dtypes.remove(torch.half)
6+
7+
devices = [torch.device('cpu')]
8+
9+
if torch.cuda.is_available():
10+
devices += [torch.device('cuda:{}'.format(torch.cuda.current_device()))]
11+
12+
13+
def tensor(x, dtype, device):
14+
return None if x is None else torch.tensor(x, dtype=dtype, device=device)

torch_cluster/graclus.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def graclus_cluster(row, col, weight=None, num_nodes=None):
2929
row, col = randperm_sort_row(row, col, num_nodes)
3030

3131
row, col = remove_self_loops(row, col)
32-
cluster = row.new(num_nodes)
32+
cluster = row.new_empty((num_nodes, ))
3333
graclus(cluster, row, col, weight)
3434

3535
return cluster

torch_cluster/utils/ffi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
def get_func(name, is_cuda, tensor=None):
55
prefix = 'THCC' if is_cuda else 'TH'
6-
prefix += 'Tensor' if tensor is None else type(tensor).__name__
6+
prefix += 'Tensor' if tensor is None else tensor.type().split('.')[-1]
77
return getattr(ffi, '{}_{}'.format(prefix, name))
88

99

torch_cluster/utils/perm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
def randperm(row, col):
55
# Randomly reorder row and column indices.
6-
edge_rid = torch.randperm(row.size(0)).type_as(row)
6+
edge_rid = torch.randperm(row.size(0))
77
return row[edge_rid], col[edge_rid]
88

99

@@ -16,7 +16,7 @@ def sort_row(row, col):
1616

1717
def randperm_sort_row(row, col, num_nodes):
1818
# Randomly change row indices to new values.
19-
node_rid = torch.randperm(num_nodes).type_as(row)
19+
node_rid = torch.randperm(num_nodes)
2020
row = node_rid[row]
2121

2222
# Sort row and column indices row-wise.

0 commit comments

Comments
 (0)