1+ from typing import Any , List , Optional , Tuple
2+
13import numba
24import numpy as np
35from llvmlite import ir
1012def compute_itershape (
1113 ctx : BaseContext ,
1214 builder : ir .IRBuilder ,
13- in_shapes ,
14- broadcast_pattern ,
15+ in_shapes : Tuple [ ir . Instruction , ...] ,
16+ broadcast_pattern : Tuple [ Tuple [ bool , ...], ...] ,
1517):
1618 one = ir .IntType (64 )(1 )
1719 ndim = len (in_shapes [0 ])
@@ -59,16 +61,23 @@ def compute_itershape(
5961
6062
6163def make_outputs (
62- ctx , builder : ir .IRBuilder , iter_shape , out_bc , dtypes , inplace , inputs , input_types
64+ ctx : numba .core .base .BaseContext ,
65+ builder : ir .IRBuilder ,
66+ iter_shape : Tuple [ir .Instruction , ...],
67+ out_bc : Tuple [Tuple [bool , ...], ...],
68+ dtypes : Tuple [Any , ...],
69+ inplace : Tuple [Tuple [int , int ], ...],
70+ inputs : Tuple [Any , ...],
71+ input_types : Tuple [Any , ...],
6372):
6473 arrays = []
6574 ar_types : list [types .Array ] = []
6675 one = ir .IntType (64 )(1 )
67- inplace = dict (inplace )
76+ inplace_dict = dict (inplace )
6877 for i , (bc , dtype ) in enumerate (zip (out_bc , dtypes )):
69- if i in inplace :
70- arrays .append (inputs [inplace [i ]])
71- ar_types .append (input_types [inplace [i ]])
78+ if i in inplace_dict :
79+ arrays .append (inputs [inplace_dict [i ]])
80+ ar_types .append (input_types [inplace_dict [i ]])
7281 # We need to incref once we return the inplace objects
7382 continue
7483 dtype = numba .from_dtype (np .dtype (dtype ))
@@ -95,15 +104,15 @@ def make_loop_call(
95104 typingctx ,
96105 context : numba .core .base .BaseContext ,
97106 builder : ir .IRBuilder ,
98- scalar_func ,
99- scalar_signature ,
100- iter_shape ,
101- inputs ,
102- outputs ,
103- input_bc ,
104- output_bc ,
105- input_types ,
106- output_types ,
107+ scalar_func : Any ,
108+ scalar_signature : types . FunctionType ,
109+ iter_shape : Tuple [ ir . Instruction , ...] ,
110+ inputs : Tuple [ ir . Instruction , ...] ,
111+ outputs : Tuple [ ir . Instruction , ...] ,
112+ input_bc : Tuple [ Tuple [ bool , ...], ...] ,
113+ output_bc : Tuple [ Tuple [ bool , ...], ...] ,
114+ input_types : Tuple [ Any , ...] ,
115+ output_types : Tuple [ Any , ...] ,
107116):
108117 safe = (False , False )
109118 n_outputs = len (outputs )
@@ -142,23 +151,25 @@ def extract_array(aryty, obj):
142151 # input_scope_set = mod.add_metadata([input_scope, output_scope])
143152 # output_scope_set = mod.add_metadata([input_scope, output_scope])
144153
145- inputs = [
154+ inputs = tuple (
146155 extract_array (aryty , ary )
147156 for aryty , ary in zip (input_types , inputs , strict = True )
148- ]
157+ )
149158
150- outputs = [
159+ outputs = tuple (
151160 extract_array (aryty , ary )
152161 for aryty , ary in zip (output_types , outputs , strict = True )
153- ]
162+ )
154163
155164 zero = ir .Constant (ir .IntType (64 ), 0 )
156165
157166 # Setup loops and initialize accumulators for outputs
158167 # This part corresponds to opening the loops
159168 loop_stack = []
160169 loops = []
161- output_accumulator = [(None , None )] * n_outputs
170+ output_accumulator : List [Tuple [Optional [Any ], Optional [int ]]] = [
171+ (None , None )
172+ ] * n_outputs
162173 for dim , length in enumerate (iter_shape ):
163174 # Find outputs that only have accumulations left
164175 for output in range (n_outputs ):
0 commit comments