@@ -19,9 +19,7 @@ def compute_itershape(
1919 ndim = len (in_shapes [0 ])
2020 shape = [None ] * ndim
2121 for i in range (ndim ):
22- for j , (bc , in_shape ) in enumerate (
23- zip (broadcast_pattern , in_shapes , strict = True )
24- ):
22+ for j , (bc , in_shape ) in enumerate (zip (broadcast_pattern , in_shapes )):
2523 length = in_shape [i ]
2624 if bc [i ]:
2725 with builder .if_then (
@@ -151,14 +149,10 @@ def extract_array(aryty, obj):
151149 # input_scope_set = mod.add_metadata([input_scope, output_scope])
152150 # output_scope_set = mod.add_metadata([input_scope, output_scope])
153151
154- inputs = tuple (
155- extract_array (aryty , ary )
156- for aryty , ary in zip (input_types , inputs , strict = True )
157- )
152+ inputs = tuple (extract_array (aryty , ary ) for aryty , ary in zip (input_types , inputs ))
158153
159154 outputs = tuple (
160- extract_array (aryty , ary )
161- for aryty , ary in zip (output_types , outputs , strict = True )
155+ extract_array (aryty , ary ) for aryty , ary in zip (output_types , outputs )
162156 )
163157
164158 zero = ir .Constant (ir .IntType (64 ), 0 )
@@ -189,8 +183,8 @@ def extract_array(aryty, obj):
189183
190184 # Load values from input arrays
191185 input_vals = []
192- for array_info , bc in zip (inputs , input_bc , strict = True ):
193- idxs_bc = [zero if bc else idx for idx , bc in zip (idxs , bc , strict = True )]
186+ for array_info , bc in zip (inputs , input_bc ):
187+ idxs_bc = [zero if bc else idx for idx , bc in zip (idxs , bc )]
194188 ptr = cgutils .get_item_pointer2 (context , builder , * array_info , idxs_bc , * safe )
195189 val = builder .load (ptr )
196190 # val.set_metadata("alias.scope", input_scope_set)
@@ -210,9 +204,7 @@ def extract_array(aryty, obj):
210204 output_values = [output_values ]
211205
212206 # Update output value or accumulators respectively
213- for i , ((accu , _ ), value ) in enumerate (
214- zip (output_accumulator , output_values , strict = True )
215- ):
207+ for i , ((accu , _ ), value ) in enumerate (zip (output_accumulator , output_values )):
216208 if accu is not None :
217209 load = builder .load (accu )
218210 # load.set_metadata("alias.scope", output_scope_set)
@@ -223,9 +215,7 @@ def extract_array(aryty, obj):
223215 # store.set_metadata("alias.scope", output_scope_set)
224216 # store.set_metadata("noalias", input_scope_set)
225217 else :
226- idxs_bc = [
227- zero if bc else idx for idx , bc in zip (idxs , output_bc [i ], strict = True )
228- ]
218+ idxs_bc = [zero if bc else idx for idx , bc in zip (idxs , output_bc [i ])]
229219 ptr = cgutils .get_item_pointer2 (context , builder , * outputs [i ], idxs_bc )
230220 # store = builder.store(value, ptr)
231221 arrayobj .store_item (context , builder , output_types [i ], value , ptr )
@@ -237,8 +227,7 @@ def extract_array(aryty, obj):
237227 for output , (accu , accu_depth ) in enumerate (output_accumulator ):
238228 if accu_depth == depth :
239229 idxs_bc = [
240- zero if bc else idx
241- for idx , bc in zip (idxs , output_bc [output ], strict = True )
230+ zero if bc else idx for idx , bc in zip (idxs , output_bc [output ])
242231 ]
243232 ptr = cgutils .get_item_pointer2 (
244233 context , builder , * outputs [output ], idxs_bc
0 commit comments