88from .utils import grad_dtypes , devices , tensor
99
1010
11+ def to_set (edge_index ):
12+ return set ([(i , j ) for i , j in edge_index .t ().tolist ()])
13+
14+
1115@pytest .mark .parametrize ('dtype,device' , product (grad_dtypes , devices ))
1216def test_knn (dtype , device ):
1317 x = tensor ([
@@ -28,18 +32,15 @@ def test_knn(dtype, device):
2832 batch_x = tensor ([0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 ], torch .long , device )
2933 batch_y = tensor ([0 , 1 ], torch .long , device )
3034
31- row , col = knn (x , y , 2 )
32- assert row .tolist () == [0 , 0 , 1 , 1 ]
33- assert col .tolist () == [2 , 3 , 0 , 1 ]
35+ edge_index = knn (x , y , 2 )
36+ assert to_set (edge_index ) == set ([(0 , 2 ), (0 , 3 ), (1 , 0 ), (1 , 1 )])
3437
35- row , col = knn (x , y , 2 , batch_x , batch_y )
36- assert row .tolist () == [0 , 0 , 1 , 1 ]
37- assert col .tolist () == [2 , 3 , 4 , 5 ]
38+ edge_index = knn (x , y , 2 , batch_x , batch_y )
39+ assert to_set (edge_index ) == set ([(0 , 2 ), (0 , 3 ), (1 , 4 ), (1 , 5 )])
3840
3941 if x .is_cuda :
40- row , col = knn (x , y , 2 , batch_x , batch_y , cosine = True )
41- assert row .tolist () == [0 , 0 , 1 , 1 ]
42- assert col .tolist () == [0 , 1 , 4 , 5 ]
42+ edge_index = knn (x , y , 2 , batch_x , batch_y , cosine = True )
43+ assert to_set (edge_index ) == set ([(0 , 0 ), (0 , 1 ), (1 , 4 ), (1 , 5 )])
4344
4445
4546@pytest .mark .parametrize ('dtype,device' , product (grad_dtypes , devices ))
@@ -51,25 +52,24 @@ def test_knn_graph(dtype, device):
5152 [+ 1 , - 1 ],
5253 ], dtype , device )
5354
54- row , col = knn_graph (x , k = 2 , flow = 'target_to_source' )
55- assert row . tolist ( ) == [ 0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 ]
56- assert col . tolist () == [ 1 , 3 , 0 , 2 , 1 , 3 , 0 , 2 ]
55+ edge_index = knn_graph (x , k = 2 , flow = 'target_to_source' )
56+ assert to_set ( edge_index ) == set ([( 0 , 1 ), ( 0 , 3 ), ( 1 , 0 ), ( 1 , 2 ), ( 2 , 1 ),
57+ ( 2 , 3 ), ( 3 , 0 ), ( 3 , 2 )])
5758
58- row , col = knn_graph (x , k = 2 , flow = 'source_to_target' )
59- assert row . tolist ( ) == [ 1 , 3 , 0 , 2 , 1 , 3 , 0 , 2 ]
60- assert col . tolist () == [ 0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 ]
59+ edge_index = knn_graph (x , k = 2 , flow = 'source_to_target' )
60+ assert to_set ( edge_index ) == set ([( 1 , 0 ), ( 3 , 0 ), ( 0 , 1 ), ( 2 , 1 ), ( 1 , 2 ),
61+ ( 3 , 2 ), ( 0 , 3 ), ( 2 , 3 )])
6162
6263
6364@pytest .mark .parametrize ('dtype,device' , product (grad_dtypes , devices ))
6465def test_knn_graph_large (dtype , device ):
6566 x = torch .randn (1000 , 3 )
6667
67- row , col = knn_graph (x , k = 5 , flow = 'target_to_source' , loop = True ,
68- num_workers = 6 )
69- pred = set ([(i , j ) for i , j in zip (row .tolist (), col .tolist ())])
68+ edge_index = knn_graph (x , k = 5 , flow = 'target_to_source' , loop = True ,
69+ num_workers = 6 )
7070
7171 tree = scipy .spatial .cKDTree (x .numpy ())
7272 _ , col = tree .query (x .cpu (), k = 5 )
7373 truth = set ([(i , j ) for i , ns in enumerate (col ) for j in ns ])
7474
75- assert pred == truth
75+ assert to_set ( edge_index ) == truth
0 commit comments