From a6bb96649c5937321e92a4ac275f2f33b695fc3b Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Wed, 3 Dec 2025 12:02:27 -0800 Subject: [PATCH 1/2] Fix --- .../extensions/CelBindingsExtensionsTest.java | 37 +++++++++++++++++++ .../SubexpressionOptimizerTest.java | 26 +++++++++++++ .../dev/cel/runtime/DefaultInterpreter.java | 22 +++++++++-- .../cel/runtime/RuntimeUnknownResolver.java | 17 ++++++++- 4 files changed, 96 insertions(+), 6 deletions(-) diff --git a/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java index 33dd9db39..b70c58bf0 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java @@ -22,6 +22,8 @@ import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; +import dev.cel.bundle.Cel; +import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOptions; @@ -38,6 +40,7 @@ import dev.cel.runtime.CelRuntime; import dev.cel.runtime.CelRuntimeFactory; import java.util.Arrays; +import java.util.List; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Test; import org.junit.runner.RunWith; @@ -243,4 +246,38 @@ public void lazyBinding_withNestedBinds() throws Exception { assertThat(result).isTrue(); assertThat(invocation.get()).isEqualTo(2); } + + + @Test + @SuppressWarnings("Immutable") // Test only + public void lazyBinding_boundAttributeInComprehension() throws Exception { + CelCompiler celCompiler = + CelCompilerFactory.standardCelCompilerBuilder() + .setStandardMacros(CelStandardMacro.MAP) + .addLibraries(CelExtensions.bindings()) + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "get_true", + CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL))) + .build(); + AtomicInteger invocation = new AtomicInteger(); + CelRuntime celRuntime = + CelRuntimeFactory.standardCelRuntimeBuilder() + .addFunctionBindings( + CelFunctionBinding.from( + "get_true_overload", + ImmutableList.of(), + arg -> { + invocation.getAndIncrement(); + return true; + })) + .build(); + + CelAbstractSyntaxTree ast = celCompiler.compile("cel.bind(x, get_true(), [1,2,3].map(y, y < 0 || x))").getAst(); + + List result = (List) celRuntime.createProgram(ast).eval(); + + assertThat(result).containsExactly(true, true, true); + assertThat(invocation.get()).isEqualTo(1); + } } diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java index c6999e46b..79177932c 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java @@ -55,6 +55,7 @@ import dev.cel.runtime.CelFunctionBinding; import dev.cel.runtime.CelRuntime; import dev.cel.runtime.CelRuntimeFactory; +import java.util.List; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Test; import org.junit.runner.RunWith; @@ -381,6 +382,31 @@ public void lazyEval_blockIndexEvaluatedOnlyOnce() throws Exception { assertThat(invocation.get()).isEqualTo(1); } + @Test + @SuppressWarnings("Immutable") // Test only + public void lazyEval_withinComprehension_blockIndexEvaluatedOnlyOnce() throws Exception { + AtomicInteger invocation = new AtomicInteger(); + CelRuntime celRuntime = + CelRuntimeFactory.standardCelRuntimeBuilder() + .addMessageTypes(TestAllTypes.getDescriptor()) + .addFunctionBindings( + CelFunctionBinding.from( + "get_true_overload", + ImmutableList.of(), + arg -> { + invocation.getAndIncrement(); + return true; + })) + .build(); + CelAbstractSyntaxTree ast = + compileUsingInternalFunctions("cel.block([get_true()], [1,2,3].map(x, x < 0 || index0))"); + + List result = (List) celRuntime.createProgram(ast).eval(); + + assertThat(result).containsExactly(true, true, true); + assertThat(invocation.get()).isEqualTo(1); + } + @Test @SuppressWarnings("Immutable") // Test only public void lazyEval_multipleBlockIndices_inResultExpr() throws Exception { diff --git a/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java b/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java index 546290a4e..3a0021fe1 100644 --- a/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java +++ b/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java @@ -982,7 +982,8 @@ private IntermediateResult evalComprehension( .build(); } IntermediateResult accuValue; - if (LazyExpression.isLazilyEvaluable(compre)) { + boolean isLazilyEvaluable = LazyExpression.isLazilyEvaluable(compre); + if (isLazilyEvaluable) { accuValue = IntermediateResult.create(new LazyExpression(compre.accuInit())); } else { accuValue = evalNonstrictly(frame, compre.accuInit()); @@ -1035,7 +1036,12 @@ private IntermediateResult evalComprehension( accuValue = maybeAdaptViewToList(accuValue); - frame.pushScope(Collections.singletonMap(accuVar, accuValue)); + Map scopedAttributes = Collections.singletonMap(accuVar, accuValue); + if (isLazilyEvaluable) { + frame.pushLazyScope(scopedAttributes); + } else { + frame.pushScope(scopedAttributes); + } IntermediateResult result; try { result = evalInternal(frame, compre.result()); @@ -1051,11 +1057,12 @@ private IntermediateResult evalCelBlock( Map blockList = new HashMap<>(); for (int index = 0; index < exprList.elements().size(); index++) { // Register the block indices as lazily evaluated expressions stored as unique identifiers. + String indexKey = "@index" + index; blockList.put( - "@index" + index, + indexKey, IntermediateResult.create(new LazyExpression(exprList.elements().get(index)))); } - frame.pushScope(Collections.unmodifiableMap(blockList)); + frame.pushLazyScope(Collections.unmodifiableMap(blockList)); return evalInternal(frame, blockCall.args().get(1)); } @@ -1167,6 +1174,13 @@ private void cacheLazilyEvaluatedResult( currentResolver.cacheLazilyEvaluatedResult(name, result); } + private void pushLazyScope(Map scope) { + pushScope(scope); + for (String lazyAttribute : scope.keySet()) { + currentResolver.declareLazyAttribute(lazyAttribute); + } + } + /** Note: we utilize a HashMap instead of ImmutableMap to make lookups faster on string keys. */ private void pushScope(Map scope) { scopeLevel++; diff --git a/runtime/src/main/java/dev/cel/runtime/RuntimeUnknownResolver.java b/runtime/src/main/java/dev/cel/runtime/RuntimeUnknownResolver.java index e9fb9d052..7cdd25989 100644 --- a/runtime/src/main/java/dev/cel/runtime/RuntimeUnknownResolver.java +++ b/runtime/src/main/java/dev/cel/runtime/RuntimeUnknownResolver.java @@ -116,7 +116,11 @@ DefaultInterpreter.IntermediateResult resolveSimpleName(String name, Long exprId } void cacheLazilyEvaluatedResult(String name, DefaultInterpreter.IntermediateResult result) { - // no-op. Caching is handled in ScopedResolver. + throw new IllegalStateException("Internal error: Lazy attributes can only be cached in ScopedResolver."); + } + + void declareLazyAttribute(String attrName) { + throw new IllegalStateException("Internal error: Lazy attributes can only be declared in ScopedResolver."); } /** @@ -161,7 +165,16 @@ DefaultInterpreter.IntermediateResult resolveSimpleName(String name, Long exprId @Override void cacheLazilyEvaluatedResult(String name, DefaultInterpreter.IntermediateResult result) { - lazyEvalResultCache.put(name, copyIfMutable(result)); + if (!lazyEvalResultCache.containsKey(name)) { + parent.cacheLazilyEvaluatedResult(name, result); + } else { + lazyEvalResultCache.put(name, copyIfMutable(result)); + } + } + + @Override + void declareLazyAttribute(String attrName) { + lazyEvalResultCache.put(attrName, null); } /** From 4089a6021caf30ba751449ffd7048d7102e5029d Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Wed, 3 Dec 2025 13:28:26 -0800 Subject: [PATCH 2/2] foo --- .../extensions/CelBindingsExtensionsTest.java | 35 +++++ .../optimizers/SubexpressionOptimizer.java | 121 +++++++++++++++++- .../dev/cel/optimizer/optimizers/BUILD.bazel | 2 + .../SubexpressionOptimizerTest.java | 71 +++++++++- .../cel/runtime/RuntimeUnknownResolver.java | 4 +- 5 files changed, 221 insertions(+), 12 deletions(-) diff --git a/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java index b70c58bf0..9b59f080d 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java @@ -280,4 +280,39 @@ public void lazyBinding_boundAttributeInComprehension() throws Exception { assertThat(result).containsExactly(true, true, true); assertThat(invocation.get()).isEqualTo(1); } + + @Test + @SuppressWarnings("Immutable") // Test only + public void foo() throws Exception { + CelCompiler celCompiler = + CelCompilerFactory.standardCelCompilerBuilder() + .setStandardMacros(CelStandardMacro.MAP) + .addLibraries(CelExtensions.bindings()) + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "get_true", + CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL))) + .build(); + AtomicInteger invocation = new AtomicInteger(); + CelRuntime celRuntime = + CelRuntimeFactory.standardCelRuntimeBuilder() + .addFunctionBindings( + CelFunctionBinding.from( + "get_true_overload", + ImmutableList.of(), + arg -> { + invocation.getAndIncrement(); + return true; + })) + .build(); + + CelAbstractSyntaxTree ast = celCompiler.compile("cel.bind(x, get_true(), [x, false].map(c0, [c0].map(c1, [c0, x])))").getAst(); + + Object foo = celRuntime.createProgram(ast).eval(); + System.out.println(foo); + List result = (List) celRuntime.createProgram(ast).eval(); + + assertThat(result).containsExactly(true, true, true); + assertThat(invocation.get()).isEqualTo(1); + } } diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java index eceb0bbe1..5442358f9 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java @@ -16,11 +16,13 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.util.stream.Collectors.toCollection; import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.common.base.Strings; import com.google.common.base.Verify; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -41,8 +43,11 @@ import dev.cel.common.CelVarDecl; import dev.cel.common.ast.CelExpr; import dev.cel.common.ast.CelExpr.CelCall; +import dev.cel.common.ast.CelExpr.CelComprehension; +import dev.cel.common.ast.CelExpr.CelList; import dev.cel.common.ast.CelExpr.ExprKind.Kind; import dev.cel.common.ast.CelMutableExpr; +import dev.cel.common.ast.CelMutableExpr.CelMutableComprehension; import dev.cel.common.ast.CelMutableExprConverter; import dev.cel.common.navigation.CelNavigableExpr; import dev.cel.common.navigation.CelNavigableMutableAst; @@ -59,7 +64,9 @@ import java.util.Comparator; import java.util.HashSet; import java.util.List; +import java.util.Objects; import java.util.Set; +import java.util.stream.Stream; /** * Performs Common Subexpression Elimination. @@ -90,14 +97,18 @@ public class SubexpressionOptimizer implements CelAstOptimizer { private static final SubexpressionOptimizer INSTANCE = new SubexpressionOptimizer(SubexpressionOptimizerOptions.newBuilder().build()); private static final String BIND_IDENTIFIER_PREFIX = "@r"; - private static final String MANGLED_COMPREHENSION_ITER_VAR_PREFIX = "@it"; - private static final String MANGLED_COMPREHENSION_ITER_VAR2_PREFIX = "@it2"; - private static final String MANGLED_COMPREHENSION_ACCU_VAR_PREFIX = "@ac"; private static final String CEL_BLOCK_FUNCTION = "cel.@block"; private static final String BLOCK_INDEX_PREFIX = "@index"; private static final Extension CEL_BLOCK_AST_EXTENSION_TAG = Extension.create("cel_block", Version.of(1L, 1L), Component.COMPONENT_RUNTIME); + @VisibleForTesting + static final String MANGLED_COMPREHENSION_ITER_VAR_PREFIX = "@it"; + @VisibleForTesting + static final String MANGLED_COMPREHENSION_ITER_VAR2_PREFIX = "@it2"; + @VisibleForTesting + static final String MANGLED_COMPREHENSION_ACCU_VAR_PREFIX = "@ac"; + private final SubexpressionOptimizerOptions cseOptions; private final AstMutator astMutator; private final ImmutableSet cseEliminableFunctions; @@ -269,6 +280,8 @@ static void verifyOptimizedAstCorrectness(CelAbstractSyntaxTree ast) { Verify.verify( resultHasAtLeastOneBlockIndex, "Expected at least one reference of index in cel.block result"); + + verifyNoInvalidScopedMangledVariables(celBlockExpr); } private static void verifyBlockIndex(CelExpr celExpr, int maxIndexValue) { @@ -289,6 +302,69 @@ private static void verifyBlockIndex(CelExpr celExpr, int maxIndexValue) { celExpr); } + private static void verifyNoInvalidScopedMangledVariables(CelExpr celExpr) { + CelCall celBlockCall = celExpr.call(); + CelExpr blockBody = celBlockCall.args().get(1); + + ImmutableSet allMangledVariablesInBlockBody = + CelNavigableExpr.fromExpr(blockBody) + .allNodes() + .map(CelNavigableExpr::expr) + .flatMap(SubexpressionOptimizer::extractMangledNames) + .collect(toImmutableSet()); + + CelList blockIndices = celBlockCall.args().get(0).list(); + for (CelExpr blockIndex : blockIndices.elements()) { + ImmutableSet indexDeclaredCompVariables = + CelNavigableExpr.fromExpr(blockIndex) + .allNodes() + .map(CelNavigableExpr::expr) + .filter(expr -> expr.getKind() == Kind.COMPREHENSION) + .map(CelExpr::comprehension) + .flatMap(comp -> Stream.of( + comp.iterVar(), + comp.iterVar2() + )) + .filter(iter -> !Strings.isNullOrEmpty(iter)) + .collect(toImmutableSet()); + + boolean containsIllegalDeclaration = + CelNavigableExpr.fromExpr(blockIndex) + .allNodes() + .map(CelNavigableExpr::expr) + .filter(expr -> expr.getKind() == Kind.IDENT) + .map(expr -> expr.ident().name()) + .filter(SubexpressionOptimizer::isMangled) + .anyMatch(ident -> + !indexDeclaredCompVariables.contains(ident) && + allMangledVariablesInBlockBody.contains(ident)); + + Verify.verify( + !containsIllegalDeclaration, + "Illegal declared reference to a comprehension variable found in block indices. Expr: %s", + celExpr); + } + } + + private static Stream extractMangledNames(CelExpr expr) { + if (expr.getKind() == Kind.IDENT) { + String name = expr.ident().name(); + return isMangled(name) ? Stream.of(name) : Stream.empty(); + } + if (expr.getKind() == Kind.COMPREHENSION) { + CelComprehension comp = expr.comprehension(); + return Stream.of(comp.iterVar(), comp.iterVar2(), comp.accuVar()) + .filter(Objects::nonNull) // Handle potential null/empty iterVar2 + .filter(SubexpressionOptimizer::isMangled); + } + return Stream.empty(); + } + + private static boolean isMangled(String name) { + return name.startsWith(MANGLED_COMPREHENSION_ITER_VAR_PREFIX) + || name.startsWith(MANGLED_COMPREHENSION_ITER_VAR2_PREFIX); + } + private static CelAbstractSyntaxTree tagAstExtension(CelAbstractSyntaxTree ast) { // Tag the extension CelSource.Builder celSourceBuilder = @@ -355,8 +431,8 @@ private List getCseCandidatesWithRecursionDepth( navAst .getRoot() .descendants(TraversalOrder.PRE_ORDER) - .filter(node -> canEliminate(node, ineligibleExprs)) .filter(node -> node.height() <= recursionLimit) + .filter(node -> canEliminate(node, ineligibleExprs)) .sorted(Comparator.comparingInt(CelNavigableMutableExpr::height).reversed()) .collect(toImmutableList()); if (descendants.isEmpty()) { @@ -441,9 +517,44 @@ private boolean canEliminate( && navigableExpr.expr().list().elements().isEmpty()) && containsEliminableFunctionOnly(navigableExpr) && !ineligibleExprs.contains(navigableExpr.expr()) - && containsComprehensionIdentInSubexpr(navigableExpr); + && containsComprehensionIdentInSubexpr(navigableExpr) + && containsProperScopedComprehensionIdents(navigableExpr); + } + + private boolean containsProperScopedComprehensionIdents(CelNavigableMutableExpr navExpr) { + if (!navExpr.getKind().equals(Kind.COMPREHENSION)) { + return true; + } + + // For nested comprehensions of form [1].exists(x, [2].exists(y, x == y)), the inner comprehension [2].exists(y, x == y) + // should not be extracted out into a block index, as it causes issues with scoping. + ImmutableSet mangledIterVars = navExpr.descendants() + .filter(x -> x.getKind().equals(Kind.IDENT)) + .map(x -> x.expr().ident().name()) + .filter(name -> + name.startsWith(MANGLED_COMPREHENSION_ITER_VAR_PREFIX) || + name.startsWith(MANGLED_COMPREHENSION_ITER_VAR2_PREFIX) + ).collect(toImmutableSet()); + + CelNavigableMutableExpr parent = navExpr.parent().orElse(null); + while (parent != null) { + if (parent.getKind().equals(Kind.COMPREHENSION)) { + CelMutableComprehension comp = parent.expr().comprehension(); + boolean containsParentIterReferences = + mangledIterVars.contains(comp.iterVar()) || mangledIterVars.contains(comp.iterVar2()); + + if (containsParentIterReferences) { + return false; + } + } + + parent = parent.parent().orElse(null); + } + + return true; } + private boolean containsComprehensionIdentInSubexpr(CelNavigableMutableExpr navExpr) { if (navExpr.getKind().equals(Kind.COMPREHENSION)) { return true; diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel b/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel index acc8a4c3f..41d9abeb6 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel @@ -18,11 +18,13 @@ java_library( "//common:mutable_ast", "//common:options", "//common/ast", + "//common/ast:mutable_expr", "//common/navigation:mutable_navigation", "//common/types", "//extensions", "//extensions:optional_library", "//optimizer", + "//optimizer:mutable_ast", "//optimizer:optimization_exception", "//optimizer:optimizer_builder", "//optimizer/optimizers:common_subexpression_elimination", diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java index 79177932c..6cc59fd4a 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java @@ -38,6 +38,7 @@ import dev.cel.common.CelValidationException; import dev.cel.common.CelVarDecl; import dev.cel.common.ast.CelExpr.ExprKind.Kind; +import dev.cel.common.ast.CelMutableExpr; import dev.cel.common.navigation.CelNavigableMutableAst; import dev.cel.common.navigation.CelNavigableMutableExpr; import dev.cel.common.types.ListType; @@ -45,6 +46,9 @@ import dev.cel.common.types.StructTypeReference; import dev.cel.expr.conformance.proto3.TestAllTypes; import dev.cel.extensions.CelExtensions; +import dev.cel.optimizer.AstMutator; +import dev.cel.optimizer.AstMutator.MangledComprehensionAst; +import dev.cel.optimizer.AstMutator.MangledComprehensionName; import dev.cel.optimizer.CelOptimizationException; import dev.cel.optimizer.CelOptimizer; import dev.cel.optimizer.CelOptimizerFactory; @@ -91,9 +95,11 @@ public class SubexpressionOptimizerTest { CelVarDecl.newVarDeclaration("index0", SimpleType.DYN), CelVarDecl.newVarDeclaration("index1", SimpleType.DYN), CelVarDecl.newVarDeclaration("index2", SimpleType.DYN), + CelVarDecl.newVarDeclaration("it", SimpleType.DYN), CelVarDecl.newVarDeclaration("@index0", SimpleType.DYN), CelVarDecl.newVarDeclaration("@index1", SimpleType.DYN), - CelVarDecl.newVarDeclaration("@index2", SimpleType.DYN)) + CelVarDecl.newVarDeclaration("@index2", SimpleType.DYN), + CelVarDecl.newVarDeclaration("@it:0:0", SimpleType.DYN)) .addMessageTypes(TestAllTypes.getDescriptor()) .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) .build(); @@ -285,6 +291,27 @@ public void iterationLimitReached_throws() throws Exception { assertThat(e).hasMessageThat().isEqualTo("Optimization failure: Max iteration count reached."); } + @Test + public void foo() throws Exception { + CelAbstractSyntaxTree ast = CEL.compile("[1].map(y, [1, 2].filter(x, x == y))").getAst(); + CelOptimizer optimizer = + CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + .addAstOptimizers( + SubexpressionOptimizer.newInstance( + SubexpressionOptimizerOptions.newBuilder() + .subexpressionMaxRecursionDepth(4) + .populateMacroCalls(true).build())) + .build(); + + CelAbstractSyntaxTree optimizedAst = optimizer.optimize(ast); + + Object result = CEL.createProgram(ast).eval(); + System.out.println(result); + + assertThat(CEL_UNPARSER.unparse(optimizedAst)) + .isEqualTo("foo"); + } + @Test public void celBlock_astExtensionTagged() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("size(x) + size(x)").getAst(); @@ -478,9 +505,7 @@ public void lazyEval_nestedComprehension_indexReferencedInNestedScopes() throws // Equivalent of [true, false, true].map(c0, [c0].map(c1, [c0, c1, true])) CelAbstractSyntaxTree ast = compileUsingInternalFunctions( - "cel.block([c0, c1, get_true()], [index2, false, index2].map(c0, [c0].map(c1, [index0," - + " index1, index2]))) == [[[true, true, true]], [[false, false, true]], [[true," - + " true, true]]]"); + "cel.block([true, false, get_true()], [index2, false, index2].map(c0, [c0].map(c1, [c0, c1, index2]))) == [[[true, true, true]], [[false, false, true]], [[true, true, true]]]"); boolean result = (boolean) celRuntime.createProgram(ast).eval(); @@ -547,6 +572,18 @@ public void verifyOptimizedAstCorrectness_blockContainsNoIndexResult_throws() th .isEqualTo("Expected at least one reference of index in cel.block result"); } + @Test + public void verifyOptimizedAstCorrectness_containsForwardReferenceFromComprehensionVar_throws() throws Exception { + CelAbstractSyntaxTree ast = compileUsingInternalFunctions("cel.block([it], [1].exists(it, it > 0 && index0 > 0))"); + + VerifyException e = + assertThrows( + VerifyException.class, () -> SubexpressionOptimizer.verifyOptimizedAstCorrectness(ast)); + assertThat(e) + .hasMessageThat() + .startsWith("Illegal declared reference to a comprehension variable found in block indices."); + } + @Test @TestParameters("{source: 'cel.block([], index0)'}") @TestParameters("{source: 'cel.block([1, 2], index2)'}") @@ -600,13 +637,37 @@ private static CelAbstractSyntaxTree compileUsingInternalFunctions(String expres .allNodes() .filter(node -> node.getKind().equals(Kind.IDENT)) .map(CelNavigableMutableExpr::expr) - .filter(expr -> expr.ident().name().startsWith("index")) + .filter(expr -> + expr.ident().name().startsWith("index") + ) .forEach( indexExpr -> { String internalIdentName = "@" + indexExpr.ident().name(); indexExpr.ident().setName(internalIdentName); }); + MangledComprehensionAst mangledComprehensionAst = AstMutator.newInstance(10000).mangleComprehensionIdentifierNames( + mutableAst, + SubexpressionOptimizer.MANGLED_COMPREHENSION_ITER_VAR_PREFIX, + SubexpressionOptimizer.MANGLED_COMPREHENSION_ITER_VAR2_PREFIX, + SubexpressionOptimizer.MANGLED_COMPREHENSION_ACCU_VAR_PREFIX + ); + mutableAst = mangledComprehensionAst.mutableAst(); + + CelNavigableMutableAst.fromAst(mutableAst) + .getRoot() + .allNodes() + .filter(node -> node.getKind().equals(Kind.IDENT)) + .map(CelNavigableMutableExpr::expr) + .filter(expr -> + expr.ident().name().equals("it") + ) + .forEach( + indexExpr -> { + indexExpr.ident().setName("@it:0:0"); + }); + + return CEL_FOR_EVALUATING_BLOCK.check(mutableAst.toParsedAst()).getAst(); } } diff --git a/runtime/src/main/java/dev/cel/runtime/RuntimeUnknownResolver.java b/runtime/src/main/java/dev/cel/runtime/RuntimeUnknownResolver.java index 7cdd25989..ed6ae8ca5 100644 --- a/runtime/src/main/java/dev/cel/runtime/RuntimeUnknownResolver.java +++ b/runtime/src/main/java/dev/cel/runtime/RuntimeUnknownResolver.java @@ -116,11 +116,11 @@ DefaultInterpreter.IntermediateResult resolveSimpleName(String name, Long exprId } void cacheLazilyEvaluatedResult(String name, DefaultInterpreter.IntermediateResult result) { - throw new IllegalStateException("Internal error: Lazy attributes can only be cached in ScopedResolver."); + // throw new IllegalStateException("Internal error: Lazy attributes can only be cached in ScopedResolver."); } void declareLazyAttribute(String attrName) { - throw new IllegalStateException("Internal error: Lazy attributes can only be declared in ScopedResolver."); + // throw new IllegalStateException("Internal error: Lazy attributes can only be declared in ScopedResolver."); } /**