1- import inspect
21from functools import singledispatch
32from numbers import Number
3+ import pickle
44from textwrap import indent
5- from typing import Any , Callable , Optional , Union
5+ from typing import Any , Callable , Literal , Optional , Union
6+ import base64
67
78import numba
89import numpy as np
10+ from llvmlite import ir
11+ from numba import TypingError , literal_unroll , types , literally
12+ from numba .core import cgutils
13+ from numba .cpython .unsafe .tuple import tuple_setitem
14+ from numba .np import arrayobj
915from numpy .core .numeric import normalize_axis_index , normalize_axis_tuple
1016
1117from pytensor import config
1622 create_numba_signature ,
1723 create_tuple_creator ,
1824 numba_funcify ,
25+ numba_njit ,
1926 use_optimized_cheap_pass ,
2027)
21- from pytensor .link .utils import (
22- compile_function_src ,
23- get_name_for_object ,
24- unique_name_generator ,
25- )
28+ from pytensor .link .numba .dispatch .helpers import check_broadcasting , tuple_mapper
29+ from pytensor .link .numba .dispatch import elemwise_codegen
30+ from pytensor .link .utils import compile_function_src , get_name_for_object
2631from pytensor .scalar .basic import (
2732 AND ,
2833 OR ,
@@ -431,6 +436,170 @@ def axis_apply_fn(x):
431436 return axis_apply_fn
432437
433438
439+ _jit_options = {
440+ "fastmath" : {
441+ "arcp" , # Allow Reciprocal
442+ "contract" , # Allow floating-point contraction
443+ "afn" , # Approximate functions
444+ "reassoc" ,
445+ "nsz" , # TODO Do we want this one?
446+ }
447+ }
448+
449+ @numba .extending .intrinsic (jit_options = _jit_options , prefer_literal = True )
450+ def _vectorized (
451+ typingctx ,
452+ scalar_func ,
453+ input_bc_patterns ,
454+ output_bc_patterns ,
455+ output_dtypes ,
456+ inplace_pattern ,
457+ inputs ,
458+ ):
459+ #if not isinstance(scalar_func, types.Literal):
460+ # raise TypingError("scalar func must be literal.")
461+ #scalar_func = scalar_func.literal_value
462+
463+ arg_types = [
464+ scalar_func ,
465+ input_bc_patterns ,
466+ output_bc_patterns ,
467+ output_dtypes ,
468+ inplace_pattern ,
469+ inputs ,
470+ ]
471+
472+ if not isinstance (input_bc_patterns , types .Literal ):
473+ raise TypingError ("input_bc_patterns must be literal." )
474+ input_bc_patterns = input_bc_patterns .literal_value
475+ input_bc_patterns = pickle .loads (base64 .decodebytes (input_bc_patterns .encode ()))
476+
477+ if not isinstance (output_bc_patterns , types .Literal ):
478+ raise TypeError ("output_bc_patterns must be literal." )
479+ output_bc_patterns = output_bc_patterns .literal_value
480+ output_bc_patterns = pickle .loads (base64 .decodebytes (output_bc_patterns .encode ()))
481+
482+ if not isinstance (output_dtypes , types .Literal ):
483+ raise TypeError ("output_dtypes must be literal." )
484+ output_dtypes = output_dtypes .literal_value
485+ output_dtypes = pickle .loads (base64 .decodebytes (output_dtypes .encode ()))
486+
487+ if not isinstance (inplace_pattern , types .Literal ):
488+ raise TypeError ("inplace_pattern must be literal." )
489+ inplace_pattern = inplace_pattern .literal_value
490+ inplace_pattern = pickle .loads (base64 .decodebytes (inplace_pattern .encode ()))
491+
492+ n_inputs = len (inputs )
493+ n_outputs = len (output_bc_patterns )
494+
495+ if not len (inputs ) > 0 :
496+ raise TypingError ("Empty argument list to elemwise op." )
497+
498+ if not n_outputs > 0 :
499+ raise TypingError ("Empty list of outputs for elemwise op." )
500+
501+ if not all (isinstance (input , types .Array ) for input in inputs ):
502+ raise TypingError ("Inputs to elemwise must be arrays." )
503+ ndim = inputs [0 ].ndim
504+
505+ if not all (input .ndim == ndim for input in inputs ):
506+ raise TypingError ("Inputs to elemwise must have the same rank." )
507+
508+ if not all (len (pattern ) == ndim for pattern in output_bc_patterns ):
509+ raise TypingError ("Invalid output broadcasting pattern." )
510+
511+ scalar_signature = typingctx .resolve_function_type (
512+ scalar_func , [in_type .dtype for in_type in inputs ], {}
513+ )
514+
515+ # So we can access the constant values in codegen...
516+ input_bc_patterns_val = input_bc_patterns
517+ output_bc_patterns_val = output_bc_patterns
518+ output_dtypes_val = output_dtypes
519+ inplace_pattern_val = inplace_pattern
520+ input_types = inputs
521+
522+ #assert not inplace_pattern_val
523+
524+ def codegen (
525+ ctx ,
526+ builder ,
527+ sig ,
528+ args ,
529+ ):
530+
531+ [_ , _ , _ , _ , _ , inputs ] = args
532+ inputs = cgutils .unpack_tuple (builder , inputs )
533+ inputs = [arrayobj .make_array (ty )(ctx , builder , val ) for ty , val in zip (input_types , inputs )]
534+ in_shapes = [cgutils .unpack_tuple (builder , obj .shape ) for obj in inputs ]
535+
536+ iter_shape = elemwise_codegen .compute_itershape (
537+ ctx ,
538+ builder ,
539+ in_shapes ,
540+ input_bc_patterns_val ,
541+ )
542+
543+ outputs , output_types = elemwise_codegen .make_outputs (
544+ ctx ,
545+ builder ,
546+ iter_shape ,
547+ output_bc_patterns_val ,
548+ output_dtypes_val ,
549+ inplace_pattern_val ,
550+ inputs ,
551+ input_types ,
552+ )
553+
554+ def _check_input_shapes (* _ ):
555+ # TODO impl
556+ return
557+
558+ _check_input_shapes (
559+ ctx ,
560+ builder ,
561+ iter_shape ,
562+ inputs ,
563+ input_bc_patterns_val ,
564+ )
565+
566+ elemwise_codegen .make_loop_call (
567+ typingctx ,
568+ ctx ,
569+ builder ,
570+ scalar_func ,
571+ scalar_signature ,
572+ iter_shape ,
573+ inputs ,
574+ outputs ,
575+ input_bc_patterns_val ,
576+ output_bc_patterns_val ,
577+ input_types ,
578+ output_types ,
579+ )
580+
581+ if len (outputs ) == 1 :
582+ if inplace_pattern :
583+ assert inplace_pattern [0 ][0 ] == 0
584+ ctx .nrt .incref (builder , sig .return_type , outputs [0 ]._getvalue ())
585+ return outputs [0 ]._getvalue ()
586+
587+ for inplace_idx in dict (inplace_pattern ):
588+ ctx .nrt .incref (builder , sig .return_type .types [inplace_idx ], outputs [inplace_idx ]._get_value ())
589+ return ctx .make_tuple (builder , sig .return_type , [out ._getvalue () for out in outputs ])
590+
591+ # TODO check inplace_pattern
592+ ret_type = types .Tuple ([
593+ types .Array (numba .from_dtype (np .dtype (dtype )), ndim , "C" )
594+ for dtype in output_dtypes
595+ ])
596+ if len (output_dtypes ) == 1 :
597+ ret_type = ret_type .types [0 ]
598+ sig = ret_type (* arg_types )
599+
600+ return sig , codegen
601+
602+
434603@numba_funcify .register (Elemwise )
435604def numba_funcify_Elemwise (op , node , ** kwargs ):
436605 # Creating a new scalar node is more involved and unnecessary
@@ -441,55 +610,42 @@ def numba_funcify_Elemwise(op, node, **kwargs):
441610 scalar_inputs = [scalar (dtype = input .dtype ) for input in node .inputs ]
442611 scalar_node = op .scalar_op .make_node (* scalar_inputs )
443612
613+ flags = {
614+ "arcp" , # Allow Reciprocal
615+ "contract" , # Allow floating-point contraction
616+ "afn" , # Approximate functions
617+ "reassoc" ,
618+ "nsz" , # TODO Do we want this one?
619+ }
620+
444621 scalar_op_fn = numba_funcify (
445- op .scalar_op , node = scalar_node , parent_node = node , inline = "always" , ** kwargs
622+ op .scalar_op , node = scalar_node , parent_node = node , fastmath = flags , ** kwargs
446623 )
447- elemwise_fn = create_vectorize_func (scalar_op_fn , node , use_signature = False )
448- elemwise_fn_name = elemwise_fn .__name__
449-
450- if op .inplace_pattern :
451- input_idx = op .inplace_pattern [0 ]
452- sign_obj = inspect .signature (elemwise_fn .py_scalar_func )
453- input_names = list (sign_obj .parameters .keys ())
454-
455- unique_names = unique_name_generator ([elemwise_fn_name , "np" ], suffix_sep = "_" )
456- input_names = [unique_names (i , force_unique = True ) for i in input_names ]
457624
458- updated_input_name = input_names [input_idx ]
459-
460- inplace_global_env = {elemwise_fn_name : elemwise_fn , "np" : np }
461-
462- inplace_elemwise_fn_name = f"{ elemwise_fn_name } _inplace"
463-
464- input_signature_str = ", " .join (input_names )
465-
466- if node .inputs [input_idx ].ndim > 0 :
467- inplace_elemwise_src = f"""
468- def { inplace_elemwise_fn_name } ({ input_signature_str } ):
469- return { elemwise_fn_name } ({ input_signature_str + ", " + updated_input_name } )
470- """
471- else :
472- # We can't perform in-place updates on Numba scalars, so we need to
473- # convert them to NumPy scalars.
474- # TODO: We should really prevent the rewrites from creating
475- # in-place updates on scalars when the Numba mode is selected (or
476- # in general?).
477- inplace_elemwise_src = f"""
478- def { inplace_elemwise_fn_name } ({ input_signature_str } ):
479- { updated_input_name } _scalar = np.asarray({ updated_input_name } )
480- return { elemwise_fn_name } ({ input_signature_str + ", " + updated_input_name } _scalar).item()
481- """
482-
483- inplace_elemwise_fn = compile_function_src (
484- inplace_elemwise_src ,
485- inplace_elemwise_fn_name ,
486- {** globals (), ** inplace_global_env },
487- )
488- return numba_basic .numba_njit (inline = "always" , fastmath = config .numba__fastmath )(
489- inplace_elemwise_fn
625+ ndim = node .outputs [0 ].ndim
626+ output_bc_patterns = tuple ([(False ,) * ndim for _ in node .outputs ])
627+ input_bc_patterns = tuple ([input_var .broadcastable for input_var in node .inputs ])
628+ output_dtypes = tuple (variable .dtype for variable in node .outputs )
629+ inplace_pattern = tuple (op .inplace_pattern .items ())
630+
631+ # numba doesn't support nested literals right now...
632+ input_bc_patterns = base64 .encodebytes (pickle .dumps (input_bc_patterns )).decode ()
633+ output_bc_patterns = base64 .encodebytes (pickle .dumps (output_bc_patterns )).decode ()
634+ output_dtypes = base64 .encodebytes (pickle .dumps (output_dtypes )).decode ()
635+ inplace_pattern = base64 .encodebytes (pickle .dumps (inplace_pattern )).decode ()
636+
637+ @numba_njit
638+ def elemwise_wrapper (* inputs ):
639+ return _vectorized (
640+ scalar_op_fn ,
641+ input_bc_patterns ,
642+ output_bc_patterns ,
643+ output_dtypes ,
644+ inplace_pattern ,
645+ inputs ,
490646 )
491647
492- return elemwise_fn
648+ return elemwise_wrapper
493649
494650
495651@numba_funcify .register (CAReduce )
0 commit comments