@@ -370,3 +370,124 @@ def test_disabled_extensions():
370370 exec ('from array_api_strict import *' , ns )
371371 assert 'linalg' not in ns
372372 assert 'fft' not in ns
373+
374+
375+ def test_environment_variables ():
376+ # Test that the environment variables work as expected
377+ subprocess_tests = [
378+ # ARRAY_API_STRICT_API_VERSION
379+ ('''\
380+ import array_api_strict as xp
381+ assert xp.__array_api_version__ == '2022.12'
382+
383+ assert xp.get_array_api_strict_flags()['api_version'] == '2022.12'
384+
385+ ''' , {}),
386+ * [
387+ (f'''\
388+ import array_api_strict as xp
389+ assert xp.__array_api_version__ == '{ version } '
390+
391+ assert xp.get_array_api_strict_flags()['api_version'] == '{ version } '
392+
393+ if { version } == '2021.12':
394+ assert hasattr(xp, 'linalg')
395+ assert not hasattr(xp, 'fft')
396+
397+ ''' , {"ARRAY_API_STRICT_API_VERSION" : version }) for version in ('2021.12' , '2022.12' , '2023.12' )],
398+
399+ # ARRAY_API_STRICT_BOOLEAN_INDEXING
400+ ('''\
401+ import array_api_strict as xp
402+
403+ a = xp.ones(3)
404+ mask = xp.asarray([True, False, True])
405+
406+ assert xp.all(a[mask] == xp.asarray([1., 1.]))
407+ assert xp.get_array_api_strict_flags()['boolean_indexing'] == True
408+ ''' , {}),
409+ * [(f'''\
410+ import array_api_strict as xp
411+
412+ a = xp.ones(3)
413+ mask = xp.asarray([True, False, True])
414+
415+ if { boolean_indexing } :
416+ assert xp.all(a[mask] == xp.asarray([1., 1.]))
417+ else:
418+ try:
419+ a[mask]
420+ except RuntimeError:
421+ pass
422+ else:
423+ assert False
424+
425+ assert xp.get_array_api_strict_flags()['boolean_indexing'] == { boolean_indexing }
426+ ''' , {"ARRAY_API_STRICT_BOOLEAN_INDEXING" : boolean_indexing })
427+ for boolean_indexing in ('True' , 'False' )],
428+
429+ # ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES
430+ ('''\
431+ import array_api_strict as xp
432+
433+ a = xp.ones(3)
434+ xp.unique_all(a)
435+
436+ assert xp.get_array_api_strict_flags()['data_dependent_shapes'] == True
437+ ''' , {}),
438+ * [(f'''\
439+ import array_api_strict as xp
440+
441+ a = xp.ones(3)
442+ if { data_dependent_shapes } :
443+ xp.unique_all(a)
444+ else:
445+ try:
446+ xp.unique_all(a)
447+ except RuntimeError:
448+ pass
449+ else:
450+ assert False
451+
452+ assert xp.get_array_api_strict_flags()['data_dependent_shapes'] == { data_dependent_shapes }
453+ ''' , {"ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" : data_dependent_shapes })
454+ for data_dependent_shapes in ('True' , 'False' )],
455+
456+ # ARRAY_API_STRICT_ENABLED_EXTENSIONS
457+ ('''\
458+ import array_api_strict as xp
459+ assert hasattr(xp, 'linalg')
460+ assert hasattr(xp, 'fft')
461+
462+ assert xp.get_array_api_strict_flags()['enabled_extensions'] == ('linalg', 'fft')
463+ ''' , {}),
464+ * [(f'''\
465+ import array_api_strict as xp
466+
467+ assert hasattr(xp, 'linalg') == ('linalg' in { extensions .split (',' )} )
468+ assert hasattr(xp, 'fft') == ('fft' in { extensions .split (',' )} )
469+
470+ assert sorted(xp.get_array_api_strict_flags()['enabled_extensions']) == { sorted (set (extensions .split (',' ))- {'' })}
471+ ''' , {"ARRAY_API_STRICT_ENABLED_EXTENSIONS" : extensions })
472+ for extensions in ('' , 'linalg' , 'fft' , 'linalg,fft' )],
473+ ]
474+
475+ for test , env in subprocess_tests :
476+ try :
477+ subprocess .run ([sys .executable , '-c' , test ], check = True ,
478+ capture_output = True , encoding = 'utf-8' , env = env )
479+ except subprocess .CalledProcessError as e :
480+ print (e .stdout , end = '' )
481+ # Ensure the exception is shown in the output log
482+ raise AssertionError (f"""\
483+ STDOUT:
484+ { e .stderr }
485+
486+ STDERR:
487+ { e .stderr }
488+
489+ TEST:
490+ { test }
491+
492+ ENV:
493+ { env } """ )
0 commit comments