2323"""
2424
2525import logging
26- from typing import TYPE_CHECKING , Optional , Union
26+ from typing import Union
2727
2828import numpy as np
2929
6565)
6666from pytensor .tensor .elemwise import DimShuffle , Elemwise
6767from pytensor .tensor .exceptions import NotScalarConstantError
68- from pytensor .tensor .extra_ops import broadcast_shape , broadcast_to
68+ from pytensor .tensor .extra_ops import broadcast_arrays
6969from pytensor .tensor .math import Sum , add
7070from pytensor .tensor .math import all as at_all
7171from pytensor .tensor .math import eq
72- from pytensor .tensor .shape import Shape_i
72+ from pytensor .tensor .shape import Shape_i , shape_padleft
7373from pytensor .tensor .sort import TopKOp
7474from pytensor .tensor .type import DenseTensorType , TensorType
7575from pytensor .tensor .var import TensorConstant , TensorVariable
7676from pytensor .utils import NoDuplicateOptWarningFilter
7777
7878
79- if TYPE_CHECKING :
80- from pytensor .tensor .rewriting .shape import ShapeFeature
81-
82-
8379_logger = logging .getLogger ("pytensor.tensor.rewriting.basic" )
8480_logger .addFilter (NoDuplicateOptWarningFilter ())
8581
@@ -261,31 +257,16 @@ def local_scalar_tensor_scalar(fgraph, node):
261257def local_elemwise_alloc (fgraph , node ):
262258 r"""Remove unnecessary `Alloc`\s that occur as inputs of `Elemwise` `Op`\s.
263259
264- `Alloc`\s are effectively a type of `Elemwise` operation
265- (e.g. ``Elemwise{second}(y, x)`` is the same as ``Alloc(x, *y.shape)``), so
266- this rewrite uses that fact to reduce `Elemwise`\s on `Alloc`\s to
267- `Elemwise`\s of the `Alloc`\s first/value input (i.e. the value it
268- broadcasts).
269-
270- In other words, this rewrite causes `Elemwise` `Op`\s to "absorb" redundant
271- `Alloc`\s.
272-
273260 The rewrite essentially performs the following replacement:
274- ``Elemwise{op}(..., Alloc(x, s), ..., y, ...) -> Elemwise{op}(..., x, ..., y, ...)``,
275- when ``y.shape`` for some input ``y`` (or the combined shapes of the
276- non-`Alloc`\s) is sufficient to maintain the same/correct output shape.
261+ ``Elemwise{op}(..., Alloc(x, s), ..., y, ...) -> Elemwise{op}(..., x, ..., y, ...)``
277262
278- In it's current form, it also explicitly accounts for `DimShuffle`\s of
263+ In its current form, it also explicitly accounts for `DimShuffle`\s of
279264 `Alloc`\s. This is largely due to `local_alloc_sink_dimshuffle`, which
280265 introduces them as a canonicalization of `Alloc`'s with leading
281266 broadcastable dimensions.
282267 """
283- # Rewrite is only applicable when there are at least two inputs
284268 if len (node .inputs ) == 1 :
285- return False
286-
287- if len (node .outputs ) > 1 :
288- return False
269+ return None
289270
290271 def dimshuffled_alloc (i ):
291272 return (
@@ -305,76 +286,40 @@ def dimshuffled_alloc(i):
305286 if len (alloc_idxs ) == 0 :
306287 return False
307288
308- # Search for a non `Alloc` or `DimShuffle` of `Alloc` input that we can use as a
309- # baseline for the dimensions.
310- ref_var_idx = None
311- for idx , i in enumerate (node .inputs ):
312- if i .type .broadcastable == node .outputs [0 ].type .broadcastable :
313- # Prefer an input that is not an `Alloc` nor a `DimShuffle` of an
314- # `Alloc`, so that all `Alloc`s can be rewritten.
315- if idx not in alloc_idxs :
316- ref_var_idx = idx
317- break
318-
319- # If only `Alloc` and `DimShuffle` of `Alloc` exist, we pick the first suitable one
320- if ref_var_idx is None :
321- for idx , i in enumerate (node .inputs ):
322- # XXX: This broadcastable comparison doesn't work
323- if (
324- i .type .broadcastable == node .outputs [0 ].type .broadcastable
325- ) and idx in alloc_idxs :
326- ref_var_idx = idx
327- break
328-
329- if not hasattr (fgraph , "shape_feature" ):
330- return False
331-
332- input_shapes = [
333- tuple (fgraph .shape_feature .get_shape (i , j ) for j in range (i .type .ndim ))
334- for i in node .inputs
335- ]
336- bcasted_shape = broadcast_shape (
337- * input_shapes ,
338- arrays_are_shapes = True ,
339- )
340-
341289 new_inputs = list (node .inputs )
342290 for idx in alloc_idxs :
343291 i = node .inputs [idx ]
344292
345- # Remove `Alloc`
293+ # Remove simple `Alloc`
346294 if isinstance (i .owner .op , Alloc ):
347- new_alloc = broadcast_to ( i .owner .inputs [0 ], bcasted_shape )
295+ new_inp = i .owner .inputs [0 ]
348296
349- # TODO FIXME: This shouldn't be handled here.
350- # `DimShuffle`s should be lifted through `Alloc`s
351- # by other, more general rewrites.
352- # Remove `Alloc` in `DimShuffle`
297+ # Remove `Dimshuffle(Alloc)`
353298 elif isinstance (i .owner .op , DimShuffle ):
354299 old_alloc = i .owner .inputs [0 ]
355- new_alloc = old_alloc .owner .inputs [0 ]
300+ old_alloc_inp = old_alloc .owner .inputs [0 ]
301+ missing_ndims = old_alloc .type .ndim - old_alloc_inp .type .ndim
302+ if missing_ndims > 0 :
303+ # The `Alloc` added new dimensions to the left.
304+ # We replace those cases with a `DimShuffle` here.
305+ # Nested dimshuffles will be merged later by other rewrites.
306+ old_alloc_inp = shape_padleft (old_alloc_inp , missing_ndims )
356307 # We need to keep the old `DimShuffle`. It could swap axes or
357308 # add dimensions anywhere.
358- if new_alloc .ndim != old_alloc .ndim :
359- # The `Alloc` can add dimensions to the value.
360- # We replace those cases with a `DimShuffle` here.
361- nb_dim_to_add = old_alloc .ndim - new_alloc .ndim
362- new_alloc = new_alloc .dimshuffle (
363- ["x" ] * nb_dim_to_add + list (range (new_alloc .ndim ))
364- )
365- new_alloc = broadcast_to (i .owner .op (new_alloc ), bcasted_shape )
309+ new_inp = i .owner .op (old_alloc_inp )
366310
367- copy_stack_trace (i , new_alloc )
368- new_inputs [idx ] = new_alloc
311+ copy_stack_trace (i , new_inp )
312+ new_inputs [idx ] = new_inp
369313
370- # If this assert is triggered, it means we are recreating an equivalent graph
371- # which would result in cyclical merge rewrites.
372- if all (new is old for new , old in zip (new_inputs , node .inputs )):
373- return
314+ new_outs = node .op (* new_inputs , return_list = True )
374315
375- ret = node .op (* new_inputs , return_list = True )
376- copy_stack_trace (node .outputs , ret )
377- return ret
316+ if new_outs [0 ].type .broadcastable != node .outputs [0 ].type .broadcastable :
317+ new_outs = [
318+ alloc_like (new_out , node .outputs [0 ], fgraph ) for new_out in new_outs
319+ ]
320+
321+ copy_stack_trace (node .outputs , new_outs )
322+ return new_outs
378323
379324
380325@register_canonicalize ("shape_unsafe" )
@@ -406,6 +351,7 @@ def local_fill_sink(fgraph, node):
406351
407352 # The newly created node c doesn't has 'clients',
408353 # so this iteration is took place with node.outputs[0]
354+ # TODO: This should just be a WalkingGraphRewrite!
409355 replacements = {node .outputs [0 ]: c }
410356 for client , cl_idx in fgraph .clients [node .outputs [0 ]]:
411357 if (
@@ -438,23 +384,15 @@ def local_fill_to_alloc(fgraph, node):
438384 with their dependencies on those tensors' shapes, and sometimes those
439385 shapes can be computed without needing to compute the tensors themselves.
440386
441- XXX: This rewrite can produce inconsistent results, so do *not* consider
442- making it a canonicalization until those inconsistencies are
443- resolved/justified.
387+ Like `local_fill_sink` this rewrites assumes non-broadcastable shapes are equivalent,
388+ which could mask shape errors.
444389 """
445390 shape_ref , values_ref = node .inputs
446391 out_type = node .outputs [0 ].type
447392
448393 if values_ref .type .broadcastable == out_type .broadcastable :
449394 # The assumption here is that `values_ref` already has the same shape
450395 # as `shape_ref`, so a `fill`/`Alloc` is unnecessary.
451-
452- # XXX FIXME TODO: The only way this can be determined is if one
453- # absolutely knows that the shapes of `shape_ref` and `values_ref` are
454- # equal.
455- # This is an old rewrite, and it's only a
456- # "specialization/stabilization", so we're going to leave it be for
457- # now.
458396 return [values_ref ]
459397
460398 if shape_ref .type .broadcastable == out_type .broadcastable :
@@ -465,6 +403,9 @@ def local_fill_to_alloc(fgraph, node):
465403 copy_stack_trace (node .outputs [0 ], o )
466404 return [o ]
467405
406+ # The case that is not covered is when `shape_ref` is broadcasted by `values_ref`
407+ # TODO: Return broadcast_to(values_ref, broadcast_shapes(values_ref.shape, shape_ref.shape))
408+
468409 return
469410
470411
@@ -1014,36 +955,30 @@ def local_sum_make_vector(fgraph, node):
1014955 return [element_sum ]
1015956
1016957
1017- @register_useless ("local_remove_switch_const_cond " )
1018- @register_canonicalize ("fast_compile" , "local_remove_switch_const_cond " )
1019- @register_specialize
1020- @node_rewriter ([Elemwise ])
958+ @register_useless ("shape_unsafe " )
959+ @register_canonicalize ("fast_compile" , "shape_unsafe " )
960+ @register_specialize ( "shape_unsafe" )
961+ @node_rewriter ([switch ])
1021962def local_useless_switch (fgraph , node ):
1022963 """
1023964 This rewrite makes the following changes in a graph:
1024965
1025- at. switch(cond, left, right) ->
1026- if cond is constant and cond == 0: right
1027- if cond is constant and cond != 0: left
1028- if left is right -> left
966+ switch(cond, left, right) ->
967+ if cond is constant and cond == 0: right
968+ if cond is constant and cond != 0: left
969+ if left is right -> left
1029970
1030971 and
1031972
1032- at. switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
973+ switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
1033974
1034975 """
1035- if not isinstance (node .op .scalar_op , aes .Switch ):
1036- return False
1037-
1038- shape_feature : Optional ["ShapeFeature" ] = getattr (fgraph , "shape_feature" , None )
1039-
1040- if shape_feature is None :
1041- return False
1042976
1043977 left = node .inputs [1 ]
1044978 right = node .inputs [2 ]
1045979 cond_var = node .inputs [0 ]
1046980 cond = extract_constant (cond_var , only_process_constants = True )
981+ out_bcast = node .outputs [0 ].type .broadcastable
1047982
1048983 if (isinstance (cond , np .ndarray ) and cond .ndim == 0 ) or isinstance (
1049984 cond , (np .number , np .bool_ )
@@ -1058,14 +993,8 @@ def local_useless_switch(fgraph, node):
1058993 else :
1059994 out = correct_out
1060995
1061- input_shapes = [
1062- tuple (shape_feature .get_shape (inp , i ) for i in range (inp .type .ndim ))
1063- for inp in node .inputs
1064- ]
1065-
1066- out_shape = broadcast_shape (* input_shapes , arrays_are_shapes = True )
1067-
1068- out = alloc (out , * out_shape )
996+ if out .type .broadcastable != out_bcast :
997+ out = broadcast_arrays (out , * node .inputs )[0 ]
1069998
1070999 # Copy over stacktrace from selected output to new output
10711000 copy_stack_trace (node .outputs + correct_out , out )
@@ -1075,10 +1004,10 @@ def local_useless_switch(fgraph, node):
10751004 if left == right :
10761005 # Note: No need to copy over stacktrace, because the input node
10771006 # already has its own stacktrace
1078- if cond . type . is_super ( left .type ) :
1007+ if left .type . broadcastable == out_bcast :
10791008 return [left ]
10801009
1081- ret = fill ( cond , left )
1010+ ret = broadcast_arrays ( left , cond )[ 0 ]
10821011
10831012 # Copy over stacktrace from switch output and correct branch
10841013 copy_stack_trace (node .outputs + left , ret )
0 commit comments