@@ -57,12 +57,7 @@ struct FloorDivideFunctor
5757
5858 resT operator ()(const argT1 &in1, const argT2 &in2) const
5959 {
60- if constexpr (std::is_same_v<argT1, bool > &&
61- std::is_same_v<argT2, bool >) {
62- return (in2) ? static_cast <resT>(in1) : resT (0 );
63- }
64- else if constexpr (std::is_integral_v<argT1> ||
65- std::is_integral_v<argT2>) {
60+ if constexpr (std::is_integral_v<argT1> || std::is_integral_v<argT2>) {
6661 if (in2 == argT2 (0 )) {
6762 return resT (0 );
6863 }
@@ -87,16 +82,7 @@ struct FloorDivideFunctor
8782 operator ()(const sycl::vec<argT1, vec_sz> &in1,
8883 const sycl::vec<argT2, vec_sz> &in2) const
8984 {
90- if constexpr (std::is_same_v<argT1, bool > &&
91- std::is_same_v<argT2, bool >) {
92- sycl::vec<resT, vec_sz> res;
93- #pragma unroll
94- for (int i = 0 ; i < vec_sz; ++i) {
95- res[i] = (in2[i]) ? static_cast <resT>(in1[i]) : resT (0 );
96- }
97- return res;
98- }
99- else if constexpr (std::is_integral_v<resT>) {
85+ if constexpr (std::is_integral_v<resT>) {
10086 sycl::vec<resT, vec_sz> res;
10187#pragma unroll
10288 for (int i = 0 ; i < vec_sz; ++i) {
@@ -165,7 +151,6 @@ template <typename T1, typename T2> struct FloorDivideOutputType
165151{
166152 using value_type = typename std::disjunction< // disjunction is C++17
167153 // feature, supported by DPC++
168- td_ns::BinaryTypeMapResultEntry<T1, bool , T2, bool , std::int8_t >,
169154 td_ns::BinaryTypeMapResultEntry<T1,
170155 std::uint8_t ,
171156 T2,
@@ -315,6 +300,183 @@ struct FloorDivideStridedFactory
315300 }
316301};
317302
303+ template <typename argT, typename resT> struct FloorDivideInplaceFunctor
304+ {
305+ using supports_sg_loadstore = std::true_type;
306+ using supports_vec = std::true_type;
307+
308+ void operator ()(resT &in1, const argT &in2) const
309+ {
310+ if constexpr (std::is_integral_v<resT>) {
311+ if (in2 == argT (0 )) {
312+ in1 = 0 ;
313+ return ;
314+ }
315+ if constexpr (std::is_signed_v<resT>) {
316+ auto tmp = in1;
317+ in1 /= in2;
318+ auto mod = tmp % in2;
319+ auto corr = (mod != 0 && l_xor (mod < 0 , in2 < 0 ));
320+ in1 -= corr;
321+ }
322+ else {
323+ in1 /= in2;
324+ }
325+ }
326+ else {
327+ in1 /= in2;
328+ if (in1 == resT (0 )) {
329+ return ;
330+ }
331+ in1 = std::floor (in1);
332+ }
333+ }
334+
335+ template <int vec_sz>
336+ void operator ()(sycl::vec<resT, vec_sz> &in1,
337+ const sycl::vec<argT, vec_sz> &in2) const
338+ {
339+ if constexpr (std::is_integral_v<resT>) {
340+ #pragma unroll
341+ for (int i = 0 ; i < vec_sz; ++i) {
342+ if (in2[i] == argT (0 )) {
343+ in1[i] = 0 ;
344+ }
345+ else {
346+ if constexpr (std::is_signed_v<resT>) {
347+ auto tmp = in1[i];
348+ in1[i] /= in2[i];
349+ auto mod = tmp % in2[i];
350+ auto corr = (mod != 0 && l_xor (mod < 0 , in2[i] < 0 ));
351+ in1[i] -= corr;
352+ }
353+ else {
354+ in1[i] /= in2[i];
355+ }
356+ }
357+ }
358+ }
359+ else {
360+ in1 /= in2;
361+ #pragma unroll
362+ for (int i = 0 ; i < vec_sz; ++i) {
363+ if (in2[i] != argT (0 )) {
364+ in1[i] = std::floor (in1[i]);
365+ }
366+ }
367+ }
368+ }
369+
370+ private:
371+ bool l_xor (bool b1, bool b2) const
372+ {
373+ return (b1 != b2);
374+ }
375+ };
376+
377+ template <typename argT,
378+ typename resT,
379+ unsigned int vec_sz = 4 ,
380+ unsigned int n_vecs = 2 >
381+ using FloorDivideInplaceContigFunctor =
382+ elementwise_common::BinaryInplaceContigFunctor<
383+ argT,
384+ resT,
385+ FloorDivideInplaceFunctor<argT, resT>,
386+ vec_sz,
387+ n_vecs>;
388+
389+ template <typename argT, typename resT, typename IndexerT>
390+ using FloorDivideInplaceStridedFunctor =
391+ elementwise_common::BinaryInplaceStridedFunctor<
392+ argT,
393+ resT,
394+ IndexerT,
395+ FloorDivideInplaceFunctor<argT, resT>>;
396+
397+ template <typename argT,
398+ typename resT,
399+ unsigned int vec_sz,
400+ unsigned int n_vecs>
401+ class floor_divide_inplace_contig_kernel ;
402+
403+ template <typename argTy, typename resTy>
404+ sycl::event
405+ floor_divide_inplace_contig_impl (sycl::queue &exec_q,
406+ size_t nelems,
407+ const char *arg_p,
408+ py::ssize_t arg_offset,
409+ char *res_p,
410+ py::ssize_t res_offset,
411+ const std::vector<sycl::event> &depends = {})
412+ {
413+ return elementwise_common::binary_inplace_contig_impl<
414+ argTy, resTy, FloorDivideInplaceContigFunctor,
415+ floor_divide_inplace_contig_kernel>(exec_q, nelems, arg_p, arg_offset,
416+ res_p, res_offset, depends);
417+ }
418+
419+ template <typename fnT, typename T1, typename T2>
420+ struct FloorDivideInplaceContigFactory
421+ {
422+ fnT get ()
423+ {
424+ if constexpr (std::is_same_v<
425+ typename FloorDivideOutputType<T1, T2>::value_type,
426+ void >)
427+ {
428+ fnT fn = nullptr ;
429+ return fn;
430+ }
431+ else {
432+ fnT fn = floor_divide_inplace_contig_impl<T1, T2>;
433+ return fn;
434+ }
435+ }
436+ };
437+
438+ template <typename resT, typename argT, typename IndexerT>
439+ class floor_divide_inplace_strided_kernel ;
440+
441+ template <typename argTy, typename resTy>
442+ sycl::event floor_divide_inplace_strided_impl (
443+ sycl::queue &exec_q,
444+ size_t nelems,
445+ int nd,
446+ const py::ssize_t *shape_and_strides,
447+ const char *arg_p,
448+ py::ssize_t arg_offset,
449+ char *res_p,
450+ py::ssize_t res_offset,
451+ const std::vector<sycl::event> &depends,
452+ const std::vector<sycl::event> &additional_depends)
453+ {
454+ return elementwise_common::binary_inplace_strided_impl<
455+ argTy, resTy, FloorDivideInplaceStridedFunctor,
456+ floor_divide_inplace_strided_kernel>(
457+ exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
458+ res_offset, depends, additional_depends);
459+ }
460+
461+ template <typename fnT, typename T1, typename T2>
462+ struct FloorDivideInplaceStridedFactory
463+ {
464+ fnT get ()
465+ {
466+ if constexpr (std::is_same_v<
467+ typename FloorDivideOutputType<T1, T2>::value_type,
468+ void >)
469+ {
470+ fnT fn = nullptr ;
471+ return fn;
472+ }
473+ else {
474+ fnT fn = floor_divide_inplace_strided_impl<T1, T2>;
475+ return fn;
476+ }
477+ }
478+ };
479+
318480} // namespace floor_divide
319481} // namespace kernels
320482} // namespace tensor
0 commit comments