@@ -145,7 +145,7 @@ def is_ndonnx_array(x):
145145
146146 import ndonnx as ndx
147147
148- return isinstance (x , ndx .Array )
148+ return isinstance (x , ndx .Array )
149149
150150def is_dask_array (x ):
151151 """
@@ -340,12 +340,9 @@ def your_function(x, y):
340340 elif use_compat is False :
341341 namespaces .add (np )
342342 else :
343- # numpy 2.0 has __array_namespace__ and is fully array API
343+ # numpy 2.0+ have __array_namespace__, however, they are not yet fully array API
344344 # compatible.
345- if hasattr (np .empty (0 ), '__array_namespace__' ):
346- namespaces .add (np .empty (0 ).__array_namespace__ (api_version = api_version ))
347- else :
348- namespaces .add (numpy_namespace )
345+ namespaces .add (numpy_namespace )
349346 elif is_cupy_array (x ):
350347 if _use_compat :
351348 _check_api_version (api_version )
@@ -377,9 +374,13 @@ def your_function(x, y):
377374 elif use_compat is False :
378375 import jax .numpy as jnp
379376 else :
380- # jax.experimental.array_api is already an array namespace. We do
381- # not have a wrapper submodule for it.
382- import jax .experimental .array_api as jnp
377+ # JAX v0.4.32 and newer implements the array API directly in jax.numpy.
378+ # For older JAX versions, it is available via jax.experimental.array_api.
379+ import jax .numpy
380+ if hasattr (jax .numpy , "__array_api_version__" ):
381+ jnp = jax .numpy
382+ else :
383+ import jax .experimental .array_api as jnp
383384 namespaces .add (jnp )
384385 elif is_pydata_sparse_array (x ):
385386 if use_compat is True :
@@ -613,8 +614,9 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
613614 return x
614615 raise ValueError (f"Unsupported device { device !r} " )
615616 elif is_jax_array (x ):
616- # This import adds to_device to x
617- import jax .experimental .array_api # noqa: F401
617+ if not hasattr (x , "__array_namespace__" ):
618+ # In JAX v0.4.31 and older, this import adds to_device method to x.
619+ import jax .experimental .array_api # noqa: F401
618620 return x .to_device (device , stream = stream )
619621 elif is_pydata_sparse_array (x ) and device == _device (x ):
620622 # Perform trivial check to return the same array if
0 commit comments