|
3 | 3 | import pytest |
4 | 4 | import torch |
5 | 5 | from torch_cluster import knn, knn_graph |
6 | | -import pickle |
7 | 6 | from .utils import grad_dtypes, devices, tensor |
8 | 7 |
|
9 | 8 |
|
@@ -61,29 +60,41 @@ def test_knn_graph(dtype, device): |
61 | 60 |
|
62 | 61 | @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) |
63 | 62 | def test_knn_graph_large(dtype, device): |
64 | | - d = pickle.load(open("test/knn_test_large.pkl", "rb")) |
65 | | - x = d['x'].to(device) |
66 | | - k = d['k'] |
67 | | - truth = d['edges'] |
68 | | - |
69 | | - row, col = knn_graph(x, k=k, flow='source_to_target', |
70 | | - batch=None, n_threads=24) |
| 63 | + x = torch.tensor([[-1.0320, 0.2380, 0.2380], |
| 64 | + [-1.3050, -0.0930, 0.6420], |
| 65 | + [-0.3190, -0.0410, 1.2150], |
| 66 | + [1.1400, -0.5390, -0.3140], |
| 67 | + [0.8410, 0.8290, 0.6090], |
| 68 | + [-1.4380, -0.2420, -0.3260], |
| 69 | + [-2.2980, 0.7160, 0.9320], |
| 70 | + [-1.3680, -0.4390, 0.1380], |
| 71 | + [-0.6710, 0.6060, 1.1800], |
| 72 | + [0.3950, -0.0790, 1.4920]],).to(device) |
| 73 | + k = 3 |
| 74 | + truth = set({(4, 8), (2, 8), (9, 8), (8, 0), (0, 7), (2, 1), (9, 4), |
| 75 | + (5, 1), (4, 9), (2, 9), (8, 1), (1, 5), (5, 0), (3, 2), |
| 76 | + (8, 2), (7, 1), (6, 0), (3, 9), (0, 5), (7, 5), (4, 2), |
| 77 | + (1, 0), (0, 1), (7, 0), (6, 8), (9, 2), (6, 1), (5, 7), |
| 78 | + (1, 7), (3, 4)}) |
| 79 | + |
| 80 | + row, col = knn_graph(x, k=k, flow='target_to_source', |
| 81 | + batch=None, n_threads=24, loop=False) |
71 | 82 |
|
72 | 83 | edges = set([(i, j) for (i, j) in zip(list(row.cpu().numpy()), |
73 | 84 | list(col.cpu().numpy()))]) |
74 | 85 |
|
75 | 86 | assert(truth == edges) |
76 | 87 |
|
77 | | - row, col = knn_graph(x, k=k, flow='source_to_target', |
78 | | - batch=None, n_threads=12) |
| 88 | + row, col = knn_graph(x, k=k, flow='target_to_source', |
| 89 | + batch=None, n_threads=12, loop=False) |
79 | 90 |
|
80 | 91 | edges = set([(i, j) for (i, j) in zip(list(row.cpu().numpy()), |
81 | 92 | list(col.cpu().numpy()))]) |
82 | 93 |
|
83 | 94 | assert(truth == edges) |
84 | 95 |
|
85 | | - row, col = knn_graph(x, k=k, flow='source_to_target', |
86 | | - batch=None, n_threads=1) |
| 96 | + row, col = knn_graph(x, k=k, flow='target_to_source', |
| 97 | + batch=None, n_threads=1, loop=False) |
87 | 98 |
|
88 | 99 | edges = set([(i, j) for (i, j) in zip(list(row.cpu().numpy()), |
89 | 100 | list(col.cpu().numpy()))]) |
|
0 commit comments