@@ -45,6 +45,7 @@ from ._backend cimport ( # noqa: E211
4545 DPCTLQueue_IsInOrder,
4646 DPCTLQueue_MemAdvise,
4747 DPCTLQueue_Memcpy,
48+ DPCTLQueue_MemcpyWithEvents,
4849 DPCTLQueue_Prefetch,
4950 DPCTLQueue_SubmitBarrierForEvents,
5051 DPCTLQueue_SubmitNDRange,
@@ -64,6 +65,7 @@ import ctypes
6465from .enum_types import backend_type
6566
6667from cpython cimport pycapsule
68+ from cpython.buffer cimport PyObject_CheckBuffer
6769from cpython.ref cimport Py_DECREF, Py_INCREF, PyObject
6870from libc.stdlib cimport free, malloc
6971
@@ -160,6 +162,62 @@ cdef void _queue_capsule_deleter(object o) noexcept:
160162 DPCTLQueue_Delete(QRef)
161163
162164
165+ cdef bint _is_buffer(object o):
166+ return PyObject_CheckBuffer(o)
167+
168+
169+ cdef DPCTLSyclEventRef _memcpy_impl(
170+ SyclQueue q,
171+ object dst,
172+ object src,
173+ size_t byte_count,
174+ DPCTLSyclEventRef * dep_events,
175+ size_t dep_events_count
176+ ):
177+ cdef void * c_dst_ptr = NULL
178+ cdef void * c_src_ptr = NULL
179+ cdef DPCTLSyclEventRef ERef = NULL
180+ cdef const unsigned char [::1 ] src_host_buf = None
181+ cdef unsigned char [::1 ] dst_host_buf = None
182+
183+ if isinstance (src, _Memory):
184+ c_src_ptr = < void * > (< _Memory> src).memory_ptr
185+ elif _is_buffer(src):
186+ src_host_buf = src
187+ c_src_ptr = < void * > & src_host_buf[0 ]
188+ else :
189+ raise TypeError (
190+ " Parameter `src` should have either type "
191+ " `dpctl.memory._Memory` or a type that "
192+ " supports Python buffer protocol"
193+ )
194+
195+ if isinstance (dst, _Memory):
196+ c_dst_ptr = < void * > (< _Memory> dst).memory_ptr
197+ elif _is_buffer(dst):
198+ dst_host_buf = dst
199+ c_dst_ptr = < void * > & dst_host_buf[0 ]
200+ else :
201+ raise TypeError (
202+ " Parameter `dst` should have either type "
203+ " `dpctl.memory._Memory` or a type that "
204+ " supports Python buffer protocol"
205+ )
206+
207+ if dep_events_count == 0 or dep_events is NULL :
208+ ERef = DPCTLQueue_Memcpy(q._queue_ref, c_dst_ptr, c_src_ptr, byte_count)
209+ else :
210+ ERef = DPCTLQueue_MemcpyWithEvents(
211+ q._queue_ref,
212+ c_dst_ptr,
213+ c_src_ptr,
214+ byte_count,
215+ dep_events,
216+ dep_events_count
217+ )
218+ return ERef
219+
220+
163221cdef class _SyclQueue:
164222 """ Barebone data owner class used by SyclQueue.
165223 """
@@ -925,44 +983,44 @@ cdef class SyclQueue(_SyclQueue):
925983 with nogil: DPCTLQueue_Wait(self ._queue_ref)
926984
927985 cpdef memcpy(self , dest, src, size_t count):
928- cdef void * c_dest
929- cdef void * c_src
986+ """ Copy memory from `src` to `dst`"""
930987 cdef DPCTLSyclEventRef ERef = NULL
931988
932- if isinstance (dest, _Memory):
933- c_dest = < void * > (< _Memory> dest).memory_ptr
934- else :
935- raise TypeError (" Parameter `dest` should have type _Memory." )
936-
937- if isinstance (src, _Memory):
938- c_src = < void * > (< _Memory> src).memory_ptr
939- else :
940- raise TypeError (" Parameter `src` should have type _Memory." )
941-
942- ERef = DPCTLQueue_Memcpy(self ._queue_ref, c_dest, c_src, count)
989+ ERef = _memcpy_impl(< SyclQueue> self , dest, src, count, NULL , 0 )
943990 if (ERef is NULL ):
944991 raise RuntimeError (
945992 " SyclQueue.memcpy operation encountered an error"
946993 )
947994 with nogil: DPCTLEvent_Wait(ERef)
948995 DPCTLEvent_Delete(ERef)
949996
950- cpdef SyclEvent memcpy_async(self , dest, src, size_t count):
951- cdef void * c_dest
952- cdef void * c_src
997+ cpdef SyclEvent memcpy_async(self , dest, src, size_t count, list dEvents = None ):
998+ """ Copy memory from `src` to `dst`"""
953999 cdef DPCTLSyclEventRef ERef = NULL
1000+ cdef DPCTLSyclEventRef * depEvents = NULL
1001+ cdef size_t nDE = 0
9541002
955- if isinstance (dest, _Memory):
956- c_dest = < void * > (< _Memory> dest).memory_ptr
957- else :
958- raise TypeError (" Parameter `dest` should have type _Memory." )
959-
960- if isinstance (src, _Memory):
961- c_src = < void * > (< _Memory> src).memory_ptr
1003+ if dEvents is None :
1004+ ERef = _memcpy_impl(< SyclQueue> self , dest, src, count, NULL , 0 )
9621005 else :
963- raise TypeError (" Parameter `src` should have type _Memory." )
1006+ nDE = len (dEvents)
1007+ depEvents = (
1008+ < DPCTLSyclEventRef* > malloc(nDE* sizeof(DPCTLSyclEventRef))
1009+ )
1010+ if depEvents is NULL :
1011+ raise MemoryError ()
1012+ else :
1013+ for idx, de in enumerate (dEvents):
1014+ if isinstance (de, SyclEvent):
1015+ depEvents[idx] = (< SyclEvent> de).get_event_ref()
1016+ else :
1017+ free(depEvents)
1018+ raise TypeError (
1019+ " A sequence of dpctl.SyclEvent is expected"
1020+ )
1021+ ERef = _memcpy_impl(self , dest, src, count, depEvents, nDE)
1022+ free(depEvents)
9641023
965- ERef = DPCTLQueue_Memcpy(self ._queue_ref, c_dest, c_src, count)
9661024 if (ERef is NULL ):
9671025 raise RuntimeError (
9681026 " SyclQueue.memcpy operation encountered an error"
0 commit comments