@@ -958,6 +958,25 @@ def local_sum_make_vector(fgraph, node):
958958 return [element_sum ]
959959
960960
961+ def equivalent_up_to_constant_casting (a , b ) -> bool :
962+ """Return True if a and b are equivalent up to constant casting."""
963+ if a == b :
964+ return True
965+ # Return equivalence based on data values, ignoring dtype
966+ if (
967+ isinstance (a , TensorConstant )
968+ and isinstance (b , TensorConstant )
969+ and a .type .shape == b .type .shape
970+ # We don't want to spend a lot of time comparing large constant arrays
971+ # First, check if dtype matches, otherwise a == b would be true if they hold the same values
972+ and a .type .dtype != b .type .dtype
973+ # Check property sum() that is cached for TensorConstants, to filter down candidates even more
974+ and a .signature ().sum == b .signature ().sum
975+ ):
976+ return np .array_equal (a .data , b .data )
977+ return False
978+
979+
961980@register_useless ("shape_unsafe" )
962981@register_canonicalize ("fast_compile" , "shape_unsafe" )
963982@register_specialize ("shape_unsafe" )
@@ -1004,17 +1023,19 @@ def local_useless_switch(fgraph, node):
10041023 return [out ]
10051024
10061025 # if left is right -> left
1007- if left == right :
1008- # Note: No need to copy over stacktrace, because the input node
1009- # already has its own stacktrace
1026+ if equivalent_up_to_constant_casting (left , right ):
10101027 if left .type .broadcastable == out_bcast :
1028+ out_dtype = node .outputs [0 ].type .dtype
1029+ if left .type .dtype != out_dtype :
1030+ left = cast (left , out_dtype )
1031+ copy_stack_trace (node .outputs + left , left )
1032+ # When not casting, the other inputs of the switch aren't needed in the traceback
10111033 return [left ]
10121034
1013- ret = broadcast_arrays (left , cond )[0 ]
1014-
1015- # Copy over stacktrace from switch output and correct branch
1016- copy_stack_trace (node .outputs + left , ret )
1017- return [ret ]
1035+ else :
1036+ ret = broadcast_arrays (left , cond )[0 ]
1037+ copy_stack_trace (node .outputs + left , ret )
1038+ return [ret ]
10181039
10191040 # This case happens with scan.
10201041 # Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
0 commit comments