@@ -10,21 +10,16 @@ def test_where_with_scalars():
1010 x = xp .asarray ([1 , 2 , 3 , 1 ])
1111
1212 # Versions up to and including 2023.12 don't support scalar arguments
13- with pytest .raises (AttributeError , match = "object has no attribute 'dtype'" ):
14- xp .where (x == 1 , 42 , 44 )
13+ with ArrayAPIStrictFlags (api_version = '2023.12' ):
14+ with pytest .raises (AttributeError , match = "object has no attribute 'dtype'" ):
15+ xp .where (x == 1 , 42 , 44 )
1516
1617 # Versions after 2023.12 support scalar arguments
17- with (pytest .warns (
18- UserWarning ,
19- match = "The 2024.12 version of the array API specification is in draft status"
20- ),
21- ArrayAPIStrictFlags (api_version = draft_version ),
22- ):
23- x_where = xp .where (x == 1 , xp .asarray (42 ), 44 )
24-
25- expected = xp .asarray ([42 , 44 , 44 , 42 ])
26- assert xp .all (x_where == expected )
27-
28- # The spec does not allow both x1 and x2 to be scalars
29- with pytest .raises (ValueError , match = "One of" ):
30- xp .where (x == 1 , 42 , 44 )
18+ x_where = xp .where (x == 1 , xp .asarray (42 ), 44 )
19+
20+ expected = xp .asarray ([42 , 44 , 44 , 42 ])
21+ assert xp .all (x_where == expected )
22+
23+ # The spec does not allow both x1 and x2 to be scalars
24+ with pytest .raises (ValueError , match = "One of" ):
25+ xp .where (x == 1 , 42 , 44 )
0 commit comments