@@ -251,26 +251,52 @@ def ihfft(
251251 return res
252252
253253@requires_extension ('fft' )
254- def fftfreq (n : int , / , * , d : float = 1.0 , device : Optional [Device ] = None ) -> Array :
254+ def fftfreq (
255+ n : int ,
256+ / ,
257+ * ,
258+ d : float = 1.0 ,
259+ dtype : Optional [dtype ] = None ,
260+ device : Optional [Device ] = None
261+ ) -> Array :
255262 """
256263 Array API compatible wrapper for :py:func:`np.fft.fftfreq <numpy.fft.fftfreq>`.
257264
258265 See its docstring for more information.
259266 """
260267 if device is not None and device not in ALL_DEVICES :
261268 raise ValueError (f"Unsupported device { device !r} " )
262- return Array ._new (np .fft .fftfreq (n , d = d ), device = device )
269+ if dtype and not dtype in _real_floating_dtypes :
270+ raise ValueError (f"`dtype` must be a real floating-point type. Got { dtype = } ." )
271+
272+ np_result = np .fft .fftfreq (n , d = d )
273+ if dtype :
274+ np_result = np_result .astype (dtype ._np_dtype )
275+ return Array ._new (np_result , device = device )
263276
264277@requires_extension ('fft' )
265- def rfftfreq (n : int , / , * , d : float = 1.0 , device : Optional [Device ] = None ) -> Array :
278+ def rfftfreq (
279+ n : int ,
280+ / ,
281+ * ,
282+ d : float = 1.0 ,
283+ dtype : Optional [dtype ] = None ,
284+ device : Optional [Device ] = None
285+ ) -> Array :
266286 """
267287 Array API compatible wrapper for :py:func:`np.fft.rfftfreq <numpy.fft.rfftfreq>`.
268288
269289 See its docstring for more information.
270290 """
271291 if device is not None and device not in ALL_DEVICES :
272292 raise ValueError (f"Unsupported device { device !r} " )
273- return Array ._new (np .fft .rfftfreq (n , d = d ), device = device )
293+ if dtype and not dtype in _real_floating_dtypes :
294+ raise ValueError (f"`dtype` must be a real floating-point type. Got { dtype = } ." )
295+
296+ np_result = np .fft .rfftfreq (n , d = d )
297+ if dtype :
298+ np_result = np_result .astype (dtype ._np_dtype )
299+ return Array ._new (np_result , device = device )
274300
275301@requires_extension ('fft' )
276302def fftshift (x : Array , / , * , axes : Union [int , Sequence [int ]] = None ) -> Array :
0 commit comments