|
3 | 3 | from numpy.testing import assert_raises |
4 | 4 | import numpy as np |
5 | 5 |
|
| 6 | +import pytest |
| 7 | + |
6 | 8 | from .. import all |
7 | 9 | from .._creation_functions import ( |
8 | 10 | asarray, |
9 | 11 | arange, |
10 | 12 | empty, |
11 | 13 | empty_like, |
12 | 14 | eye, |
| 15 | + from_dlpack, |
13 | 16 | full, |
14 | 17 | full_like, |
15 | 18 | linspace, |
|
21 | 24 | ) |
22 | 25 | from .._dtypes import float32, float64 |
23 | 26 | from .._array_object import Array, CPU_DEVICE |
24 | | - |
| 27 | +from .._flags import set_array_api_strict_flags |
25 | 28 |
|
26 | 29 | def test_asarray_errors(): |
27 | 30 | # Test various protections against incorrect usage |
@@ -188,3 +191,24 @@ def test_meshgrid_dtype_errors(): |
188 | 191 | meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float32)) |
189 | 192 |
|
190 | 193 | assert_raises(ValueError, lambda: meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float64))) |
| 194 | + |
| 195 | + |
| 196 | +@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12']) |
| 197 | +def from_dlpack_2023_12(api_version): |
| 198 | + if api_version != '2022.12': |
| 199 | + with pytest.warns(UserWarning): |
| 200 | + set_array_api_strict_flags(api_version=api_version) |
| 201 | + else: |
| 202 | + set_array_api_strict_flags(api_version=api_version) |
| 203 | + |
| 204 | + a = asarray([1., 2., 3.], dtype=float64) |
| 205 | + # Never an error |
| 206 | + capsule = a.__dlpack__() |
| 207 | + from_dlpack(capsule) |
| 208 | + |
| 209 | + exception = NotImplementedError if api_version >= '2023.12' else ValueError |
| 210 | + pytest.raises(exception, lambda: from_dlpack(capsule, device=CPU_DEVICE)) |
| 211 | + pytest.raises(exception, lambda: from_dlpack(capsule, device=None)) |
| 212 | + pytest.raises(exception, lambda: from_dlpack(capsule, copy=False)) |
| 213 | + pytest.raises(exception, lambda: from_dlpack(capsule, copy=True)) |
| 214 | + pytest.raises(exception, lambda: from_dlpack(capsule, copy=None)) |
0 commit comments