Skip to content

Commit 5bb8d17

Browse files
authored
update (#154)
1 parent 6f22228 commit 5bb8d17

File tree

10 files changed

+41
-31
lines changed

10 files changed

+41
-31
lines changed

test/__init__.py

Whitespace-only changes.

test/test_fps.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
import torch
55
from torch import Tensor
66
from torch_cluster import fps
7-
8-
from .utils import grad_dtypes, devices, tensor
7+
from torch_cluster.testing import devices, grad_dtypes, tensor
98

109

1110
@torch.jit.script

test/test_graclus.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import pytest
44
import torch
55
from torch_cluster import graclus_cluster
6-
7-
from .utils import dtypes, devices, tensor
6+
from torch_cluster.testing import devices, dtypes, tensor
87

98
tests = [{
109
'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
@@ -42,6 +41,9 @@ def assert_correct(row, col, cluster):
4241

4342
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
4443
def test_graclus_cluster(test, dtype, device):
44+
if dtype == torch.bfloat16 and device == torch.device('cuda:0'):
45+
return
46+
4547
row = tensor(test['row'], torch.long, device)
4648
col = tensor(test['col'], torch.long, device)
4749
weight = tensor(test.get('weight'), dtype, device)

test/test_grid.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from itertools import product
22

33
import pytest
4+
import torch
45
from torch_cluster import grid_cluster
5-
6-
from .utils import dtypes, devices, tensor
6+
from torch_cluster.testing import devices, dtypes, tensor
77

88
tests = [{
99
'pos': [2, 6],
@@ -28,6 +28,9 @@
2828

2929
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
3030
def test_grid_cluster(test, dtype, device):
31+
if dtype == torch.bfloat16 and device == torch.device('cuda:0'):
32+
return
33+
3134
pos = tensor(test['pos'], dtype, device)
3235
size = tensor(test['size'], dtype, device)
3336
start = tensor(test.get('start'), dtype, device)

test/test_knn.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from itertools import product
22

33
import pytest
4-
import torch
54
import scipy.spatial
5+
import torch
66
from torch_cluster import knn, knn_graph
7-
8-
from .utils import grad_dtypes, devices, tensor
7+
from torch_cluster.testing import devices, grad_dtypes, tensor
98

109

1110
def to_set(edge_index):

test/test_nearest.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import pytest
44
import torch
55
from torch_cluster import nearest
6-
7-
from .utils import grad_dtypes, devices, tensor
6+
from torch_cluster.testing import devices, grad_dtypes, tensor
87

98

109
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))

test/test_radius.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from itertools import product
22

33
import pytest
4-
import torch
54
import scipy.spatial
5+
import torch
66
from torch_cluster import radius, radius_graph
7-
8-
from .utils import grad_dtypes, devices, tensor
7+
from torch_cluster.testing import devices, grad_dtypes, tensor
98

109

1110
def to_set(edge_index):

test/test_rw.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import pytest
22
import torch
33
from torch_cluster import random_walk
4-
5-
from .utils import devices, tensor
4+
from torch_cluster.testing import devices, tensor
65

76

87
@pytest.mark.parametrize('device', devices)
@@ -41,7 +40,10 @@ def test_rw_large_with_edge_indices(device):
4140
walk_length = 10
4241

4342
node_seq, edge_seq = random_walk(
44-
row, col, start, walk_length,
43+
row,
44+
col,
45+
start,
46+
walk_length,
4547
return_edge_indices=True,
4648
)
4749
assert node_seq[:, 0].tolist() == start.tolist()
@@ -63,7 +65,10 @@ def test_rw_small_with_edge_indices(device):
6365
walk_length = 4
6466

6567
node_seq, edge_seq = random_walk(
66-
row, col, start, walk_length,
68+
row,
69+
col,
70+
start,
71+
walk_length,
6772
num_nodes=3,
6873
return_edge_indices=True,
6974
)

test/utils.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

torch_cluster/testing.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from typing import Any
2+
3+
import torch
4+
5+
dtypes = [
6+
torch.half, torch.bfloat16, torch.float, torch.double, torch.int,
7+
torch.long
8+
]
9+
grad_dtypes = [torch.half, torch.float, torch.double]
10+
11+
devices = [torch.device('cpu')]
12+
if torch.cuda.is_available():
13+
devices += [torch.device('cuda:0')]
14+
15+
16+
def tensor(x: Any, dtype: torch.dtype, device: torch.device):
17+
return None if x is None else torch.tensor(x, dtype=dtype, device=device)

0 commit comments

Comments
 (0)