@@ -33,3 +33,38 @@ def test_nearest(dtype, device):
3333
3434 out = nearest (x , y )
3535 assert out .tolist () == [0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 ]
36+
37+ # Invalid input: instance 1 only in batch_x
38+ batch_x = tensor ([0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 ], torch .long , device )
39+ batch_y = tensor ([0 , 0 , 0 , 0 ], torch .long , device )
40+ with pytest .raises (ValueError ):
41+ out = nearest (x , y , batch_x , batch_y )
42+
43+ # Invalid input: instance 1 only in batch_x (implicitly as batch_y=None)
44+ with pytest .raises (ValueError ):
45+ out = nearest (x , y , batch_x , batch_y = None )
46+
47+ # Valid input: instance 1 only in batch_y
48+ batch_x = tensor ([0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ], torch .long , device )
49+ batch_y = tensor ([0 , 0 , 1 , 1 ], torch .long , device )
50+ out = nearest (x , y , batch_x , batch_y )
51+ assert out .tolist () == [0 , 0 , 1 , 1 , 0 , 0 , 1 , 1 ]
52+
53+ # Invalid input: instance 2 only in batch_x
54+ # (i.e.instance in the middle missing)
55+ batch_x = tensor ([0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 ], torch .long , device )
56+ batch_y = tensor ([0 , 1 , 3 , 3 ], torch .long , device )
57+ with pytest .raises (ValueError ):
58+ out = nearest (x , y , batch_x , batch_y )
59+
60+ # Invalid input: batch_x unsorted
61+ batch_x = tensor ([0 , 0 , 1 , 0 , 0 , 0 , 0 ], torch .long , device )
62+ batch_y = tensor ([0 , 0 , 1 , 1 ], torch .long , device )
63+ with pytest .raises (ValueError ):
64+ out = nearest (x , y , batch_x , batch_y )
65+
66+ # Invalid input: batch_y unsorted
67+ batch_x = tensor ([0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 ], torch .long , device )
68+ batch_y = tensor ([0 , 0 , 1 , 0 ], torch .long , device )
69+ with pytest .raises (ValueError ):
70+ out = nearest (x , y , batch_x , batch_y )
0 commit comments