99import numpy as np
1010from numba import TypingError , types
1111from numba .core import cgutils
12+ from numba .core .extending import overload
1213from numba .np import arrayobj
1314from numpy .core .numeric import normalize_axis_index , normalize_axis_tuple
1415
@@ -174,6 +175,7 @@ def create_axis_reducer(
174175 ndim : int ,
175176 dtype : numba .types .Type ,
176177 keepdims : bool = False ,
178+ return_scalar = False ,
177179) -> numba .core .dispatcher .Dispatcher :
178180 r"""Create Python function that performs a NumPy-like reduction on a given axis.
179181
@@ -284,6 +286,8 @@ def {reduce_elemwise_fn_name}(x):
284286 inplace_update_statement = indent (inplace_update_statement , " " * 4 * 2 )
285287
286288 return_expr = "res" if keepdims else "res.item()"
289+ if not return_scalar :
290+ return_expr = f"np.asarray({ return_expr } )"
287291 reduce_elemwise_def_src = f"""
288292def { reduce_elemwise_fn_name } (x):
289293
@@ -305,7 +309,13 @@ def {reduce_elemwise_fn_name}(x):
305309
306310
307311def create_multiaxis_reducer (
308- scalar_op , identity , axes , ndim , dtype , input_name = "input"
312+ scalar_op ,
313+ identity ,
314+ axes ,
315+ ndim ,
316+ dtype ,
317+ input_name = "input" ,
318+ return_scalar = False ,
309319):
310320 r"""Construct a function that reduces multiple axes.
311321
@@ -336,6 +346,8 @@ def careduce_maximum(input):
336346 The number of dimensions of the result.
337347 dtype:
338348 The data type of the result.
349+ return_scalar:
350+ If True, return a scalar, otherwise an array.
339351
340352 Returns
341353 =======
@@ -370,10 +382,17 @@ def careduce_maximum(input):
370382 )
371383
372384 careduce_assign_lines = indent ("\n " .join (careduce_lines_src ), " " * 4 )
385+ if not return_scalar :
386+ pre_result = "np.asarray"
387+ post_result = ""
388+ else :
389+ pre_result = "np.asarray"
390+ post_result = ".item()"
391+
373392 careduce_def_src = f"""
374393def { careduce_fn_name } ({ input_name } ):
375394{ careduce_assign_lines }
376- return np.asarray ({ var_name } )
395+ return { pre_result } ({ var_name } ){ post_result }
377396 """
378397
379398 careduce_fn = compile_function_src (
@@ -383,7 +402,7 @@ def {careduce_fn_name}({input_name}):
383402 return careduce_fn
384403
385404
386- def jit_compile_reducer (node , fn , ** kwds ):
405+ def jit_compile_reducer (node , fn , * , reduce_to_scalar = False , * *kwds ):
387406 """Compile Python source for reduction loops using additional optimizations.
388407
389408 Parameters
@@ -400,7 +419,7 @@ def jit_compile_reducer(node, fn, **kwds):
400419 A :func:`numba.njit`-compiled function.
401420
402421 """
403- signature = create_numba_signature (node , reduce_to_scalar = True )
422+ signature = create_numba_signature (node , reduce_to_scalar = reduce_to_scalar )
404423
405424 # Eagerly compile the function using increased optimizations. This should
406425 # help improve nested loop reductions.
@@ -618,23 +637,58 @@ def numba_funcify_Elemwise(op, node, **kwargs):
618637 inplace_pattern = tuple (op .inplace_pattern .items ())
619638
620639 # numba doesn't support nested literals right now...
621- input_bc_patterns = base64 .encodebytes (pickle .dumps (input_bc_patterns )).decode ()
622- output_bc_patterns = base64 .encodebytes (pickle .dumps (output_bc_patterns )).decode ()
623- output_dtypes = base64 .encodebytes (pickle .dumps (output_dtypes )).decode ()
624- inplace_pattern = base64 .encodebytes (pickle .dumps (inplace_pattern )).decode ()
640+ input_bc_patterns_enc = base64 .encodebytes (pickle .dumps (input_bc_patterns )).decode ()
641+ output_bc_patterns_enc = base64 .encodebytes (
642+ pickle .dumps (output_bc_patterns )
643+ ).decode ()
644+ output_dtypes_enc = base64 .encodebytes (pickle .dumps (output_dtypes )).decode ()
645+ inplace_pattern_enc = base64 .encodebytes (pickle .dumps (inplace_pattern )).decode ()
625646
626- @numba_njit
627647 def elemwise_wrapper (* inputs ):
628648 return _vectorized (
629649 scalar_op_fn ,
630- input_bc_patterns ,
631- output_bc_patterns ,
632- output_dtypes ,
633- inplace_pattern ,
650+ input_bc_patterns_enc ,
651+ output_bc_patterns_enc ,
652+ output_dtypes_enc ,
653+ inplace_pattern_enc ,
634654 inputs ,
635655 )
636656
637- return elemwise_wrapper
657+ # Pure python implementation, that will be used in tests
658+ def elemwise (* inputs ):
659+ inputs = [np .asarray (input ) for input in inputs ]
660+ inputs_bc = np .broadcast_arrays (* inputs )
661+ shape = inputs [0 ].shape
662+ for input , bc in zip (inputs , input_bc_patterns ):
663+ for length , allow_bc , iter_length in zip (input .shape , bc , shape ):
664+ if length == 1 and shape and iter_length != 1 and not allow_bc :
665+ raise ValueError ("Broadcast not allowed." )
666+
667+ outputs = []
668+ for dtype in output_dtypes :
669+ outputs .append (np .empty (shape , dtype = dtype ))
670+
671+ for idx in np .ndindex (shape ):
672+ vals = [input [idx ] for input in inputs_bc ]
673+ outs = scalar_op_fn (* vals )
674+ if not isinstance (outs , tuple ):
675+ outs = (outs ,)
676+ for out , out_val in zip (outputs , outs ):
677+ out [idx ] = out_val
678+
679+ outputs_summed = []
680+ for output , bc in zip (outputs , output_bc_patterns ):
681+ axes = tuple (np .nonzero (bc )[0 ])
682+ outputs_summed .append (output .sum (axes , keepdims = True ))
683+ if len (outputs_summed ) != 1 :
684+ return tuple (outputs_summed )
685+ return outputs_summed [0 ]
686+
687+ @overload (elemwise )
688+ def ov_elemwise (* inputs ):
689+ return elemwise_wrapper
690+
691+ return elemwise
638692
639693
640694@numba_funcify .register (Sum )
@@ -643,7 +697,7 @@ def numba_funcify_Sum(op, node, **kwargs):
643697 if axes is None :
644698 axes = list (range (node .inputs [0 ].ndim ))
645699
646- axes = list (axes )
700+ axes = tuple (axes )
647701
648702 ndim_input = node .inputs [0 ].ndim
649703
@@ -658,15 +712,16 @@ def numba_funcify_Sum(op, node, **kwargs):
658712
659713 @numba_njit (fastmath = True )
660714 def impl_sum (array ):
661- # TODO The accumulation itself should happen in acc_dtype...
662- return np .asarray (array .sum ()).astype (np_acc_dtype )
715+ return np .asarray (array .sum (), dtype = np_acc_dtype )
663716
664- else :
717+ elif len ( axes ) == 0 :
665718
666719 @numba_njit (fastmath = True )
667720 def impl_sum (array ):
668- # TODO The accumulation itself should happen in acc_dtype...
669- return array .sum (axes ).astype (np_acc_dtype )
721+ return array
722+
723+ else :
724+ impl_sum = numba_funcify_CAReduce (op , node , ** kwargs )
670725
671726 return impl_sum
672727
@@ -705,7 +760,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
705760 input_name = input_name ,
706761 )
707762
708- careduce_fn = jit_compile_reducer (node , careduce_py_fn )
763+ careduce_fn = jit_compile_reducer (node , careduce_py_fn , reduce_to_scalar = False )
709764 return careduce_fn
710765
711766
@@ -888,7 +943,12 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
888943 if axis is not None :
889944 axis = normalize_axis_index (axis , x_at .ndim )
890945 reduce_max_py = create_axis_reducer (
891- scalar_maximum , - np .inf , axis , x_at .ndim , x_dtype , keepdims = True
946+ scalar_maximum ,
947+ - np .inf ,
948+ axis ,
949+ x_at .ndim ,
950+ x_dtype ,
951+ keepdims = True ,
892952 )
893953 reduce_sum_py = create_axis_reducer (
894954 add_as , 0.0 , axis , x_at .ndim , x_dtype , keepdims = True
@@ -935,10 +995,17 @@ def maxandargmax(x):
935995 keep_axes = tuple (i for i in range (x_ndim ) if i not in axes )
936996
937997 reduce_max_py_fn = create_multiaxis_reducer (
938- scalar_maximum , - np .inf , axes , x_ndim , x_dtype
998+ scalar_maximum ,
999+ - np .inf ,
1000+ axes ,
1001+ x_ndim ,
1002+ x_dtype ,
1003+ return_scalar = False ,
9391004 )
9401005 reduce_max = jit_compile_reducer (
941- Apply (node .op , node .inputs , [node .outputs [0 ].clone ()]), reduce_max_py_fn
1006+ Apply (node .op , node .inputs , [node .outputs [0 ].clone ()]),
1007+ reduce_max_py_fn ,
1008+ reduce_to_scalar = False ,
9421009 )
9431010
9441011 reduced_x_ndim = x_ndim - len (axes ) + 1
0 commit comments