@@ -72,7 +72,7 @@ import logging
7272
7373
7474cdef extern from " _host_task_util.hpp" :
75- int async_dec_ref(DPCTLSyclQueueRef, PyObject ** , size_t, DPCTLSyclEventRef * , size_t) nogil
75+ DPCTLSyclEventRef async_dec_ref(DPCTLSyclQueueRef, PyObject ** , size_t, DPCTLSyclEventRef * , size_t, int * ) nogil
7676
7777
7878__all__ = [
@@ -703,6 +703,79 @@ cdef class SyclQueue(_SyclQueue):
703703 """
704704 return < size_t> self ._queue_ref
705705
706+
707+ cpdef SyclEvent _submit_keep_args_alive(
708+ self ,
709+ object args,
710+ list dEvents
711+ ):
712+ """ SyclQueue._submit_keep_args_alive(args, events)
713+
714+ Keeps objects in `args` alive until tasks associated with events
715+ complete.
716+
717+ Args:
718+ args(object): Python object to keep alive.
719+ Typically a tuple with arguments to offloaded tasks
720+ events(Tuple[dpctl.SyclEvent]): Gating events
721+ The list or tuple of events associated with tasks
722+ working on Python objects collected in `args`.
723+ Returns:
724+ dpctl.SyclEvent
725+ The event associated with the submission of host task.
726+
727+ Increments reference count of `args` and schedules asynchronous
728+ ``host_task`` to decrement the count once dependent events are
729+ complete.
730+
731+ N.B.: The `host_task` attempts to acquire Python GIL, and it is
732+ known to be unsafe during interpreter shudown sequence. It is
733+ thus strongly advised to ensure that all submitted `host_task`
734+ complete before the end of the Python script.
735+ """
736+ cdef size_t nDE = len (dEvents)
737+ cdef DPCTLSyclEventRef * depEvents = NULL
738+ cdef PyObject * args_raw = NULL
739+ cdef DPCTLSyclEventRef htERef = NULL
740+ cdef int status = - 1
741+
742+ # Create the array of dependent events if any
743+ if nDE > 0 :
744+ depEvents = (
745+ < DPCTLSyclEventRef* > malloc(nDE* sizeof(DPCTLSyclEventRef))
746+ )
747+ if not depEvents:
748+ raise MemoryError ()
749+ else :
750+ for idx, de in enumerate (dEvents):
751+ if isinstance (de, SyclEvent):
752+ depEvents[idx] = (< SyclEvent> de).get_event_ref()
753+ else :
754+ free(depEvents)
755+ raise TypeError (
756+ " A sequence of dpctl.SyclEvent is expected"
757+ )
758+
759+ # increment reference counts to list of arguments
760+ Py_INCREF(args)
761+
762+ # schedule decrement
763+ args_raw = < PyObject * > args
764+
765+ htERef = async_dec_ref(
766+ self .get_queue_ref(),
767+ & args_raw, 1 ,
768+ depEvents, nDE, & status
769+ )
770+
771+ free(depEvents)
772+ if (status != 0 ):
773+ with nogil: DPCTLEvent_Wait(htERef)
774+ raise RuntimeError (" Could not submit keep_args_alive host_task" )
775+
776+ return SyclEvent._create(htERef)
777+
778+
706779 cpdef SyclEvent submit(
707780 self ,
708781 SyclKernel kernel,
@@ -715,13 +788,14 @@ cdef class SyclQueue(_SyclQueue):
715788 cdef _arg_data_type * kargty = NULL
716789 cdef DPCTLSyclEventRef * depEvents = NULL
717790 cdef DPCTLSyclEventRef Eref = NULL
791+ cdef DPCTLSyclEventRef htEref = NULL
718792 cdef int ret = 0
719793 cdef size_t gRange[3 ]
720794 cdef size_t lRange[3 ]
721795 cdef size_t nGS = len (gS)
722796 cdef size_t nLS = len (lS) if lS is not None else 0
723797 cdef size_t nDE = len (dEvents) if dEvents is not None else 0
724- cdef PyObject ** arg_objects = NULL
798+ cdef PyObject * args_raw = NULL
725799 cdef ssize_t i = 0
726800
727801 # Allocate the arrays to be sent to DPCTLQueue_Submit
@@ -745,7 +819,15 @@ cdef class SyclQueue(_SyclQueue):
745819 raise MemoryError ()
746820 else :
747821 for idx, de in enumerate (dEvents):
748- depEvents[idx] = (< SyclEvent> de).get_event_ref()
822+ if isinstance (de, SyclEvent):
823+ depEvents[idx] = (< SyclEvent> de).get_event_ref()
824+ else :
825+ free(kargs)
826+ free(kargty)
827+ free(depEvents)
828+ raise TypeError (
829+ " A sequence of dpctl.SyclEvent is expected"
830+ )
749831
750832 # populate the args and argstype arrays
751833 ret = self ._populate_args(args, kargs, kargty)
@@ -823,22 +905,23 @@ cdef class SyclQueue(_SyclQueue):
823905 raise SyclKernelSubmitError(
824906 " Kernel submission to Sycl queue failed."
825907 )
826- # increment reference counts to each argument
827- arg_objects = < PyObject ** > malloc(len (args) * sizeof(PyObject * ))
828- for i in range (len (args)):
829- arg_objects[i] = < PyObject * > (args[i])
830- Py_INCREF(< object > arg_objects[i])
908+ # increment reference counts to list of arguments
909+ Py_INCREF(args)
831910
832911 # schedule decrement
833- if async_dec_ref(self .get_queue_ref(), arg_objects, len (args), & Eref, 1 ):
912+ args_raw = < PyObject * > args
913+
914+ ret = - 1
915+ htERef = async_dec_ref(self .get_queue_ref(), & args_raw, 1 , & Eref, 1 , & ret)
916+ if ret:
834917 # async task submission failed, decrement ref counts and wait
835- for i in range ( len ( args)):
836- arg_objects[i] = < PyObject * > (args[i])
837- Py_DECREF( < object > arg_objects[i] )
838- with nogil: DPCTLEvent_Wait(Eref )
918+ Py_DECREF( args)
919+ with nogil:
920+ DPCTLEvent_Wait(Eref )
921+ DPCTLEvent_Wait(htERef )
839922
840- # free memory
841- free(arg_objects )
923+ # we are not returning host-task event at the moment
924+ DPCTLEvent_Delete(htERef )
842925
843926 return SyclEvent._create(Eref)
844927
0 commit comments