@@ -188,11 +188,28 @@ def test_axis0_bug():
188188 assert dpt .all (s == expected )
189189
190190
191+ def _any_complex (dtypes ):
192+ return any (dpt .isdtype (dpt .dtype (dt ), "complex floating" ) for dt in dtypes )
193+
194+
195+ def _skip_on_this_device (sycl_dev ):
196+ device_mask = du .intel_device_info (sycl_dev ).get ("device_id" , 0 ) & 0xFF00
197+ return device_mask in [0x3E00 , 0x9B00 ]
198+
199+
191200@pytest .mark .parametrize ("arg_dtype" , _all_dtypes [1 :])
192201def test_prod_arg_dtype_default_output_dtype_matrix (arg_dtype ):
193202 q = get_queue_or_skip ()
194203 skip_if_dtype_not_supported (arg_dtype , q )
195204
205+ arg_dtype = dpt .dtype (arg_dtype )
206+ if _any_complex ((arg_dtype ,)):
207+ if _skip_on_this_device (q .sycl_device ):
208+ pytest .skip (
209+ "Product reduction for complex output are known "
210+ "to fail for Gen9 with 2024.0 compiler"
211+ )
212+
196213 m = dpt .ones (100 , dtype = arg_dtype )
197214 r = dpt .prod (m )
198215
@@ -245,13 +262,12 @@ def test_prod_arg_out_dtype_matrix(arg_dtype, out_dtype):
245262
246263 out_dtype = dpt .dtype (out_dtype )
247264 arg_dtype = dpt .dtype (arg_dtype )
248- if dpt .isdtype (out_dtype , "complex floating" ) and du ._is_gen9 (
249- q .sycl_device
250- ):
251- pytest .skip (
252- "Product reduction for complex output are known "
253- "to fail for Gen9 with 2024.0 compiler"
254- )
265+ if _any_complex ((arg_dtype , out_dtype )):
266+ if _skip_on_this_device (q .sycl_device ):
267+ pytest .skip (
268+ "Product reduction for complex output are known "
269+ "to fail for Gen9 with 2024.0 compiler"
270+ )
255271
256272 m = dpt .ones (100 , dtype = arg_dtype )
257273 r = dpt .prod (m , dtype = out_dtype )
0 commit comments