|
2 | 2 |
|
3 | 3 | import pytest |
4 | 4 | import torch |
5 | | -import numpy as np |
6 | 5 | from torch_cluster import graclus_cluster |
7 | 6 |
|
8 | | -from .tensor import tensors |
| 7 | +from .utils import dtypes, devices, tensor |
9 | 8 |
|
10 | 9 | tests = [{ |
11 | 10 | 'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3], |
|
18 | 17 |
|
19 | 18 |
|
20 | 19 | def assert_correct_graclus(row, col, cluster): |
21 | | - row, col = row.cpu().numpy(), col.cpu().numpy() |
22 | | - cluster, n_nodes = cluster.cpu().numpy(), cluster.size(0) |
| 20 | + row, col, cluster = row.to('cpu'), col.to('cpu'), cluster.to('cpu') |
| 21 | + n = cluster.size(0) |
23 | 22 |
|
24 | 23 | # Every node was assigned a cluster. |
25 | 24 | assert cluster.min() >= 0 |
26 | 25 |
|
27 | 26 | # There are no more than two nodes in each cluster. |
28 | | - _, count = np.unique(cluster, return_counts=True) |
| 27 | + _, index = torch.unique(cluster, return_inverse=True) |
| 28 | + count = torch.zeros_like(cluster) |
| 29 | + count.scatter_add_(0, index, torch.ones_like(cluster)) |
29 | 30 | assert (count > 2).max() == 0 |
30 | 31 |
|
31 | 32 | # Cluster value is minimal. |
32 | | - assert (cluster <= np.arange(n_nodes, dtype=row.dtype)).sum() == n_nodes |
| 33 | + assert (cluster <= torch.arange(n, dtype=cluster.dtype)).sum() == n |
33 | 34 |
|
34 | 35 | # Corresponding clusters must be adjacent. |
35 | | - for n in range(cluster.shape[0]): |
36 | | - x = cluster[col[row == n]] == cluster[n] # Neighbors with same cluster |
37 | | - y = cluster == cluster[n] # Nodes with same cluster |
38 | | - y[n] = 0 # Do not look at cluster of node `n`. |
| 36 | + for i in range(n): |
| 37 | + x = cluster[col[row == i]] == cluster[i] # Neighbors with same cluster |
| 38 | + y = cluster == cluster[i] # Nodes with same cluster. |
| 39 | + y[i] = 0 # Do not look at cluster of `i`. |
39 | 40 | assert x.sum() == y.sum() |
40 | 41 |
|
41 | 42 |
|
42 | | -@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests)))) |
43 | | -def test_graclus_cluster_cpu(tensor, i): |
44 | | - data = tests[i] |
45 | | - |
46 | | - row = torch.LongTensor(data['row']) |
47 | | - col = torch.LongTensor(data['col']) |
48 | | - |
49 | | - weight = data.get('weight') |
50 | | - weight = weight if weight is None else getattr(torch, tensor)(weight) |
51 | | - |
52 | | - cluster = graclus_cluster(row, col, weight) |
53 | | - assert_correct_graclus(row, col, cluster) |
54 | | - |
55 | | - |
56 | | -@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA') |
57 | | -@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests)))) |
58 | | -def test_graclus_cluster_gpu(tensor, i): # pragma: no cover |
59 | | - data = tests[i] |
60 | | - |
61 | | - row = torch.cuda.LongTensor(data['row']) |
62 | | - col = torch.cuda.LongTensor(data['col']) |
63 | | - |
64 | | - weight = data.get('weight') |
65 | | - weight = weight if weight is None else getattr(torch.cuda, tensor)(weight) |
| 43 | +@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices)) |
| 44 | +def test_graclus_cluster_cpu(test, dtype, device): |
| 45 | + row = tensor(test['row'], torch.long, device) |
| 46 | + col = tensor(test['col'], torch.long, device) |
| 47 | + weight = tensor(test.get('weight'), dtype, device) |
66 | 48 |
|
67 | 49 | cluster = graclus_cluster(row, col, weight) |
68 | 50 | assert_correct_graclus(row, col, cluster) |
0 commit comments