Skip to content

Commit 3072693

Browse files
l46kokcopybara-github
authored andcommitted
Persist lazily bound variables in the correct scoped resolver
PiperOrigin-RevId: 840535220
1 parent 741ad14 commit 3072693

File tree

9 files changed

+825
-542
lines changed

9 files changed

+825
-542
lines changed

extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import dev.cel.runtime.CelRuntime;
3939
import dev.cel.runtime.CelRuntimeFactory;
4040
import java.util.Arrays;
41+
import java.util.List;
4142
import java.util.concurrent.atomic.AtomicInteger;
4243
import org.junit.Test;
4344
import org.junit.runner.RunWith;
@@ -243,4 +244,76 @@ public void lazyBinding_withNestedBinds() throws Exception {
243244
assertThat(result).isTrue();
244245
assertThat(invocation.get()).isEqualTo(2);
245246
}
247+
248+
@Test
249+
@SuppressWarnings({"Immutable", "unchecked"}) // Test only
250+
public void lazyBinding_boundAttributeInComprehension() throws Exception {
251+
CelCompiler celCompiler =
252+
CelCompilerFactory.standardCelCompilerBuilder()
253+
.setStandardMacros(CelStandardMacro.MAP)
254+
.addLibraries(CelExtensions.bindings())
255+
.addFunctionDeclarations(
256+
CelFunctionDecl.newFunctionDeclaration(
257+
"get_true",
258+
CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)))
259+
.build();
260+
AtomicInteger invocation = new AtomicInteger();
261+
CelRuntime celRuntime =
262+
CelRuntimeFactory.standardCelRuntimeBuilder()
263+
.addFunctionBindings(
264+
CelFunctionBinding.from(
265+
"get_true_overload",
266+
ImmutableList.of(),
267+
arg -> {
268+
invocation.getAndIncrement();
269+
return true;
270+
}))
271+
.build();
272+
273+
CelAbstractSyntaxTree ast =
274+
celCompiler.compile("cel.bind(x, get_true(), [1,2,3].map(y, y < 0 || x))").getAst();
275+
276+
List<Boolean> result = (List<Boolean>) celRuntime.createProgram(ast).eval();
277+
278+
assertThat(result).containsExactly(true, true, true);
279+
assertThat(invocation.get()).isEqualTo(1);
280+
}
281+
282+
@Test
283+
@SuppressWarnings({"Immutable"}) // Test only
284+
public void lazyBinding_boundAttributeInNestedComprehension() throws Exception {
285+
CelCompiler celCompiler =
286+
CelCompilerFactory.standardCelCompilerBuilder()
287+
.setStandardMacros(CelStandardMacro.EXISTS)
288+
.addLibraries(CelExtensions.bindings())
289+
.addFunctionDeclarations(
290+
CelFunctionDecl.newFunctionDeclaration(
291+
"get_true",
292+
CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)))
293+
.build();
294+
AtomicInteger invocation = new AtomicInteger();
295+
CelRuntime celRuntime =
296+
CelRuntimeFactory.standardCelRuntimeBuilder()
297+
.addFunctionBindings(
298+
CelFunctionBinding.from(
299+
"get_true_overload",
300+
ImmutableList.of(),
301+
arg -> {
302+
invocation.getAndIncrement();
303+
return true;
304+
}))
305+
.build();
306+
307+
CelAbstractSyntaxTree ast =
308+
celCompiler
309+
.compile(
310+
"cel.bind(x, get_true(), [1,2,3].exists(unused, x && "
311+
+ "['a','b','c'].exists(unused_2, x)))")
312+
.getAst();
313+
314+
boolean result = (boolean) celRuntime.createProgram(ast).eval();
315+
316+
assertThat(result).isTrue();
317+
assertThat(invocation.get()).isEqualTo(1);
318+
}
246319
}

optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java

Lines changed: 113 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616

1717
import static com.google.common.base.Preconditions.checkNotNull;
1818
import static com.google.common.collect.ImmutableList.toImmutableList;
19+
import static com.google.common.collect.ImmutableSet.toImmutableSet;
1920
import static java.util.stream.Collectors.toCollection;
2021

2122
import com.google.auto.value.AutoValue;
2223
import com.google.common.annotations.VisibleForTesting;
2324
import com.google.common.base.Preconditions;
25+
import com.google.common.base.Strings;
2426
import com.google.common.base.Verify;
2527
import com.google.common.collect.ImmutableList;
2628
import com.google.common.collect.ImmutableSet;
@@ -41,8 +43,11 @@
4143
import dev.cel.common.CelVarDecl;
4244
import dev.cel.common.ast.CelExpr;
4345
import dev.cel.common.ast.CelExpr.CelCall;
46+
import dev.cel.common.ast.CelExpr.CelComprehension;
47+
import dev.cel.common.ast.CelExpr.CelList;
4448
import dev.cel.common.ast.CelExpr.ExprKind.Kind;
4549
import dev.cel.common.ast.CelMutableExpr;
50+
import dev.cel.common.ast.CelMutableExpr.CelMutableComprehension;
4651
import dev.cel.common.ast.CelMutableExprConverter;
4752
import dev.cel.common.navigation.CelNavigableExpr;
4853
import dev.cel.common.navigation.CelNavigableMutableAst;
@@ -60,6 +65,7 @@
6065
import java.util.HashSet;
6166
import java.util.List;
6267
import java.util.Set;
68+
import java.util.stream.Stream;
6369

6470
/**
6571
* Performs Common Subexpression Elimination.
@@ -90,14 +96,15 @@ public class SubexpressionOptimizer implements CelAstOptimizer {
9096
private static final SubexpressionOptimizer INSTANCE =
9197
new SubexpressionOptimizer(SubexpressionOptimizerOptions.newBuilder().build());
9298
private static final String BIND_IDENTIFIER_PREFIX = "@r";
93-
private static final String MANGLED_COMPREHENSION_ITER_VAR_PREFIX = "@it";
94-
private static final String MANGLED_COMPREHENSION_ITER_VAR2_PREFIX = "@it2";
95-
private static final String MANGLED_COMPREHENSION_ACCU_VAR_PREFIX = "@ac";
9699
private static final String CEL_BLOCK_FUNCTION = "cel.@block";
97100
private static final String BLOCK_INDEX_PREFIX = "@index";
98101
private static final Extension CEL_BLOCK_AST_EXTENSION_TAG =
99102
Extension.create("cel_block", Version.of(1L, 1L), Component.COMPONENT_RUNTIME);
100103

104+
@VisibleForTesting static final String MANGLED_COMPREHENSION_ITER_VAR_PREFIX = "@it";
105+
@VisibleForTesting static final String MANGLED_COMPREHENSION_ITER_VAR2_PREFIX = "@it2";
106+
@VisibleForTesting static final String MANGLED_COMPREHENSION_ACCU_VAR_PREFIX = "@ac";
107+
101108
private final SubexpressionOptimizerOptions cseOptions;
102109
private final AstMutator astMutator;
103110
private final ImmutableSet<String> cseEliminableFunctions;
@@ -269,6 +276,8 @@ static void verifyOptimizedAstCorrectness(CelAbstractSyntaxTree ast) {
269276
Verify.verify(
270277
resultHasAtLeastOneBlockIndex,
271278
"Expected at least one reference of index in cel.block result");
279+
280+
verifyNoInvalidScopedMangledVariables(celBlockExpr);
272281
}
273282

274283
private static void verifyBlockIndex(CelExpr celExpr, int maxIndexValue) {
@@ -289,6 +298,67 @@ private static void verifyBlockIndex(CelExpr celExpr, int maxIndexValue) {
289298
celExpr);
290299
}
291300

301+
private static void verifyNoInvalidScopedMangledVariables(CelExpr celExpr) {
302+
CelCall celBlockCall = celExpr.call();
303+
CelExpr blockBody = celBlockCall.args().get(1);
304+
305+
ImmutableSet<String> allMangledVariablesInBlockBody =
306+
CelNavigableExpr.fromExpr(blockBody)
307+
.allNodes()
308+
.map(CelNavigableExpr::expr)
309+
.flatMap(SubexpressionOptimizer::extractMangledNames)
310+
.collect(toImmutableSet());
311+
312+
CelList blockIndices = celBlockCall.args().get(0).list();
313+
for (CelExpr blockIndex : blockIndices.elements()) {
314+
ImmutableSet<String> indexDeclaredCompVariables =
315+
CelNavigableExpr.fromExpr(blockIndex)
316+
.allNodes()
317+
.map(CelNavigableExpr::expr)
318+
.filter(expr -> expr.getKind() == Kind.COMPREHENSION)
319+
.map(CelExpr::comprehension)
320+
.flatMap(comp -> Stream.of(comp.iterVar(), comp.iterVar2()))
321+
.filter(iter -> !Strings.isNullOrEmpty(iter))
322+
.collect(toImmutableSet());
323+
324+
boolean containsIllegalDeclaration =
325+
CelNavigableExpr.fromExpr(blockIndex)
326+
.allNodes()
327+
.map(CelNavigableExpr::expr)
328+
.filter(expr -> expr.getKind() == Kind.IDENT)
329+
.map(expr -> expr.ident().name())
330+
.filter(SubexpressionOptimizer::isMangled)
331+
.anyMatch(
332+
ident ->
333+
!indexDeclaredCompVariables.contains(ident)
334+
&& allMangledVariablesInBlockBody.contains(ident));
335+
336+
Verify.verify(
337+
!containsIllegalDeclaration,
338+
"Illegal declared reference to a comprehension variable found in block indices. Expr: %s",
339+
celExpr);
340+
}
341+
}
342+
343+
private static Stream<String> extractMangledNames(CelExpr expr) {
344+
if (expr.getKind().equals(Kind.IDENT)) {
345+
String name = expr.ident().name();
346+
return isMangled(name) ? Stream.of(name) : Stream.empty();
347+
}
348+
if (expr.getKind().equals(Kind.COMPREHENSION)) {
349+
CelComprehension comp = expr.comprehension();
350+
return Stream.of(comp.iterVar(), comp.iterVar2(), comp.accuVar())
351+
.filter(x -> !Strings.isNullOrEmpty(x))
352+
.filter(SubexpressionOptimizer::isMangled);
353+
}
354+
return Stream.empty();
355+
}
356+
357+
private static boolean isMangled(String name) {
358+
return name.startsWith(MANGLED_COMPREHENSION_ITER_VAR_PREFIX)
359+
|| name.startsWith(MANGLED_COMPREHENSION_ITER_VAR2_PREFIX);
360+
}
361+
292362
private static CelAbstractSyntaxTree tagAstExtension(CelAbstractSyntaxTree ast) {
293363
// Tag the extension
294364
CelSource.Builder celSourceBuilder =
@@ -355,8 +425,8 @@ private List<CelMutableExpr> getCseCandidatesWithRecursionDepth(
355425
navAst
356426
.getRoot()
357427
.descendants(TraversalOrder.PRE_ORDER)
358-
.filter(node -> canEliminate(node, ineligibleExprs))
359428
.filter(node -> node.height() <= recursionLimit)
429+
.filter(node -> canEliminate(node, ineligibleExprs))
360430
.sorted(Comparator.comparingInt(CelNavigableMutableExpr::height).reversed())
361431
.collect(toImmutableList());
362432
if (descendants.isEmpty()) {
@@ -441,7 +511,45 @@ private boolean canEliminate(
441511
&& navigableExpr.expr().list().elements().isEmpty())
442512
&& containsEliminableFunctionOnly(navigableExpr)
443513
&& !ineligibleExprs.contains(navigableExpr.expr())
444-
&& containsComprehensionIdentInSubexpr(navigableExpr);
514+
&& containsComprehensionIdentInSubexpr(navigableExpr)
515+
&& containsProperScopedComprehensionIdents(navigableExpr);
516+
}
517+
518+
private boolean containsProperScopedComprehensionIdents(CelNavigableMutableExpr navExpr) {
519+
if (!navExpr.getKind().equals(Kind.COMPREHENSION)) {
520+
return true;
521+
}
522+
523+
// For nested comprehensions of form [1].exists(x, [2].exists(y, x == y)), the inner
524+
// comprehension [2].exists(y, x == y)
525+
// should not be extracted out into a block index, as it causes issues with scoping.
526+
ImmutableSet<String> mangledIterVars =
527+
navExpr
528+
.descendants()
529+
.filter(x -> x.getKind().equals(Kind.IDENT))
530+
.map(x -> x.expr().ident().name())
531+
.filter(
532+
name ->
533+
name.startsWith(MANGLED_COMPREHENSION_ITER_VAR_PREFIX)
534+
|| name.startsWith(MANGLED_COMPREHENSION_ITER_VAR2_PREFIX))
535+
.collect(toImmutableSet());
536+
537+
CelNavigableMutableExpr parent = navExpr.parent().orElse(null);
538+
while (parent != null) {
539+
if (parent.getKind().equals(Kind.COMPREHENSION)) {
540+
CelMutableComprehension comp = parent.expr().comprehension();
541+
boolean containsParentIterReferences =
542+
mangledIterVars.contains(comp.iterVar()) || mangledIterVars.contains(comp.iterVar2());
543+
544+
if (containsParentIterReferences) {
545+
return false;
546+
}
547+
}
548+
549+
parent = parent.parent().orElse(null);
550+
}
551+
552+
return true;
445553
}
446554

447555
private boolean containsComprehensionIdentInSubexpr(CelNavigableMutableExpr navExpr) {

optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
import dev.cel.runtime.CelFunctionBinding;
5656
import dev.cel.runtime.CelRuntime;
5757
import dev.cel.runtime.CelRuntimeFactory;
58+
import java.util.List;
5859
import java.util.concurrent.atomic.AtomicInteger;
5960
import org.junit.Test;
6061
import org.junit.runner.RunWith;
@@ -381,6 +382,31 @@ public void lazyEval_blockIndexEvaluatedOnlyOnce() throws Exception {
381382
assertThat(invocation.get()).isEqualTo(1);
382383
}
383384

385+
@Test
386+
@SuppressWarnings({"Immutable", "unchecked"}) // Test only
387+
public void lazyEval_withinComprehension_blockIndexEvaluatedOnlyOnce() throws Exception {
388+
AtomicInteger invocation = new AtomicInteger();
389+
CelRuntime celRuntime =
390+
CelRuntimeFactory.standardCelRuntimeBuilder()
391+
.addMessageTypes(TestAllTypes.getDescriptor())
392+
.addFunctionBindings(
393+
CelFunctionBinding.from(
394+
"get_true_overload",
395+
ImmutableList.of(),
396+
arg -> {
397+
invocation.getAndIncrement();
398+
return true;
399+
}))
400+
.build();
401+
CelAbstractSyntaxTree ast =
402+
compileUsingInternalFunctions("cel.block([get_true()], [1,2,3].map(x, x < 0 || index0))");
403+
404+
List<Boolean> result = (List<Boolean>) celRuntime.createProgram(ast).eval();
405+
406+
assertThat(result).containsExactly(true, true, true);
407+
assertThat(invocation.get()).isEqualTo(1);
408+
}
409+
384410
@Test
385411
@SuppressWarnings("Immutable") // Test only
386412
public void lazyEval_multipleBlockIndices_inResultExpr() throws Exception {
@@ -452,9 +478,9 @@ public void lazyEval_nestedComprehension_indexReferencedInNestedScopes() throws
452478
// Equivalent of [true, false, true].map(c0, [c0].map(c1, [c0, c1, true]))
453479
CelAbstractSyntaxTree ast =
454480
compileUsingInternalFunctions(
455-
"cel.block([c0, c1, get_true()], [index2, false, index2].map(c0, [c0].map(c1, [index0,"
456-
+ " index1, index2]))) == [[[true, true, true]], [[false, false, true]], [[true,"
457-
+ " true, true]]]");
481+
"cel.block([true, false, get_true()], [index2, false, index2].map(c0, [c0].map(c1, [c0,"
482+
+ " c1, index2]))) == [[[true, true, true]], [[false, false, true]], [[true, true,"
483+
+ " true]]]");
458484

459485
boolean result = (boolean) celRuntime.createProgram(ast).eval();
460486

0 commit comments

Comments
 (0)