Skip to content

Commit 4e7de16

Browse files
committed
remove pickle dependency
1 parent 0d66377 commit 4e7de16

File tree

4 files changed

+47
-17
lines changed

4 files changed

+47
-17
lines changed

test/knn_test_large.pkl

-632 KB
Binary file not shown.

test/radius_test_large.pkl

-960 KB
Binary file not shown.

test/test_knn.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import pytest
44
import torch
55
from torch_cluster import knn, knn_graph
6-
import pickle
76
from .utils import grad_dtypes, devices, tensor
87

98

@@ -61,29 +60,41 @@ def test_knn_graph(dtype, device):
6160

6261
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
6362
def test_knn_graph_large(dtype, device):
64-
d = pickle.load(open("test/knn_test_large.pkl", "rb"))
65-
x = d['x'].to(device)
66-
k = d['k']
67-
truth = d['edges']
68-
69-
row, col = knn_graph(x, k=k, flow='source_to_target',
70-
batch=None, n_threads=24)
63+
x = torch.tensor([[-1.0320, 0.2380, 0.2380],
64+
[-1.3050, -0.0930, 0.6420],
65+
[-0.3190, -0.0410, 1.2150],
66+
[1.1400, -0.5390, -0.3140],
67+
[0.8410, 0.8290, 0.6090],
68+
[-1.4380, -0.2420, -0.3260],
69+
[-2.2980, 0.7160, 0.9320],
70+
[-1.3680, -0.4390, 0.1380],
71+
[-0.6710, 0.6060, 1.1800],
72+
[0.3950, -0.0790, 1.4920]],).to(device)
73+
k = 3
74+
truth = set({(4, 8), (2, 8), (9, 8), (8, 0), (0, 7), (2, 1), (9, 4),
75+
(5, 1), (4, 9), (2, 9), (8, 1), (1, 5), (5, 0), (3, 2),
76+
(8, 2), (7, 1), (6, 0), (3, 9), (0, 5), (7, 5), (4, 2),
77+
(1, 0), (0, 1), (7, 0), (6, 8), (9, 2), (6, 1), (5, 7),
78+
(1, 7), (3, 4)})
79+
80+
row, col = knn_graph(x, k=k, flow='target_to_source',
81+
batch=None, n_threads=24, loop=False)
7182

7283
edges = set([(i, j) for (i, j) in zip(list(row.cpu().numpy()),
7384
list(col.cpu().numpy()))])
7485

7586
assert(truth == edges)
7687

77-
row, col = knn_graph(x, k=k, flow='source_to_target',
78-
batch=None, n_threads=12)
88+
row, col = knn_graph(x, k=k, flow='target_to_source',
89+
batch=None, n_threads=12, loop=False)
7990

8091
edges = set([(i, j) for (i, j) in zip(list(row.cpu().numpy()),
8192
list(col.cpu().numpy()))])
8293

8394
assert(truth == edges)
8495

85-
row, col = knn_graph(x, k=k, flow='source_to_target',
86-
batch=None, n_threads=1)
96+
row, col = knn_graph(x, k=k, flow='target_to_source',
97+
batch=None, n_threads=1, loop=False)
8798

8899
edges = set([(i, j) for (i, j) in zip(list(row.cpu().numpy()),
89100
list(col.cpu().numpy()))])

test/test_radius.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@
44
import torch
55
from torch_cluster import radius, radius_graph
66
from .utils import grad_dtypes, devices, tensor
7-
import pickle
7+
import scipy.spatial
8+
9+
10+
@torch.jit.script
11+
def sample(col: torch.Tensor, count: int) -> torch.Tensor:
12+
if col.size(0) > count:
13+
col = col[torch.randperm(col.size(0), dtype=torch.long)][:count]
14+
return col
815

916

1017
def coalesce(index):
@@ -594,10 +601,22 @@ def test_radius_graph_ndim(dtype, device):
594601

595602
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
596603
def test_radius_graph_large(dtype, device):
597-
d = pickle.load(open("test/radius_test_large.pkl", "rb"))
598-
x = d['x'].to(device)
599-
r = d['r']
600-
truth = d['edges']
604+
x = torch.randn((8192*4, 6))
605+
r = 0.5
606+
607+
tree = scipy.spatial.cKDTree(x.detach().cpu().numpy())
608+
col = tree.query_ball_point(x.detach().cpu().numpy(), r)
609+
col = [torch.tensor(c, dtype=torch.long) for c in col]
610+
col = [sample(c, 32) for c in col]
611+
row = [torch.full_like(c, i) for i, c in enumerate(col)]
612+
row, col = torch.cat(row, dim=0), torch.cat(col, dim=0)
613+
mask = col < int(tree.n)
614+
row_truth, col_truth = torch.stack([row[mask], col[mask]], dim=0)
615+
mask = row_truth != col_truth
616+
row_truth, col_truth = row_truth[mask], col_truth[mask]
617+
618+
truth = (set([(i, j) for (i, j) in zip(list(row_truth.cpu().numpy()),
619+
list(col_truth.cpu().numpy()))]))
601620

602621
row, col = radius_graph(x, r=r, flow='source_to_target',
603622
batch=None, n_threads=24)

0 commit comments

Comments
 (0)