1616
1717import static com .google .common .base .Preconditions .checkNotNull ;
1818import static com .google .common .collect .ImmutableList .toImmutableList ;
19+ import static com .google .common .collect .ImmutableSet .toImmutableSet ;
1920import static java .util .stream .Collectors .toCollection ;
2021
2122import com .google .auto .value .AutoValue ;
2223import com .google .common .annotations .VisibleForTesting ;
2324import com .google .common .base .Preconditions ;
25+ import com .google .common .base .Strings ;
2426import com .google .common .base .Verify ;
2527import com .google .common .collect .ImmutableList ;
2628import com .google .common .collect .ImmutableSet ;
4143import dev .cel .common .CelVarDecl ;
4244import dev .cel .common .ast .CelExpr ;
4345import dev .cel .common .ast .CelExpr .CelCall ;
46+ import dev .cel .common .ast .CelExpr .CelComprehension ;
47+ import dev .cel .common .ast .CelExpr .CelList ;
4448import dev .cel .common .ast .CelExpr .ExprKind .Kind ;
4549import dev .cel .common .ast .CelMutableExpr ;
50+ import dev .cel .common .ast .CelMutableExpr .CelMutableComprehension ;
4651import dev .cel .common .ast .CelMutableExprConverter ;
4752import dev .cel .common .navigation .CelNavigableExpr ;
4853import dev .cel .common .navigation .CelNavigableMutableAst ;
6065import java .util .HashSet ;
6166import java .util .List ;
6267import java .util .Set ;
68+ import java .util .stream .Stream ;
6369
6470/**
6571 * Performs Common Subexpression Elimination.
@@ -90,14 +96,15 @@ public class SubexpressionOptimizer implements CelAstOptimizer {
9096 private static final SubexpressionOptimizer INSTANCE =
9197 new SubexpressionOptimizer (SubexpressionOptimizerOptions .newBuilder ().build ());
9298 private static final String BIND_IDENTIFIER_PREFIX = "@r" ;
93- private static final String MANGLED_COMPREHENSION_ITER_VAR_PREFIX = "@it" ;
94- private static final String MANGLED_COMPREHENSION_ITER_VAR2_PREFIX = "@it2" ;
95- private static final String MANGLED_COMPREHENSION_ACCU_VAR_PREFIX = "@ac" ;
9699 private static final String CEL_BLOCK_FUNCTION = "cel.@block" ;
97100 private static final String BLOCK_INDEX_PREFIX = "@index" ;
98101 private static final Extension CEL_BLOCK_AST_EXTENSION_TAG =
99102 Extension .create ("cel_block" , Version .of (1L , 1L ), Component .COMPONENT_RUNTIME );
100103
104+ @ VisibleForTesting static final String MANGLED_COMPREHENSION_ITER_VAR_PREFIX = "@it" ;
105+ @ VisibleForTesting static final String MANGLED_COMPREHENSION_ITER_VAR2_PREFIX = "@it2" ;
106+ @ VisibleForTesting static final String MANGLED_COMPREHENSION_ACCU_VAR_PREFIX = "@ac" ;
107+
101108 private final SubexpressionOptimizerOptions cseOptions ;
102109 private final AstMutator astMutator ;
103110 private final ImmutableSet <String > cseEliminableFunctions ;
@@ -269,6 +276,8 @@ static void verifyOptimizedAstCorrectness(CelAbstractSyntaxTree ast) {
269276 Verify .verify (
270277 resultHasAtLeastOneBlockIndex ,
271278 "Expected at least one reference of index in cel.block result" );
279+
280+ verifyNoInvalidScopedMangledVariables (celBlockExpr );
272281 }
273282
274283 private static void verifyBlockIndex (CelExpr celExpr , int maxIndexValue ) {
@@ -289,6 +298,67 @@ private static void verifyBlockIndex(CelExpr celExpr, int maxIndexValue) {
289298 celExpr );
290299 }
291300
301+ private static void verifyNoInvalidScopedMangledVariables (CelExpr celExpr ) {
302+ CelCall celBlockCall = celExpr .call ();
303+ CelExpr blockBody = celBlockCall .args ().get (1 );
304+
305+ ImmutableSet <String > allMangledVariablesInBlockBody =
306+ CelNavigableExpr .fromExpr (blockBody )
307+ .allNodes ()
308+ .map (CelNavigableExpr ::expr )
309+ .flatMap (SubexpressionOptimizer ::extractMangledNames )
310+ .collect (toImmutableSet ());
311+
312+ CelList blockIndices = celBlockCall .args ().get (0 ).list ();
313+ for (CelExpr blockIndex : blockIndices .elements ()) {
314+ ImmutableSet <String > indexDeclaredCompVariables =
315+ CelNavigableExpr .fromExpr (blockIndex )
316+ .allNodes ()
317+ .map (CelNavigableExpr ::expr )
318+ .filter (expr -> expr .getKind () == Kind .COMPREHENSION )
319+ .map (CelExpr ::comprehension )
320+ .flatMap (comp -> Stream .of (comp .iterVar (), comp .iterVar2 ()))
321+ .filter (iter -> !Strings .isNullOrEmpty (iter ))
322+ .collect (toImmutableSet ());
323+
324+ boolean containsIllegalDeclaration =
325+ CelNavigableExpr .fromExpr (blockIndex )
326+ .allNodes ()
327+ .map (CelNavigableExpr ::expr )
328+ .filter (expr -> expr .getKind () == Kind .IDENT )
329+ .map (expr -> expr .ident ().name ())
330+ .filter (SubexpressionOptimizer ::isMangled )
331+ .anyMatch (
332+ ident ->
333+ !indexDeclaredCompVariables .contains (ident )
334+ && allMangledVariablesInBlockBody .contains (ident ));
335+
336+ Verify .verify (
337+ !containsIllegalDeclaration ,
338+ "Illegal declared reference to a comprehension variable found in block indices. Expr: %s" ,
339+ celExpr );
340+ }
341+ }
342+
343+ private static Stream <String > extractMangledNames (CelExpr expr ) {
344+ if (expr .getKind ().equals (Kind .IDENT )) {
345+ String name = expr .ident ().name ();
346+ return isMangled (name ) ? Stream .of (name ) : Stream .empty ();
347+ }
348+ if (expr .getKind ().equals (Kind .COMPREHENSION )) {
349+ CelComprehension comp = expr .comprehension ();
350+ return Stream .of (comp .iterVar (), comp .iterVar2 (), comp .accuVar ())
351+ .filter (x -> !Strings .isNullOrEmpty (x ))
352+ .filter (SubexpressionOptimizer ::isMangled );
353+ }
354+ return Stream .empty ();
355+ }
356+
357+ private static boolean isMangled (String name ) {
358+ return name .startsWith (MANGLED_COMPREHENSION_ITER_VAR_PREFIX )
359+ || name .startsWith (MANGLED_COMPREHENSION_ITER_VAR2_PREFIX );
360+ }
361+
292362 private static CelAbstractSyntaxTree tagAstExtension (CelAbstractSyntaxTree ast ) {
293363 // Tag the extension
294364 CelSource .Builder celSourceBuilder =
@@ -355,8 +425,8 @@ private List<CelMutableExpr> getCseCandidatesWithRecursionDepth(
355425 navAst
356426 .getRoot ()
357427 .descendants (TraversalOrder .PRE_ORDER )
358- .filter (node -> canEliminate (node , ineligibleExprs ))
359428 .filter (node -> node .height () <= recursionLimit )
429+ .filter (node -> canEliminate (node , ineligibleExprs ))
360430 .sorted (Comparator .comparingInt (CelNavigableMutableExpr ::height ).reversed ())
361431 .collect (toImmutableList ());
362432 if (descendants .isEmpty ()) {
@@ -441,7 +511,45 @@ private boolean canEliminate(
441511 && navigableExpr .expr ().list ().elements ().isEmpty ())
442512 && containsEliminableFunctionOnly (navigableExpr )
443513 && !ineligibleExprs .contains (navigableExpr .expr ())
444- && containsComprehensionIdentInSubexpr (navigableExpr );
514+ && containsComprehensionIdentInSubexpr (navigableExpr )
515+ && containsProperScopedComprehensionIdents (navigableExpr );
516+ }
517+
518+ private boolean containsProperScopedComprehensionIdents (CelNavigableMutableExpr navExpr ) {
519+ if (!navExpr .getKind ().equals (Kind .COMPREHENSION )) {
520+ return true ;
521+ }
522+
523+ // For nested comprehensions of form [1].exists(x, [2].exists(y, x == y)), the inner
524+ // comprehension [2].exists(y, x == y)
525+ // should not be extracted out into a block index, as it causes issues with scoping.
526+ ImmutableSet <String > mangledIterVars =
527+ navExpr
528+ .descendants ()
529+ .filter (x -> x .getKind ().equals (Kind .IDENT ))
530+ .map (x -> x .expr ().ident ().name ())
531+ .filter (
532+ name ->
533+ name .startsWith (MANGLED_COMPREHENSION_ITER_VAR_PREFIX )
534+ || name .startsWith (MANGLED_COMPREHENSION_ITER_VAR2_PREFIX ))
535+ .collect (toImmutableSet ());
536+
537+ CelNavigableMutableExpr parent = navExpr .parent ().orElse (null );
538+ while (parent != null ) {
539+ if (parent .getKind ().equals (Kind .COMPREHENSION )) {
540+ CelMutableComprehension comp = parent .expr ().comprehension ();
541+ boolean containsParentIterReferences =
542+ mangledIterVars .contains (comp .iterVar ()) || mangledIterVars .contains (comp .iterVar2 ());
543+
544+ if (containsParentIterReferences ) {
545+ return false ;
546+ }
547+ }
548+
549+ parent = parent .parent ().orElse (null );
550+ }
551+
552+ return true ;
445553 }
446554
447555 private boolean containsComprehensionIdentInSubexpr (CelNavigableMutableExpr navExpr ) {
0 commit comments