File tree Expand file tree Collapse file tree 2 files changed +17
-22
lines changed
pytensor/tensor/rewriting Expand file tree Collapse file tree 2 files changed +17
-22
lines changed Original file line number Diff line number Diff line change 4949from pytensor .tensor .shape import Shape_i
5050from pytensor .tensor .sort import TopKOp
5151from pytensor .tensor .type import DenseTensorType , TensorType
52- from pytensor .tensor .var import TensorConstant
52+ from pytensor .tensor .var import TensorConstant , TensorVariable
5353from pytensor .utils import NoDuplicateOptWarningFilter
5454
5555
6161_logger .addFilter (NoDuplicateOptWarningFilter ())
6262
6363
64- def encompasses_broadcastable ( b1 , b2 ) :
65- """
64+ def broadcasted_by ( x : TensorVariable , y : TensorVariable ) -> bool :
65+ """Check whether x would be broadcasted by y in an Elemwise operation
6666
6767 Parameters
6868 ----------
69- b1
70- The broadcastable attribute of a tensor type.
71- b2
72- The broadcastable attribute of a tensor type.
69+ x: TensorVariable
70+ The variable that may be broadcasted by y
71+ y: TensorVariable
72+ The variable that may broadcast x
7373
7474 Returns
7575 -------
76- bool
77- True if the broadcastable patterns b1 and b2 are such that b2 is
78- broadcasted to b1's shape and not the opposite.
79-
76+ broadcasted_by: bool
8077 """
81- if len (b1 ) < len (b2 ):
82- return False
83- b1 = b1 [- len (b2 ) :]
84- return not any (v1 and not v2 for v1 , v2 in zip (b1 , b2 ))
78+ bx = x .type .broadcastable
79+ by = y .type .broadcastable
80+ if len (bx ) < len (by ):
81+ return True
82+ bx = bx [- len (by ) :]
83+ return any (bx_dim and not by_dim for bx_dim , by_dim in zip (bx , by ))
8584
8685
8786def merge_broadcastables (broadcastables ):
Original file line number Diff line number Diff line change 8585from pytensor .tensor .math import true_div
8686from pytensor .tensor .rewriting .basic import (
8787 broadcast_like ,
88- encompasses_broadcastable ,
88+ broadcasted_by ,
8989 local_fill_sink ,
9090 register_canonicalize ,
9191 register_specialize ,
@@ -2049,9 +2049,7 @@ def local_pow_specialize(fgraph, node):
20492049 xsym = node .inputs [0 ]
20502050 ysym = node .inputs [1 ]
20512051 y = get_constant (ysym )
2052- if (y is not None ) and encompasses_broadcastable (
2053- xsym .type .broadcastable , ysym .type .broadcastable
2054- ):
2052+ if (y is not None ) and not broadcasted_by (xsym , ysym ):
20552053 rval = None
20562054
20572055 if np .all (y == 2 ):
@@ -2107,9 +2105,7 @@ def local_pow_to_nested_squaring(fgraph, node):
21072105 y = y [0 ]
21082106 except IndexError :
21092107 pass
2110- if (y is not None ) and encompasses_broadcastable (
2111- xsym .type .broadcastable , ysym .type .broadcastable
2112- ):
2108+ if (y is not None ) and not broadcasted_by (xsym , ysym ):
21132109 rval = None
21142110 # 512 is too small for the cpu and too big for some gpu!
21152111 if abs (y ) == int (abs (y )) and abs (y ) <= 512 :
You can’t perform that action at this time.
0 commit comments