1111library will only support one particular configuration of these flags.
1212"""
1313
14+ import functools
1415import os
1516
16- supported_versions = [
17+ supported_versions = (
1718 "2021.12" ,
1819 "2022.12" ,
19- ]
20+ )
2021
21- STANDARD_VERSION = "2022.12"
22+ STANDARD_VERSION = default_version = "2022.12"
2223
2324DATA_DEPENDENT_SHAPES = True
2425
25- all_extensions = [
26+ all_extensions = (
2627 "linalg" ,
2728 "fft" ,
28- ]
29+ )
2930
3031extension_versions = {
3132 "linalg" : "2021.12" ,
3233 "fft" : "2022.12" ,
3334}
3435
35- ENABLED_EXTENSIONS = [
36+ ENABLED_EXTENSIONS = default_extensions = (
3637 "linalg" ,
3738 "fft" ,
38- ]
39+ )
40+
41+ # Public functions
3942
4043def set_array_api_strict_flags (
4144 * ,
@@ -136,8 +139,8 @@ def set_array_api_strict_flags(
136139# We have to do this separately or it won't get added as the docstring
137140set_array_api_strict_flags .__doc__ = set_array_api_strict_flags .__doc__ .format (
138141 supported_versions = supported_versions ,
139- default_version = STANDARD_VERSION ,
140- default_extensions = ENABLED_EXTENSIONS ,
142+ default_version = default_version ,
143+ default_extensions = default_extensions ,
141144)
142145
143146def get_array_api_strict_flags ():
@@ -160,7 +163,7 @@ def get_array_api_strict_flags():
160163 >>> from array_api_strict import get_array_api_strict_flags
161164 >>> flags = get_array_api_strict_flags()
162165 >>> flags
163- {'standard_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': [ 'linalg', 'fft'] }
166+ {'standard_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ( 'linalg', 'fft') }
164167
165168 See Also
166169 --------
@@ -181,6 +184,8 @@ def reset_array_api_strict_flags():
181184 """
182185 Reset the array-api-strict flags to their default values.
183186
187+ This will also reset any flags that were set by environment variables.
188+
184189 .. note::
185190
186191 This function is **not** part of the array API standard. It only exists
@@ -201,9 +206,9 @@ def reset_array_api_strict_flags():
201206
202207 """
203208 global STANDARD_VERSION , DATA_DEPENDENT_SHAPES , ENABLED_EXTENSIONS
204- STANDARD_VERSION = "2022.12"
209+ STANDARD_VERSION = default_version
205210 DATA_DEPENDENT_SHAPES = True
206- ENABLED_EXTENSIONS = [ "linalg" , "fft" ]
211+ ENABLED_EXTENSIONS = default_extensions
207212
208213
209214class ArrayApiStrictFlags :
@@ -241,18 +246,22 @@ def __enter__(self):
241246 def __exit__ (self , exc_type , exc_value , traceback ):
242247 set_array_api_strict_flags (** self .old_flags )
243248
244- # Set the flags from the environment variables
245- if "ARRAY_API_STRICT_STANDARD_VERSION" in os .environ :
246- set_array_api_strict_flags (
247- standard_version = os .environ ["ARRAY_API_STRICT_STANDARD_VERSION" ]
248- )
249-
250- if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os .environ :
251- set_array_api_strict_flags (
252- data_dependent_shapes = os .environ ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" ].lower () == "true"
253- )
254-
255- if "ARRAY_API_STRICT_ENABLED_EXTENSIONS" in os .environ :
256- set_array_api_strict_flags (
257- enabled_extensions = os .environ ["ARRAY_API_STRICT_ENABLED_EXTENSIONS" ].split ("," )
258- )
249+ # Private functions
250+
251+ def set_flags_from_environment ():
252+ if "ARRAY_API_STRICT_STANDARD_VERSION" in os .environ :
253+ set_array_api_strict_flags (
254+ standard_version = os .environ ["ARRAY_API_STRICT_STANDARD_VERSION" ]
255+ )
256+
257+ if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os .environ :
258+ set_array_api_strict_flags (
259+ data_dependent_shapes = os .environ ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" ].lower () == "true"
260+ )
261+
262+ if "ARRAY_API_STRICT_ENABLED_EXTENSIONS" in os .environ :
263+ set_array_api_strict_flags (
264+ enabled_extensions = os .environ ["ARRAY_API_STRICT_ENABLED_EXTENSIONS" ].split ("," )
265+ )
266+
267+ set_flags_from_environment ()
0 commit comments