@@ -58,14 +58,14 @@ def expand_dims(
5858 `axis` in the expanded array shape.
5959
6060 This is ``xp.expand_dims`` for `axis` an int *or a tuple of ints*.
61- Equivalent to ``numpy.expand_dims`` for NumPy arrays.
61+ Roughly equivalent to ``numpy.expand_dims`` for NumPy arrays.
6262
6363 Parameters
6464 ----------
6565 a : array
6666 axis : int or tuple of ints
6767 Position(s) in the expanded axes where the new axis (or axes) is/are placed.
68- If multiple positions are provided, they should be unique.
68+ If multiple positions are provided, they should be unique and increasing .
6969 Default: ``(0,)``.
7070 xp : array_namespace
7171 The standard-compatible namespace for `a`.
@@ -77,52 +77,54 @@ def expand_dims(
7777
7878 Examples
7979 --------
80- # >>> import numpy as np
81- # >>> x = np.array([1, 2])
82- # >>> x.shape
83- # (2,)
84-
85- # The following is equivalent to ``x[np.newaxis, :]`` or ``x[np.newaxis]``:
86-
87- # >>> y = np.expand_dims(x, axis=0)
88- # >>> y
89- # array([[1, 2]])
90- # >>> y.shape
91- # (1, 2)
80+ >>> import array_api_strict as xp
81+ >>> import array_api_extra as xpx
82+ >>> x = xp.asarray([1, 2])
83+ >>> x.shape
84+ (2,)
9285
93- # The following is equivalent to ``x[:, np .newaxis]``:
86+ The following is equivalent to ``x[xp.newaxis, :]`` or ``x[xp .newaxis]``:
9487
95- # >>> y = np.expand_dims(x, axis=1)
96- # >>> y
97- # array([[1],
98- # [2]])
99- # >>> y.shape
100- # (2, 1)
88+ >>> y = xpx.expand_dims(x, axis=0, xp=xp)
89+ >>> y
90+ Array([[1, 2]], dtype=array_api_strict.int64)
91+ >>> y.shape
92+ (1, 2)
10193
102- # ``axis`` may also be a tuple :
94+ The following is equivalent to ``x[:, xp.newaxis]`` :
10395
104- # >>> y = np.expand_dims(x, axis=(0, 1))
105- # >>> y
106- # array([[[1, 2]]])
96+ >>> y = xpx.expand_dims(x, axis=1, xp=xp)
97+ >>> y
98+ Array([[1],
99+ [2]], dtype=array_api_strict.int64)
100+ >>> y.shape
101+ (2, 1)
107102
108- # >>> y = np.expand_dims(x, axis=(2, 0))
109- # >>> y
110- # array([[[1],
111- # [2]]])
103+ ``axis`` may also be a tuple:
112104
113- # Note that some examples may use ``None`` instead of ``np.newaxis``. These
114- # are the same objects:
105+ >>> y = xpx.expand_dims(x, axis=(0, 1), xp=xp)
106+ >>> y
107+ Array([[[1, 2]]], dtype=array_api_strict.int64)
115108
116- # >>> np.newaxis is None
117- # True
109+ >>> y = xpx.expand_dims(x, axis=(2, 0), xp=xp)
110+ >>> y
111+ Array([[[1],
112+ [2]]], dtype=array_api_strict.int64)
118113
119114 """
120115 if not isinstance (axis , tuple ):
121116 axis = (axis ,)
122117 if len (set (axis )) != len (axis ):
123118 err_msg = "Duplicate dimensions specified in `axis`."
124119 raise ValueError (err_msg )
125- for i in axis :
120+ ndim = a .ndim + len (axis )
121+ if axis != () and (min (axis ) < - ndim or max (axis ) >= ndim ):
122+ err_msg = (
123+ f"a provided axis position is out of bounds for array of dimension { a .ndim } "
124+ )
125+ raise IndexError (err_msg )
126+ axis = tuple (dim % ndim for dim in axis )
127+ for i in sorted (axis ):
126128 a = xp .expand_dims (a , axis = i )
127129 return a
128130
@@ -145,6 +147,7 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
145147 Returns
146148 -------
147149 res : array
150+ The Kronecker product of `a` and `b`.
148151
149152 Notes
150153 -----
@@ -170,30 +173,35 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
170173
171174 Examples
172175 --------
173- # >>> import numpy as np
174- # >>> np.kron([1,10,100], [5,6,7])
175- # array([ 5, 6, 7, ..., 500, 600, 700])
176- # >>> np.kron([5,6,7], [1,10,100])
177- # array([ 5, 50, 500, ..., 7, 70, 700])
178-
179- # >>> np.kron(np.eye(2), np.ones((2,2)))
180- # array([[1., 1., 0., 0.],
181- # [1., 1., 0., 0.],
182- # [0., 0., 1., 1.],
183- # [0., 0., 1., 1.]])
184-
185- # >>> a = np.arange(100).reshape((2,5,2,5))
186- # >>> b = np.arange(24).reshape((2,3,4))
187- # >>> c = np.kron(a,b)
188- # >>> c.shape
189- # (2, 10, 6, 20)
190- # >>> I = (1,3,0,2)
191- # >>> J = (0,2,1)
192- # >>> J1 = (0,) + J # extend to ndim=4
193- # >>> S1 = (1,) + b.shape
194- # >>> K = tuple(np.array(I) * np.array(S1) + np.array(J1))
195- # >>> c[K] == a[I]*b[J]
196- # True
176+ >>> import array_api_strict as xp
177+ >>> import array_api_extra as xpx
178+ >>> xpx.kron(xp.asarray([1, 10, 100]), xp.asarray([5, 6, 7]), xp=xp)
179+ Array([ 5, 6, 7, 50, 60, 70, 500,
180+ 600, 700], dtype=array_api_strict.int64)
181+
182+ >>> xpx.kron(xp.asarray([5, 6, 7]), xp.asarray([1, 10, 100]), xp=xp)
183+ Array([ 5, 50, 500, 6, 60, 600, 7,
184+ 70, 700], dtype=array_api_strict.int64)
185+
186+ >>> xpx.kron(xp.eye(2), xp.ones((2, 2)), xp=xp)
187+ Array([[1., 1., 0., 0.],
188+ [1., 1., 0., 0.],
189+ [0., 0., 1., 1.],
190+ [0., 0., 1., 1.]], dtype=array_api_strict.float64)
191+
192+
193+ >>> a = xp.reshape(xp.arange(100), (2, 5, 2, 5))
194+ >>> b = xp.reshape(xp.arange(24), (2, 3, 4))
195+ >>> c = xpx.kron(a, b, xp=xp)
196+ >>> c.shape
197+ (2, 10, 6, 20)
198+ >>> I = (1, 3, 0, 2)
199+ >>> J = (0, 2, 1)
200+ >>> J1 = (0,) + J # extend to ndim=4
201+ >>> S1 = (1,) + b.shape
202+ >>> K = tuple(xp.asarray(I) * xp.asarray(S1) + xp.asarray(J1))
203+ >>> c[K] == a[I]*b[J]
204+ Array(True, dtype=array_api_strict.bool)
197205
198206 """
199207
0 commit comments