@@ -46,3 +46,176 @@ def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array:
4646 x = xp .expand_dims (x , axis = 0 )
4747 x = atleast_nd (x , ndim = ndim , xp = xp )
4848 return x
49+
50+
51+ def expand_dims (a : Array , * , axis : tuple [int ] = (0 ,), xp : ModuleType ):
52+ """
53+ Expand the shape of an array.
54+
55+ Insert a new axis that will appear at the `axis` position in the expanded
56+ array shape.
57+
58+ This is ``xp.expand_dims`` for ``axis`` an int *or a tuple of ints*.
59+ Equivalent to ``numpy.expand_dims`` for NumPy arrays.
60+
61+ Parameters
62+ ----------
63+ a : array
64+ axis : int or tuple of ints
65+ Position(s) in the expanded axes where the new axis (or axes) is/are placed.
66+ xp : array_namespace
67+ The standard-compatible namespace for `a`.
68+
69+ Returns
70+ -------
71+ res : array
72+ `a` with an expanded shape.
73+
74+ Examples
75+ --------
76+ # >>> import numpy as np
77+ # >>> x = np.array([1, 2])
78+ # >>> x.shape
79+ # (2,)
80+
81+ # The following is equivalent to ``x[np.newaxis, :]`` or ``x[np.newaxis]``:
82+
83+ # >>> y = np.expand_dims(x, axis=0)
84+ # >>> y
85+ # array([[1, 2]])
86+ # >>> y.shape
87+ # (1, 2)
88+
89+ # The following is equivalent to ``x[:, np.newaxis]``:
90+
91+ # >>> y = np.expand_dims(x, axis=1)
92+ # >>> y
93+ # array([[1],
94+ # [2]])
95+ # >>> y.shape
96+ # (2, 1)
97+
98+ # ``axis`` may also be a tuple:
99+
100+ # >>> y = np.expand_dims(x, axis=(0, 1))
101+ # >>> y
102+ # array([[[1, 2]]])
103+
104+ # >>> y = np.expand_dims(x, axis=(2, 0))
105+ # >>> y
106+ # array([[[1],
107+ # [2]]])
108+
109+ # Note that some examples may use ``None`` instead of ``np.newaxis``. These
110+ # are the same objects:
111+
112+ # >>> np.newaxis is None
113+ # True
114+
115+ """
116+ if not isinstance (axis , tuple ):
117+ axis = (axis ,)
118+ for i in axis :
119+ a = xp .expand_dims (a , axis = i , xp = xp )
120+ return a
121+
122+
123+ def kron (a : Array , b : Array , * , xp : ModuleType ):
124+ """
125+ Kronecker product of two arrays.
126+
127+ Computes the Kronecker product, a composite array made of blocks of the
128+ second array scaled by the first.
129+
130+ Equivalent to ``numpy.kron`` for NumPy arrays.
131+
132+ Parameters
133+ ----------
134+ a, b : array
135+ xp : array_namespace
136+ The standard-compatible namespace for `a` and `b`.
137+
138+ Returns
139+ -------
140+ res : array
141+
142+ Notes
143+ -----
144+ The function assumes that the number of dimensions of `a` and `b`
145+ are the same, if necessary prepending the smallest with ones.
146+ If ``a.shape = (r0,r1,..,rN)`` and ``b.shape = (s0,s1,...,sN)``,
147+ the Kronecker product has shape ``(r0*s0, r1*s1, ..., rN*SN)``.
148+ The elements are products of elements from `a` and `b`, organized
149+ explicitly by::
150+
151+ kron(a,b)[k0,k1,...,kN] = a[i0,i1,...,iN] * b[j0,j1,...,jN]
152+
153+ where::
154+
155+ kt = it * st + jt, t = 0,...,N
156+
157+ In the common 2-D case (N=1), the block structure can be visualized::
158+
159+ [[ a[0,0]*b, a[0,1]*b, ... , a[0,-1]*b ],
160+ [ ... ... ],
161+ [ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]]
162+
163+
164+ Examples
165+ --------
166+ # >>> import numpy as np
167+ # >>> np.kron([1,10,100], [5,6,7])
168+ # array([ 5, 6, 7, ..., 500, 600, 700])
169+ # >>> np.kron([5,6,7], [1,10,100])
170+ # array([ 5, 50, 500, ..., 7, 70, 700])
171+
172+ # >>> np.kron(np.eye(2), np.ones((2,2)))
173+ # array([[1., 1., 0., 0.],
174+ # [1., 1., 0., 0.],
175+ # [0., 0., 1., 1.],
176+ # [0., 0., 1., 1.]])
177+
178+ # >>> a = np.arange(100).reshape((2,5,2,5))
179+ # >>> b = np.arange(24).reshape((2,3,4))
180+ # >>> c = np.kron(a,b)
181+ # >>> c.shape
182+ # (2, 10, 6, 20)
183+ # >>> I = (1,3,0,2)
184+ # >>> J = (0,2,1)
185+ # >>> J1 = (0,) + J # extend to ndim=4
186+ # >>> S1 = (1,) + b.shape
187+ # >>> K = tuple(np.array(I) * np.array(S1) + np.array(J1))
188+ # >>> c[K] == a[I]*b[J]
189+ # True
190+
191+ """
192+
193+ b = xp .asarray (b )
194+ singletons = (1 ,) * (b .ndim - a .ndim )
195+ a = xp .broadcast_to (xp .asarray (a ), singletons + a .shape )
196+
197+ nd_b , nd_a = b .ndim , a .ndim
198+ nd_max = max (nd_b , nd_a )
199+ if nd_a == 0 or nd_b == 0 :
200+ return xp .multiply (a , b )
201+
202+ a_shape = a .shape
203+ b_shape = b .shape
204+
205+ # Equalise the shapes by prepending smaller one with 1s
206+ a_shape = (1 ,) * max (0 , nd_b - nd_a ) + a_shape
207+ b_shape = (1 ,) * max (0 , nd_a - nd_b ) + b_shape
208+
209+ # Insert empty dimensions
210+ a_arr = expand_dims (a , axis = tuple (range (nd_b - nd_a )), xp = xp )
211+ b_arr = expand_dims (b , axis = tuple (range (nd_a - nd_b )), xp = xp )
212+
213+ # Compute the product
214+ a_arr = expand_dims (a_arr , axis = tuple (range (1 , nd_max * 2 , 2 )), xp = xp )
215+ b_arr = expand_dims (b_arr , axis = tuple (range (0 , nd_max * 2 , 2 )), xp = xp )
216+ result = xp .multiply (a_arr , b_arr )
217+
218+ # Reshape back and return
219+ a_shape = xp .asarray (a_shape )
220+ b_shape = xp .asarray (b_shape )
221+ return xp .reshape (result , tuple (xp .multiply (a_shape , b_shape )))
0 commit comments