8484from pytensor .tensor .math import sum as at_sum
8585from pytensor .tensor .math import true_div
8686from pytensor .tensor .rewriting .basic import (
87- broadcast_like ,
87+ alloc_like ,
8888 broadcasted_by ,
8989 local_fill_sink ,
9090 register_canonicalize ,
@@ -1973,7 +1973,7 @@ def local_div_to_reciprocal(fgraph, node):
19731973 new_out = cast (new_out , dtype = out .dtype )
19741974 # The ones could have forced a specific length
19751975 if not out .type .is_super (new_out .type ):
1976- new_out = broadcast_like (new_out , out , fgraph )
1976+ new_out = alloc_like (new_out , out , fgraph )
19771977 return [new_out ]
19781978 else :
19791979 return False
@@ -1994,9 +1994,9 @@ def local_pow_canonicalize(fgraph, node):
19941994 if node .op == at_pow :
19951995 cst = get_constant (node .inputs [1 ])
19961996 if cst == 0 :
1997- return [broadcast_like (1 , node .outputs [0 ], fgraph )]
1997+ return [alloc_like (1 , node .outputs [0 ], fgraph )]
19981998 if cst == 1 :
1999- return [broadcast_like (node .inputs [0 ], node .outputs [0 ], fgraph )]
1999+ return [alloc_like (node .inputs [0 ], node .outputs [0 ], fgraph )]
20002000 else :
20012001 return False
20022002
@@ -2033,7 +2033,7 @@ def local_zero_div(fgraph, node):
20332033 node .op .scalar_op , (aes .IntDiv , aes .TrueDiv )
20342034 ):
20352035 if get_constant (node .inputs [0 ]) == 0 :
2036- ret = broadcast_like (0 , node .outputs [0 ], fgraph )
2036+ ret = alloc_like (0 , node .outputs [0 ], fgraph )
20372037 ret .tag .values_eq_approx = values_eq_approx_remove_nan
20382038 return [ret ]
20392039
@@ -2184,7 +2184,7 @@ def local_mul_specialize(fgraph, node):
21842184 has_neg ^= True # toggles
21852185 elif y == 0.0 :
21862186 # if we find any zero, we just return right away
2187- return [broadcast_like (0 , node .outputs [0 ], fgraph )]
2187+ return [alloc_like (0 , node .outputs [0 ], fgraph )]
21882188 else :
21892189 new_inputs .append (inp )
21902190
@@ -2209,14 +2209,14 @@ def local_mul_specialize(fgraph, node):
22092209 new_inputs = [m1 ] + new_inputs
22102210 rval = mul (* new_inputs )
22112211
2212- return [broadcast_like (rval , node .outputs [0 ], fgraph )]
2212+ return [alloc_like (rval , node .outputs [0 ], fgraph )]
22132213 else :
22142214 # there are no variable inputs to mul
22152215 # N.B. this could have been constant-folded...
22162216 if has_neg :
2217- return [broadcast_like (- 1 , node .outputs [0 ], fgraph )]
2217+ return [alloc_like (- 1 , node .outputs [0 ], fgraph )]
22182218 else :
2219- return [broadcast_like (1 , node .outputs [0 ], fgraph )]
2219+ return [alloc_like (1 , node .outputs [0 ], fgraph )]
22202220
22212221
22222222@register_specialize
0 commit comments