2828from pytensor .graph .fg import FunctionGraph
2929from pytensor .graph .op import compute_test_value
3030from pytensor .graph .replace import clone_replace
31- from pytensor .graph .rewriting .basic import GraphRewriter , in2out , node_rewriter
31+ from pytensor .graph .rewriting .basic import (
32+ GraphRewriter ,
33+ copy_stack_trace ,
34+ in2out ,
35+ node_rewriter ,
36+ )
3237from pytensor .graph .rewriting .db import EquilibriumDB , SequenceDB
38+ from pytensor .graph .rewriting .utils import get_clients_at_depth
3339from pytensor .graph .type import HasShape
3440from pytensor .graph .utils import InconsistencyError
41+ from pytensor .raise_op import Assert
42+ from pytensor .scalar import ScalarConstant
3543from pytensor .scan .op import Scan , ScanInfo
3644from pytensor .scan .utils import (
3745 ScanArgs ,
@@ -1103,6 +1111,71 @@ def sanitize(x):
11031111 return at .as_tensor_variable (x )
11041112
11051113
1114+ @node_rewriter ([Scan ])
1115+ def while_scan_merge_subtensor_last_element (fgraph , scan_node ):
1116+ """
1117+ Replace while_scan_out[abs(min(tap)):][-1] by while_scan_out[-1], for
1118+ recurring outputs, asserting that at least one step occurs.
1119+ Only the first step can be ensured by the inputs alone (i.e., `n_steps > 0`),
1120+ as the while scan could abort earlier anytime after that. This means it is
1121+ not possible to replace while_scan_out[abs(min(tap)):][-i]
1122+ by while_scan_out[-i], for -i != -1.
1123+ """
1124+ op = scan_node .op
1125+
1126+ if not op .info .as_while :
1127+ return None
1128+
1129+ # Optimization is not implemented form mit-mot
1130+ recurrent_outputs = op .outer_mitsot_outs (scan_node .outputs ) + op .outer_sitsot_outs (
1131+ scan_node .outputs
1132+ )
1133+ recurrent_outputs_taps_slices = (
1134+ op .info .mit_sot_in_slices + op .info .sit_sot_in_slices
1135+ )
1136+
1137+ n_steps = scan_node .inputs [0 ]
1138+ non_zero_steps_cond = n_steps > 0
1139+ assert_non_zero_steps_op = Assert ("n_steps > 0" )
1140+
1141+ subtensor_merge_replacements = {}
1142+
1143+ # Iterate over all nodes that are two computations below the while scan
1144+ for node2 in get_clients_at_depth (fgraph , scan_node , depth = 2 ):
1145+ if not isinstance (node2 .op , Subtensor ):
1146+ continue
1147+
1148+ node1 = node2 .inputs [0 ].owner
1149+ if not (node1 and isinstance (node1 .op , Subtensor )):
1150+ continue
1151+
1152+ x = node1 .inputs [0 ]
1153+ if x not in recurrent_outputs :
1154+ continue
1155+
1156+ slice1 = get_idx_list (node1 .inputs , node1 .op .idx_list )
1157+ slice2 = get_idx_list (node2 .inputs , node2 .op .idx_list )
1158+
1159+ min_tap = abs (min (recurrent_outputs_taps_slices [recurrent_outputs .index (x )]))
1160+
1161+ if (
1162+ len (slice1 ) == 1
1163+ and isinstance (slice1 [0 ], slice )
1164+ and isinstance (slice1 [0 ].start , aes .ScalarConstant )
1165+ and slice1 [0 ].start .data == min_tap
1166+ and slice1 [0 ].stop is None
1167+ and slice1 [0 ].step is None
1168+ and len (slice2 ) == 1
1169+ and isinstance (slice2 [0 ], aes .ScalarConstant )
1170+ and slice2 [0 ].data == - 1
1171+ ):
1172+ out = assert_non_zero_steps_op (x [- 1 ], non_zero_steps_cond )
1173+ copy_stack_trace ([node2 .outputs [0 ], node2 .inputs [0 ]], out )
1174+ subtensor_merge_replacements [node2 .outputs [0 ]] = out
1175+
1176+ return subtensor_merge_replacements
1177+
1178+
11061179@node_rewriter ([Scan ])
11071180def save_mem_new_scan (fgraph , node ):
11081181 r"""Graph optimizer that reduces scan memory consumption.
@@ -1124,6 +1197,17 @@ def save_mem_new_scan(fgraph, node):
11241197 that SITSOT output. Only the most recently computed timestep ever needs to
11251198 be kept in memory.
11261199
1200+ There are two ways in which the Scan buffer size is controlled:
1201+ 1. Each recurring output is saved in an input empty tensor x with the initial
1202+ state written at x[:abs(min(taps))]. The remaining x[abs(min(taps)):]
1203+ positions determine how many intermediate results should be stored.
1204+ This rewrite shortens x[abs(min(taps)):] to the smallest possible size.
1205+ 2. Each non-recurrent output (nit-sot) is associated with a scalar integer
1206+ input that determines how many steps should be saved in the perform method.
1207+ This rewrite reduces this number to the smallest possible.
1208+
1209+ The scan perform implementation takes the output sizes into consideration,
1210+ saving the newest results over the oldest ones whenever the buffer is filled.
11271211 """
11281212 if not isinstance (node .op , Scan ):
11291213 return False
@@ -1172,13 +1256,16 @@ def save_mem_new_scan(fgraph, node):
11721256 # index(step) for any output scan actually needs to compute
11731257 # In other words n_steps should be equal to this maximal !
11741258 # Note: if we have a shared variable that gets updated at every step
1175- # of the loop, reducing the number of steps will affect the the
1176- # value of the shared variable after the loop so we need not to
1259+ # of the loop, reducing the number of steps will affect the
1260+ # value of the shared variable after the loop so we cannot
11771261 # change the number of steps in that case. To do this we set
11781262 # global_nsteps to None which is seen as a flag that nothing needs
1179- # to be done
1263+ # to be done.
1264+ # Note: For simplicity while Scans also have global_nsteps set to None.
1265+ # All step optimizations require knowing the shape of the output, which
1266+ # cannot be determined from the inputs alone.
11801267 assert len (node .outputs ) >= c_outs
1181- if len (node .outputs ) == c_outs :
1268+ if len (node .outputs ) == c_outs and not op . info . as_while :
11821269 global_nsteps = {"real" : - 1 , "sym" : []}
11831270 else :
11841271 global_nsteps = None
@@ -1257,7 +1344,7 @@ def save_mem_new_scan(fgraph, node):
12571344 else :
12581345 # there is a **gotcha** here ! Namely, scan returns an
12591346 # array that contains the initial state of the output
1260- # as well. Which means that if have a initial state of
1347+ # as well. Which means that if y has an initial state of
12611348 # length 3, and you look for 5 steps you get an output
12621349 # y of length 8. If you only use y[:5], this does not
12631350 # mean that you only need to loop for 5 steps but
@@ -1285,9 +1372,9 @@ def save_mem_new_scan(fgraph, node):
12851372
12861373 # 2.3. Analyze global_nsteps to figure out for how many steps scan
12871374 # needs to iterate
1288- if global_nsteps is not None :
1375+ if global_nsteps is None :
12891376 nw_steps = node .inputs [0 ]
1290-
1377+ else :
12911378 # there are some symbolic tensors that limit the number of
12921379 # steps
12931380 if len (global_nsteps ["sym" ]) == 0 :
@@ -1303,16 +1390,14 @@ def save_mem_new_scan(fgraph, node):
13031390 real_steps = None
13041391 nw_steps = select_min (select_max (sym_steps , real_steps ), node .inputs [0 ])
13051392
1393+ # FIXME: This is not correct. Scan with 0 steps seems to be supported
13061394 # Make sure the ScanSaveMem optimization never makes the new
13071395 # number of steps to be 0 (this could happen, for instance, if
13081396 # the optimization detects that the outputs of the Scan go through
13091397 # subtensor nodes that end up taking no elements) because Scan with
13101398 # 0 iterations are not supported. Make sure the new number of steps
13111399 # is at least 1.
13121400 nw_steps = select_max (nw_steps , 1 )
1313- else :
1314- nw_steps = node .inputs [0 ]
1315- global_nsteps = None
13161401
13171402 # 2.4 Loop over the clients again now looking just to see how many
13181403 # intermediate steps to store
@@ -1335,19 +1420,33 @@ def save_mem_new_scan(fgraph, node):
13351420 store_steps [i ] = 0
13361421 break
13371422
1338- if i > op_info .n_mit_mot :
1339- length = node .inputs [0 ] + init_l [i ]
1423+ # Special case for recurrent outputs where only the last result
1424+ # is requested. This is needed for this rewrite to apply to
1425+ # do-while Scans at all. Otherwise, `get_canonical_form_slice` in
1426+ # the `else` branch would reintroduce a shape dependency on the
1427+ # original Scan which would lead this rewrite to abort in the end.
1428+ if (
1429+ i <= op .info .n_mit_mot
1430+ and isinstance (this_slice [0 ], ScalarConstant )
1431+ and this_slice [0 ].value == - 1
1432+ ):
1433+ start = nw_steps - 1
13401434 else :
1341- try :
1342- length = shape_of [out ][0 ]
1343- except KeyError :
1344- length = out .shape [0 ]
1345- cf_slice = get_canonical_form_slice (this_slice [0 ], length )
1435+ if i <= op .info .n_mit_mot :
1436+ try :
1437+ length = shape_of [out ][0 ]
1438+ except KeyError :
1439+ length = out .shape [0 ]
1440+ else :
1441+ length = node .inputs [0 ] + init_l [i ]
1442+
1443+ cf_slice = get_canonical_form_slice (this_slice [0 ], length )
1444+
1445+ if isinstance (cf_slice [0 ], slice ):
1446+ start = at .extract_constant (cf_slice [0 ].start )
1447+ else :
1448+ start = at .extract_constant (cf_slice [0 ])
13461449
1347- if isinstance (cf_slice [0 ], slice ):
1348- start = at .extract_constant (cf_slice [0 ].start )
1349- else :
1350- start = at .extract_constant (cf_slice [0 ])
13511450 if start == 0 or store_steps [i ] == 0 :
13521451 store_steps [i ] = 0
13531452 else :
@@ -1498,6 +1597,7 @@ def save_mem_new_scan(fgraph, node):
14981597 nw_input = expand_empty (_nw_input , nw_steps )
14991598 nw_inputs [in_idx ] = nw_input
15001599 else :
1600+ # FIXME: This is never used
15011601 nw_input = nw_inputs [in_idx ][: (initl + nw_steps )]
15021602
15031603 elif (
@@ -1554,8 +1654,8 @@ def save_mem_new_scan(fgraph, node):
15541654 )
15551655 else :
15561656 fslice = sanitize (cnf_slice [0 ])
1557-
15581657 nw_slice = (fslice ,) + tuple (old_slices [1 :])
1658+
15591659 nw_pos = inv_compress_map [idx ]
15601660
15611661 subtens = Subtensor (nw_slice )
@@ -1604,9 +1704,16 @@ def save_mem_new_scan(fgraph, node):
16041704 ) + tuple (old_slices [1 :])
16051705
16061706 else :
1607- position = (
1608- cnf_slice [0 ] - nw_steps - init_l [pos ] + store_steps [pos ]
1609- )
1707+ # Special case when only last value is requested
1708+ if (
1709+ isinstance (old_slices [0 ], ScalarConstant )
1710+ and old_slices [0 ].value == - 1
1711+ ):
1712+ position = old_slices [0 ]
1713+ else :
1714+ position = (
1715+ cnf_slice [0 ] - nw_steps - init_l [pos ] + store_steps [pos ]
1716+ )
16101717
16111718 nw_slice = (sanitize (position ),) + tuple (old_slices [1 :])
16121719 subtens = Subtensor (nw_slice )
@@ -2403,6 +2510,12 @@ def push_out_dot1_scan(fgraph, node):
24032510 position = 5 ,
24042511)
24052512
2513+ scan_eqopt2 .register (
2514+ "while_scan_merge_subtensor_last_element" ,
2515+ in2out (while_scan_merge_subtensor_last_element , ignore_newtrees = True ),
2516+ "fast_run" ,
2517+ "scan" ,
2518+ )
24062519
24072520scan_eqopt2 .register (
24082521 "constant_folding_for_scan2" ,
0 commit comments