File tree Expand file tree Collapse file tree 3 files changed +17
-9
lines changed
Expand file tree Collapse file tree 3 files changed +17
-9
lines changed Original file line number Diff line number Diff line change 1717
1818import operator
1919from enum import IntEnum
20- import warnings
2120
2221from ._creation_functions import asarray
2322from ._dtypes import (
@@ -502,8 +501,6 @@ def __array_namespace__(
502501
503502 """
504503 set_array_api_strict_flags (api_version = api_version )
505- if api_version == "2021.12" :
506- warnings .warn ("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12" )
507504 import array_api_strict
508505 return array_api_strict
509506
Original file line number Diff line number Diff line change 1313
1414import functools
1515import os
16+ import warnings
1617
1718import array_api_strict
1819
@@ -62,6 +63,9 @@ def set_array_api_strict_flags(
6263 versions are: ``{supported_versions}``. The default version number is
6364 ``{default_version!r}``.
6465
66+ Note that 2021.12 is supported, but currently gives the same thing as
67+ 2022.12 (except that the fft extension will be disabled).
68+
6569 - `data_dependent_shapes`: Whether data-dependent shapes are enabled in
6670 array-api-strict.
6771
@@ -118,6 +122,8 @@ def set_array_api_strict_flags(
118122 if api_version is not None :
119123 if api_version not in supported_versions :
120124 raise ValueError (f"Unsupported standard version { api_version !r} " )
125+ if api_version == "2021.12" :
126+ warnings .warn ("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12" )
121127 API_VERSION = api_version
122128 array_api_strict .__array_api_version__ = API_VERSION
123129
Original file line number Diff line number Diff line change @@ -32,8 +32,12 @@ def test_flags():
3232 'data_dependent_shapes' : False ,
3333 'enabled_extensions' : ('fft' ,),
3434 }
35- # Make sure setting the version to 2021.12 disables fft
36- set_array_api_strict_flags (api_version = '2021.12' )
35+ # Make sure setting the version to 2021.12 disables fft and issues a
36+ # warning.
37+ with pytest .warns (UserWarning ) as record :
38+ set_array_api_strict_flags (api_version = '2021.12' )
39+ assert len (record ) == 1
40+ assert '2021.12' in str (record [0 ].message )
3741 flags = get_array_api_strict_flags ()
3842 assert flags == {
3943 'api_version' : '2021.12' ,
@@ -51,10 +55,11 @@ def test_flags():
5155 enabled_extensions = ('linalg' , 'fft' )))
5256
5357 # Test resetting flags
54- set_array_api_strict_flags (
55- api_version = '2021.12' ,
56- data_dependent_shapes = False ,
57- enabled_extensions = ())
58+ with pytest .warns (UserWarning ):
59+ set_array_api_strict_flags (
60+ api_version = '2021.12' ,
61+ data_dependent_shapes = False ,
62+ enabled_extensions = ())
5863 reset_array_api_strict_flags ()
5964 flags = get_array_api_strict_flags ()
6065 assert flags == {
You can’t perform that action at this time.
0 commit comments