Skip to content

Commit bcb25c1

Browse files
committed
TST: test kron and expand_dims
1 parent e9d38b1 commit bcb25c1

File tree

3 files changed

+125
-6
lines changed

3 files changed

+125
-6
lines changed

src/array_api_extra/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

3-
from ._funcs import atleast_nd
3+
from ._funcs import atleast_nd, expand_dims, kron
44

55
__version__ = "0.1.2.dev0"
66

7-
__all__ = ["__version__", "atleast_nd"]
7+
__all__ = ["__version__", "atleast_nd", "expand_dims", "kron"]

src/array_api_extra/_funcs.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
if TYPE_CHECKING:
66
from ._typing import Array, ModuleType
77

8-
__all__ = ["atleast_nd"]
8+
__all__ = ["atleast_nd", "expand_dims", "kron"]
99

1010

1111
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
@@ -65,6 +65,7 @@ def expand_dims(
6565
a : array
6666
axis : int or tuple of ints
6767
Position(s) in the expanded axes where the new axis (or axes) is/are placed.
68+
If multiple positions are provided, they should be unique.
6869
Default: ``(0,)``.
6970
xp : array_namespace
7071
The standard-compatible namespace for `a`.
@@ -118,8 +119,11 @@ def expand_dims(
118119
"""
119120
if not isinstance(axis, tuple):
120121
axis = (axis,)
122+
if len(set(axis)) != len(axis):
123+
err_msg = "Duplicate dimensions specified in `axis`."
124+
raise ValueError(err_msg)
121125
for i in axis:
122-
a = xp.expand_dims(a, axis=i, xp=xp)
126+
a = xp.expand_dims(a, axis=i)
123127
return a
124128

125129

tests/test_funcs.py

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
from __future__ import annotations
22

3+
import contextlib
4+
from typing import TYPE_CHECKING, Any
5+
36
# array-api-strict#6
47
import array_api_strict as xp # type: ignore[import-untyped]
5-
from numpy.testing import assert_array_equal
8+
import pytest
9+
from numpy.testing import assert_array_equal, assert_equal
10+
11+
from array_api_extra import atleast_nd, expand_dims, kron
612

7-
from array_api_extra import atleast_nd
13+
if TYPE_CHECKING:
14+
Array = Any # To be changed to a Protocol later (see array-api#589)
815

916

1017
class TestAtLeastND:
@@ -67,3 +74,111 @@ def test_5D(self):
6774

6875
y = atleast_nd(x, ndim=9, xp=xp)
6976
assert_array_equal(y, xp.ones((1, 1, 1, 1, 1, 1, 1, 1, 1)))
77+
78+
79+
class TestKron:
80+
def test_basic(self):
81+
# Using 0-dimensional array
82+
a = xp.asarray(1)
83+
b = xp.asarray([[1, 2], [3, 4]])
84+
k = xp.asarray([[1, 2], [3, 4]])
85+
assert_array_equal(kron(a, b, xp=xp), k)
86+
a = xp.asarray([[1, 2], [3, 4]])
87+
b = xp.asarray(1)
88+
assert_array_equal(kron(a, b, xp=xp), k)
89+
90+
# Using 1-dimensional array
91+
a = xp.asarray([3])
92+
b = xp.asarray([[1, 2], [3, 4]])
93+
k = xp.asarray([[3, 6], [9, 12]])
94+
assert_array_equal(kron(a, b, xp=xp), k)
95+
a = xp.asarray([[1, 2], [3, 4]])
96+
b = xp.asarray([3])
97+
assert_array_equal(kron(a, b, xp=xp), k)
98+
99+
# Using 3-dimensional array
100+
a = xp.asarray([[[1]], [[2]]])
101+
b = xp.asarray([[1, 2], [3, 4]])
102+
k = xp.asarray([[[1, 2], [3, 4]], [[2, 4], [6, 8]]])
103+
assert_array_equal(kron(a, b, xp=xp), k)
104+
a = xp.asarray([[1, 2], [3, 4]])
105+
b = xp.asarray([[[1]], [[2]]])
106+
k = xp.asarray([[[1, 2], [3, 4]], [[2, 4], [6, 8]]])
107+
assert_array_equal(kron(a, b, xp=xp), k)
108+
109+
def test_kron_smoke(self):
110+
a = xp.ones([3, 3])
111+
b = xp.ones([3, 3])
112+
k = xp.ones([9, 9])
113+
114+
assert_array_equal(kron(a, b, xp=xp), k)
115+
116+
@pytest.mark.parametrize(
117+
("shape_a", "shape_b"),
118+
[
119+
((1, 1), (1, 1)),
120+
((1, 2, 3), (4, 5, 6)),
121+
((2, 2), (2, 2, 2)),
122+
((1, 0), (1, 1)),
123+
((2, 0, 2), (2, 2)),
124+
((2, 0, 0, 2), (2, 0, 2)),
125+
],
126+
)
127+
def test_kron_shape(self, shape_a, shape_b):
128+
a = xp.ones(shape_a)
129+
b = xp.ones(shape_b)
130+
normalised_shape_a = xp.asarray(
131+
(1,) * max(0, len(shape_b) - len(shape_a)) + shape_a
132+
)
133+
normalised_shape_b = xp.asarray(
134+
(1,) * max(0, len(shape_a) - len(shape_b)) + shape_b
135+
)
136+
expected_shape = tuple(
137+
int(dim) for dim in xp.multiply(normalised_shape_a, normalised_shape_b)
138+
)
139+
140+
k = kron(a, b, xp=xp)
141+
assert_equal(k.shape, expected_shape, err_msg="Unexpected shape from kron")
142+
143+
144+
class TestExpandDims:
145+
def test_functionality(self):
146+
def _squeeze_all(b: Array) -> Array:
147+
"""Mimics `np.squeeze(b)`. `xpx.squeeze`?"""
148+
for axis in range(b.ndim):
149+
with contextlib.suppress(ValueError):
150+
b = xp.squeeze(b, axis=axis)
151+
return b
152+
153+
s = (2, 3, 4, 5)
154+
a = xp.empty(s)
155+
for axis in range(-5, 4):
156+
b = expand_dims(a, axis=axis, xp=xp)
157+
assert b.shape[axis] == 1
158+
assert _squeeze_all(b).shape == s
159+
160+
def test_axis_tuple(self):
161+
a = xp.empty((3, 3, 3))
162+
assert expand_dims(a, axis=(0, 1, 2), xp=xp).shape == (1, 1, 1, 3, 3, 3)
163+
assert expand_dims(a, axis=(0, -1, -2), xp=xp).shape == (1, 3, 3, 3, 1, 1)
164+
assert expand_dims(a, axis=(0, 3, 5), xp=xp).shape == (1, 3, 3, 1, 3, 1)
165+
assert expand_dims(a, axis=(0, -3, -5), xp=xp).shape == (1, 1, 3, 1, 3, 3)
166+
167+
def test_axis_out_of_range(self):
168+
s = (2, 3, 4, 5)
169+
a = xp.empty(s)
170+
with pytest.raises(IndexError, match="out of bounds"):
171+
expand_dims(a, axis=-6, xp=xp)
172+
with pytest.raises(IndexError, match="out of bounds"):
173+
expand_dims(a, axis=5, xp=xp)
174+
175+
a = xp.empty((3, 3, 3))
176+
with pytest.raises(IndexError, match="out of bounds"):
177+
expand_dims(a, axis=(0, -6), xp=xp)
178+
with pytest.raises(IndexError, match="out of bounds"):
179+
expand_dims(a, axis=(0, 5), xp=xp)
180+
181+
def test_repeated_axis(self):
182+
a = xp.empty((3, 3, 3))
183+
with pytest.raises(ValueError, match="Duplicate dimensions"):
184+
expand_dims(a, axis=(1, 1), xp=xp)

0 commit comments

Comments
 (0)