3030 cast ,
3131 constant ,
3232 extract_constant ,
33- fill ,
3433 get_underlying_scalar_constant_value ,
3534 ones_like ,
3635 switch ,
@@ -2041,8 +2040,6 @@ def local_zero_div(fgraph, node):
20412040@register_specialize
20422041@node_rewriter ([at_pow ])
20432042def local_pow_specialize (fgraph , node ):
2044- # here, we are past the point of canonicalization, so we don't want
2045- # to put in un-necessary fills.
20462043 if node .op == at_pow :
20472044 # the idea here is that we have pow(x, y)
20482045 odtype = node .outputs [0 ].dtype
@@ -2057,7 +2054,7 @@ def local_pow_specialize(fgraph, node):
20572054 if np .all (y == 1 ):
20582055 rval = [xsym ]
20592056 if np .all (y == 0 ):
2060- rval = [fill ( xsym , np . asarray ( 1 , dtype = odtype ) )]
2057+ rval = [alloc_like ( 1 , xsym , fgraph )]
20612058 if np .all (y == 0.5 ):
20622059 rval = [sqrt (xsym )]
20632060 if np .all (y == - 0.5 ):
@@ -2158,9 +2155,7 @@ def local_mul_specialize(fgraph, node):
21582155 mul(-1, x, y) -/-> neg(mul(x, y))
21592156
21602157 """
2161- # here, we are past the point of canonicalization, so we don't
2162- # want to put in un-necessary fills.
2163- #
2158+
21642159 # at this point [post canonicalize], mul() may have many inputs.
21652160 if node .op == mul :
21662161 # the idea here is that we have pow(x, y)
@@ -2221,16 +2216,7 @@ def local_mul_specialize(fgraph, node):
22212216
22222217@register_specialize
22232218@node_rewriter ([add ])
2224- def local_add_specialize (fgraph , node ):
2225- """Remove zeros from ``add``s.
2226-
2227- TODO: This should be a canonicalization, no?
2228- """
2229- # here, we are past the point of canonicalization, so we don't want
2230- # to put in un-necessary fills.
2231- if node .op != add :
2232- return False
2233-
2219+ def local_add_remove_zeros (fgraph , node ):
22342220 new_inputs = []
22352221 for inp in node .inputs :
22362222 try :
@@ -2253,12 +2239,12 @@ def local_add_specialize(fgraph, node):
22532239 # Reuse call to constant for cache()
22542240 cst = constant (np .zeros ((1 ,) * ndim , dtype = dtype ))
22552241 assert cst .type .broadcastable == (True ,) * ndim
2256- return [broadcast_arrays (cst , * node . inputs )[ 0 ] ]
2242+ return [alloc_like (cst , node_output , fgraph ) ]
22572243
22582244 if len (new_inputs ) == 1 :
2259- ret = [broadcast_arrays (new_inputs [0 ], * node . inputs )[ 0 ] ]
2245+ ret = [alloc_like (new_inputs [0 ], node_output , fgraph ) ]
22602246 else :
2261- ret = [broadcast_arrays (add (* new_inputs ), * node . inputs )[ 0 ] ]
2247+ ret = [alloc_like (add (* new_inputs ), node_output , fgraph ) ]
22622248
22632249 # The dtype should not be changed. It can happen if the input
22642250 # that was forcing upcasting was equal to 0.
@@ -2376,7 +2362,7 @@ def local_log1p(fgraph, node):
23762362 ninp = nonconsts [0 ]
23772363 if ninp .dtype != log_arg .type .dtype :
23782364 ninp = ninp .astype (node .outputs [0 ].dtype )
2379- return [broadcast_arrays (log1p (ninp ), * scalar_inputs ) [0 ]]
2365+ return [alloc_like (log1p (ninp ), node . outputs [0 ], fgraph ) ]
23802366
23812367 elif log_arg .owner and log_arg .owner .op == sub :
23822368 one = extract_constant (log_arg .owner .inputs [0 ], only_process_constants = True )
@@ -3572,10 +3558,11 @@ def local_reciprocal_1_plus_exp(fgraph, node):
35723558 if nonconsts [0 ].owner and nonconsts [0 ].owner .op == exp :
35733559 if scalars_ and np .allclose (np .sum (scalars_ ), 1 ):
35743560 out = [
3575- broadcast_arrays (
3561+ alloc_like (
35763562 sigmoid (neg (nonconsts [0 ].owner .inputs [0 ])),
3577- * scalar_inputs ,
3578- )[0 ]
3563+ node .outputs [0 ],
3564+ fgraph ,
3565+ )
35793566 ]
35803567 # keep combined stack traces of
35813568 # exp(x): nonconsts[0],
0 commit comments