From fe8d392a1a305621400e0095bff845041b40749e Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Thu, 10 Apr 2025 11:58:10 -0700 Subject: [PATCH] handle nested comprehension containing accuVar errors to always pop --- .bazelrc | 4 ++ runtime/BUILD.bazel | 5 ++ .../src/main/java/dev/cel/runtime/BUILD.bazel | 1 - .../dev/cel/runtime/DefaultInterpreter.java | 53 +++++++++++++--- .../src/test/java/dev/cel/runtime/BUILD.bazel | 3 + .../cel/runtime/DefaultInterpreterTest.java | 62 +++++++++++++++++++ 6 files changed, 120 insertions(+), 8 deletions(-) create mode 100644 runtime/src/test/java/dev/cel/runtime/DefaultInterpreterTest.java diff --git a/.bazelrc b/.bazelrc index 4e4a0184c..750977061 100644 --- a/.bazelrc +++ b/.bazelrc @@ -5,3 +5,7 @@ build --java_language_version=11 # Hide Java 8 deprecation warnings. common --javacopt=-Xlint:-options + +# MacOS Fix https://github.com/protocolbuffers/protobuf/issues/16944 +build --host_cxxopt=-std=c++14 +build --cxxopt=-std=c++14 diff --git a/runtime/BUILD.bazel b/runtime/BUILD.bazel index 1ca1e59c9..6061b1ac9 100644 --- a/runtime/BUILD.bazel +++ b/runtime/BUILD.bazel @@ -105,6 +105,11 @@ java_library( exports = ["//runtime/src/main/java/dev/cel/runtime:interpreter"], ) +java_library( + name = "interpretable", + exports = ["//runtime/src/main/java/dev/cel/runtime:interpretable"], +) + java_library( name = "runtime_helpers", visibility = ["//:internal"], diff --git a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel index e4c4dbbc5..607ee4b53 100644 --- a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel @@ -506,7 +506,6 @@ java_library( java_library( name = "interpretable", srcs = INTERPRABLE_SOURCES, - visibility = ["//visibility:private"], deps = [ ":evaluation_exception", ":evaluation_listener", diff --git a/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java b/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java index 14ae160fa..fdf98bebf 100644 --- a/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java +++ b/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java @@ -17,6 +17,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.auto.value.AutoValue; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Joiner; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -111,7 +112,8 @@ public Interpretable createInterpretable(CelAbstractSyntaxTree ast) { } @Immutable - private static final class DefaultInterpretable + @VisibleForTesting + static final class DefaultInterpretable implements Interpretable, UnknownTrackingInterpretable { private final TypeResolver typeResolver; private final RuntimeTypeProvider typeProvider; @@ -165,12 +167,38 @@ public Object evalTrackingUnknowns( Optional functionResolver, CelEvaluationListener listener) throws CelEvaluationException { + ExecutionFrame frame = newExecutionFrame(resolver, functionResolver, listener); + IntermediateResult internalResult = evalInternal(frame, ast.getExpr()); + return internalResult.value(); + } + + /** + * Evaluates this interpretable and returns the resulting execution frame populated with evaluation state. + * This method is specifically designed for testing the interpreter's internal invariants. + * + *

Do not expose to public. This method is strictly for internal testing purposes only. + */ + @VisibleForTesting + ExecutionFrame populateExecutionFrame(ExecutionFrame frame) throws CelEvaluationException { + evalInternal(frame, ast.getExpr()); + + return frame; + } + + @VisibleForTesting + ExecutionFrame newTestExecutionFrame(GlobalResolver resolver) throws CelEvaluationException { + return newExecutionFrame(RuntimeUnknownResolver.fromResolver(resolver), Optional.empty(), CelEvaluationListener.noOpListener()); + } + + + private ExecutionFrame newExecutionFrame( + RuntimeUnknownResolver resolver, + Optional functionResolver, + CelEvaluationListener listener) throws CelEvaluationException { int comprehensionMaxIterations = celOptions.enableComprehension() ? celOptions.comprehensionMaxIterations() : 0; - ExecutionFrame frame = + return new ExecutionFrame(listener, resolver, functionResolver, comprehensionMaxIterations); - IntermediateResult internalResult = evalInternal(frame, ast.getExpr()); - return internalResult.value(); } private IntermediateResult evalInternal(ExecutionFrame frame, CelExpr expr) @@ -916,6 +944,7 @@ private IntermediateResult evalComprehension( iterVar, IntermediateResult.create(iterAttr, RuntimeHelpers.maybeAdaptPrimitive(elem))); loopVars.put(accuVar, accuValue); + System.out.println("Push scope"); frame.pushScope(Collections.unmodifiableMap(loopVars)); IntermediateResult evalObject = evalBooleanStrict(frame, compre.loopCondition()); if (!isUnknownValue(evalObject.value()) && !(boolean) evalObject.value()) { @@ -927,8 +956,13 @@ private IntermediateResult evalComprehension( } frame.pushScope(Collections.singletonMap(accuVar, accuValue)); - IntermediateResult result = evalInternal(frame, compre.result()); - frame.popScope(); + IntermediateResult result; + try { + result = evalInternal(frame, compre.result()); + } + finally { + frame.popScope(); + } return result; } @@ -975,14 +1009,17 @@ private LazyExpression(CelExpr celExpr) { } } + /** This class tracks the state meaningful to a single evaluation pass. */ - private static class ExecutionFrame { + static class ExecutionFrame { private final CelEvaluationListener evaluationListener; private final int maxIterations; private final ArrayDeque resolvers; private final Optional lateBoundFunctionResolver; private RuntimeUnknownResolver currentResolver; private int iterations; + @VisibleForTesting + int scopeLevel; private ExecutionFrame( CelEvaluationListener evaluationListener, @@ -1040,12 +1077,14 @@ private void cacheLazilyEvaluatedResult( /** Note: we utilize a HashMap instead of ImmutableMap to make lookups faster on string keys. */ private void pushScope(Map scope) { + scopeLevel++; RuntimeUnknownResolver scopedResolver = currentResolver.withScope(scope); currentResolver = scopedResolver; resolvers.addLast(scopedResolver); } private void popScope() { + scopeLevel--; if (resolvers.isEmpty()) { throw new IllegalStateException("Execution frame error: more scopes popped than pushed"); } diff --git a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel index 0f568b720..db7bb62bf 100644 --- a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel @@ -72,6 +72,7 @@ java_library( "//common:cel_descriptors", "//common:cel_exception", "//common:cel_source", + "//common:compiler_common", "//common:error_codes", "//common:options", "//common:proto_ast", @@ -100,6 +101,8 @@ java_library( "//runtime:evaluation_listener", "//runtime:function_binding", "//runtime:function_overload_impl", + "//runtime:interpretable", + "//runtime:interpreter", "//runtime:interpreter_util", "//runtime:lite_runtime", "//runtime:lite_runtime_factory", diff --git a/runtime/src/test/java/dev/cel/runtime/DefaultInterpreterTest.java b/runtime/src/test/java/dev/cel/runtime/DefaultInterpreterTest.java new file mode 100644 index 000000000..0203c0c80 --- /dev/null +++ b/runtime/src/test/java/dev/cel/runtime/DefaultInterpreterTest.java @@ -0,0 +1,62 @@ +package dev.cel.runtime; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.testing.junit.testparameterinjector.TestParameterInjector; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelFunctionDecl; +import dev.cel.common.CelOptions; +import dev.cel.common.CelOverloadDecl; +import dev.cel.common.types.SimpleType; +import dev.cel.compiler.CelCompiler; +import dev.cel.compiler.CelCompilerFactory; +import dev.cel.parser.CelStandardMacro; +import dev.cel.runtime.DefaultInterpreter.DefaultInterpretable; +import dev.cel.runtime.DefaultInterpreter.ExecutionFrame; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; + +@RunWith(TestParameterInjector.class) +public class DefaultInterpreterTest { + + @Test + public void nestedComprehensions_accuVarContainsErrors_scopeLevelInvariantNotViolated() throws Exception { + CelCompiler celCompiler = CelCompilerFactory.standardCelCompilerBuilder() + .addFunctionDeclarations(CelFunctionDecl.newFunctionDeclaration( + "error", CelOverloadDecl.newGlobalOverload("error_overload", SimpleType.DYN) + )) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS).build(); + RuntimeTypeProvider emptyProvider = new RuntimeTypeProvider() { + @Override + public Object createMessage(String messageName, Map values) { + return null; + } + @Override + public Object selectField(Object message, String fieldName) { + return null; + } + @Override + public Object hasField(Object message, String fieldName) { + return null; + } + @Override + public Object adapt(Object message) { + return message; + } + }; + CelAbstractSyntaxTree ast = + celCompiler.compile("[1].all(x, [2].all(y, error()))").getAst(); + DefaultDispatcher dispatcher = DefaultDispatcher.create(); + dispatcher.add("error", long.class, (args) -> new IllegalArgumentException("Always throws")); + DefaultInterpreter defaultInterpreter = new DefaultInterpreter( + new TypeResolver(), emptyProvider, dispatcher, CelOptions.DEFAULT); + DefaultInterpretable interpretable = (DefaultInterpretable) defaultInterpreter.createInterpretable(ast); + + ExecutionFrame frame = interpretable.newTestExecutionFrame(GlobalResolver.EMPTY); + + assertThrows(CelEvaluationException.class, () -> interpretable.populateExecutionFrame(frame)); + assertThat(frame.scopeLevel).isEqualTo(0); + } +}