@@ -51,6 +51,8 @@ def __repr__(self):
5151
5252CPU_DEVICE = _cpu_device ()
5353
54+ _default = object ()
55+
5456class Array :
5557 """
5658 n-d array object for the array API namespace.
@@ -525,10 +527,34 @@ def __complex__(self: Array, /) -> complex:
525527 res = self ._array .__complex__ ()
526528 return res
527529
528- def __dlpack__ (self : Array , / , * , stream : None = None ) -> PyCapsule :
530+ def __dlpack__ (
531+ self : Array ,
532+ / ,
533+ * ,
534+ stream : Optional [Union [int , Any ]] = None ,
535+ max_version : Optional [tuple [int , int ]] = _default ,
536+ dl_device : Optional [tuple [IntEnum , int ]] = _default ,
537+ copy : Optional [bool ] = _default ,
538+ ) -> PyCapsule :
529539 """
530540 Performs the operation __dlpack__.
531541 """
542+ if get_array_api_strict_flags ()['api_version' ] < '2023.12' :
543+ if max_version is not _default :
544+ raise ValueError ("The max_version argument to __dlpack__ requires at least version 2023.12 of the array API" )
545+ if dl_device is not _default :
546+ raise ValueError ("The device argument to __dlpack__ requires at least version 2023.12 of the array API" )
547+ if copy is not _default :
548+ raise ValueError ("The copy argument to __dlpack__ requires at least version 2023.12 of the array API" )
549+
550+ # Going to wait for upstream numpy support
551+ if max_version not in [_default , None ]:
552+ raise NotImplementedError ("The max_version argument to __dlpack__ is not yet implemented" )
553+ if dl_device not in [_default , None ]:
554+ raise NotImplementedError ("The device argument to __dlpack__ is not yet implemented" )
555+ if copy not in [_default , None ]:
556+ raise NotImplementedError ("The copy argument to __dlpack__ is not yet implemented" )
557+
532558 return self ._array .__dlpack__ (stream = stream )
533559
534560 def __dlpack_device__ (self : Array , / ) -> Tuple [IntEnum , int ]:
0 commit comments