From f4cb635037ac979b0ecac6907c98a841df65849f Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Tue, 15 Jul 2025 14:26:36 -0700 Subject: [PATCH] Accumulated unknowns --- .../dev/cel/runtime/AccumulatedUnknowns.java | 47 ++++++++++++++++ .../src/main/java/dev/cel/runtime/BUILD.bazel | 18 +++++++ .../dev/cel/runtime/CallArgumentChecker.java | 24 ++++----- .../cel/runtime/CelEvaluationListener.java | 5 -- .../java/dev/cel/runtime/CelUnknownSet.java | 2 +- .../dev/cel/runtime/DefaultInterpreter.java | 54 ++++++++++++------- .../java/dev/cel/runtime/InterpreterUtil.java | 26 +++++---- .../java/dev/cel/runtime/LiteProgramImpl.java | 3 +- .../java/dev/cel/runtime/ProgramImpl.java | 32 ++++++----- .../cel/runtime/RuntimeUnknownResolver.java | 29 ++++++++-- .../runtime/UnknownTrackingInterpretable.java | 2 +- .../java/dev/cel/runtime/CelRuntimeTest.java | 49 +++++++++++++++++ 12 files changed, 225 insertions(+), 66 deletions(-) create mode 100644 runtime/src/main/java/dev/cel/runtime/AccumulatedUnknowns.java diff --git a/runtime/src/main/java/dev/cel/runtime/AccumulatedUnknowns.java b/runtime/src/main/java/dev/cel/runtime/AccumulatedUnknowns.java new file mode 100644 index 000000000..5647c06f8 --- /dev/null +++ b/runtime/src/main/java/dev/cel/runtime/AccumulatedUnknowns.java @@ -0,0 +1,47 @@ +package dev.cel.runtime; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +/** + * An internal representation used for fast accumulation of unknown expr IDs and attributes. + * For safety, this object should never be returned as an evaluated result and instead be adapted into an immutable CelUnknownSet. + */ +final class AccumulatedUnknowns { + + private final List exprIds; + private final List attributes; + + List exprIds() { + return exprIds; + } + + List attributes() { + return attributes; + } + + AccumulatedUnknowns merge(AccumulatedUnknowns arg) { + this.exprIds.addAll(arg.exprIds); + this.attributes.addAll(arg.attributes); + return this; + } + + static AccumulatedUnknowns create(Long... ids) { + return create(Arrays.asList(ids)); + } + + static AccumulatedUnknowns create(Collection ids) { + return create(ids, new ArrayList<>()); + } + + static AccumulatedUnknowns create(Collection exprIds, Collection attributes) { + return new AccumulatedUnknowns(new ArrayList<>(exprIds), new ArrayList<>(attributes)); + } + + private AccumulatedUnknowns(List exprIds, List attributes) { + this.exprIds = exprIds; + this.attributes = attributes; + } +} diff --git a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel index 4588bd2d5..0e381341d 100644 --- a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel @@ -271,6 +271,7 @@ java_library( ], exports = [":base"], deps = [ + ":accumulated_unknowns", ":base", ":concatenated_list_view", ":dispatcher", @@ -306,6 +307,7 @@ cel_android_library( srcs = INTERPRETER_SOURCES, visibility = ["//visibility:private"], deps = [ + ":accumulated_unknowns_android", ":base_android", ":concatenated_list_view", ":dispatcher_android", @@ -1034,6 +1036,7 @@ java_library( tags = [ ], deps = [ + ":accumulated_unknowns", ":evaluation_exception", ":unknown_attributes", "//common/annotations", @@ -1047,6 +1050,7 @@ cel_android_library( srcs = ["InterpreterUtil.java"], visibility = ["//visibility:private"], deps = [ + ":accumulated_unknowns_android", ":evaluation_exception", ":unknown_attributes_android", "//common/annotations", @@ -1105,3 +1109,17 @@ java_library( # used_by_android visibility = ["//visibility:private"], ) + +java_library( + name = "accumulated_unknowns", + srcs = ["AccumulatedUnknowns.java"], + visibility = ["//visibility:private"], + deps = [":unknown_attributes"], +) + +cel_android_library( + name = "accumulated_unknowns_android", + srcs = ["AccumulatedUnknowns.java"], + visibility = ["//visibility:private"], + deps = [":unknown_attributes_android"], +) diff --git a/runtime/src/main/java/dev/cel/runtime/CallArgumentChecker.java b/runtime/src/main/java/dev/cel/runtime/CallArgumentChecker.java index b3096fc7a..49c86f5d2 100644 --- a/runtime/src/main/java/dev/cel/runtime/CallArgumentChecker.java +++ b/runtime/src/main/java/dev/cel/runtime/CallArgumentChecker.java @@ -32,7 +32,7 @@ class CallArgumentChecker { private final ArrayList exprIds; private final RuntimeUnknownResolver resolver; private final boolean acceptPartial; - private Optional unknowns; + private Optional unknowns; private CallArgumentChecker(RuntimeUnknownResolver resolver, boolean acceptPartial) { this.exprIds = new ArrayList<>(); @@ -61,29 +61,29 @@ static CallArgumentChecker createAcceptingPartial(RuntimeUnknownResolver resolve return new CallArgumentChecker(resolver, true); } - private static Optional mergeOptionalUnknowns( - Optional lhs, Optional rhs) { + private static Optional mergeOptionalUnknowns( + Optional lhs, Optional rhs) { return lhs.isPresent() ? rhs.isPresent() ? Optional.of(lhs.get().merge(rhs.get())) : lhs : rhs; } /** Determine if the call argument is unknown and accumulate if so. */ void checkArg(DefaultInterpreter.IntermediateResult arg) { // Handle attribute tracked unknowns. - Optional argUnknowns = maybeUnknownFromArg(arg); + Optional argUnknowns = maybeUnknownFromArg(arg); unknowns = mergeOptionalUnknowns(unknowns, argUnknowns); // support for ExprValue unknowns. - if (InterpreterUtil.isUnknown(arg.value())) { - CelUnknownSet unknownSet = (CelUnknownSet) arg.value(); - exprIds.addAll(unknownSet.unknownExprIds()); + if (InterpreterUtil.isAccumulatedUnknowns(arg.value())) { + AccumulatedUnknowns unknownSet = (AccumulatedUnknowns) arg.value(); + exprIds.addAll(unknownSet.exprIds()); } } - private Optional maybeUnknownFromArg(DefaultInterpreter.IntermediateResult arg) { - if (arg.value() instanceof CelUnknownSet) { - CelUnknownSet celUnknownSet = (CelUnknownSet) arg.value(); + private Optional maybeUnknownFromArg(DefaultInterpreter.IntermediateResult arg) { + if (arg.value() instanceof AccumulatedUnknowns) { + AccumulatedUnknowns celUnknownSet = (AccumulatedUnknowns) arg.value(); if (!celUnknownSet.attributes().isEmpty()) { - return Optional.of((CelUnknownSet) arg.value()); + return Optional.of((AccumulatedUnknowns) arg.value()); } } if (!acceptPartial) { @@ -99,7 +99,7 @@ Optional maybeUnknowns() { } if (!exprIds.isEmpty()) { - return Optional.of(CelUnknownSet.create(exprIds)); + return Optional.of(AccumulatedUnknowns.create(exprIds)); } return Optional.empty(); diff --git a/runtime/src/main/java/dev/cel/runtime/CelEvaluationListener.java b/runtime/src/main/java/dev/cel/runtime/CelEvaluationListener.java index 12f11ea5e..1cb9e8d97 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelEvaluationListener.java +++ b/runtime/src/main/java/dev/cel/runtime/CelEvaluationListener.java @@ -33,9 +33,4 @@ public interface CelEvaluationListener { * @param evaluatedResult Evaluated result. */ void callback(CelExpr expr, Object evaluatedResult); - - /** Construct a listener that does nothing. */ - static CelEvaluationListener noOpListener() { - return (arg1, arg2) -> {}; - } } diff --git a/runtime/src/main/java/dev/cel/runtime/CelUnknownSet.java b/runtime/src/main/java/dev/cel/runtime/CelUnknownSet.java index 05201d1e5..c7f1d0c91 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelUnknownSet.java +++ b/runtime/src/main/java/dev/cel/runtime/CelUnknownSet.java @@ -59,7 +59,7 @@ static CelUnknownSet create(Iterable unknownExprIds) { return create(ImmutableSet.of(), ImmutableSet.copyOf(unknownExprIds)); } - private static CelUnknownSet create( + static CelUnknownSet create( ImmutableSet attributes, ImmutableSet unknownExprIds) { return new AutoValue_CelUnknownSet(attributes, unknownExprIds); } diff --git a/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java b/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java index 9db72cdf2..aa7d1ace9 100644 --- a/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java +++ b/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java @@ -21,6 +21,7 @@ import com.google.common.base.Joiner; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.Immutable; import javax.annotation.concurrent.ThreadSafe; @@ -137,20 +138,22 @@ static final class DefaultInterpretable implements Interpretable, UnknownTrackin @Override public Object eval(GlobalResolver resolver) throws CelEvaluationException { // Result is already unwrapped from IntermediateResult. - return eval(resolver, CelEvaluationListener.noOpListener()); + return evalTrackingUnknowns( + RuntimeUnknownResolver.fromResolver(resolver), Optional.empty(), Optional.empty()); } @Override public Object eval(GlobalResolver resolver, CelEvaluationListener listener) throws CelEvaluationException { return evalTrackingUnknowns( - RuntimeUnknownResolver.fromResolver(resolver), Optional.empty(), listener); + RuntimeUnknownResolver.fromResolver(resolver), Optional.empty(), Optional.of(listener)); } @Override public Object eval(GlobalResolver resolver, FunctionResolver lateBoundFunctionResolver) throws CelEvaluationException { - return eval(resolver, lateBoundFunctionResolver, CelEvaluationListener.noOpListener()); + return evalTrackingUnknowns( + RuntimeUnknownResolver.fromResolver(resolver), Optional.of(lateBoundFunctionResolver), Optional.empty()); } @Override @@ -162,19 +165,31 @@ public Object eval( return evalTrackingUnknowns( RuntimeUnknownResolver.fromResolver(resolver), Optional.of(lateBoundFunctionResolver), - listener); + Optional.of(listener)); } @Override public Object evalTrackingUnknowns( RuntimeUnknownResolver resolver, Optional functionResolver, - CelEvaluationListener listener) + Optional listener) throws CelEvaluationException { ExecutionFrame frame = newExecutionFrame(resolver, functionResolver, listener); IntermediateResult internalResult = evalInternal(frame, ast.getExpr()); - return internalResult.value(); + Object underlyingValue = internalResult.value(); + + return maybeAdaptToCelUnknownSet(underlyingValue); + } + + private static Object maybeAdaptToCelUnknownSet(Object val) { + if (!(val instanceof AccumulatedUnknowns)) { + return val; + } + + AccumulatedUnknowns unknowns = (AccumulatedUnknowns) val; + + return CelUnknownSet.create(ImmutableSet.copyOf(unknowns.attributes()), ImmutableSet.copyOf(unknowns.exprIds())); } /** @@ -198,13 +213,13 @@ ExecutionFrame newTestExecutionFrame(GlobalResolver resolver) { return newExecutionFrame( RuntimeUnknownResolver.fromResolver(resolver), Optional.empty(), - CelEvaluationListener.noOpListener()); + Optional.empty()); } private ExecutionFrame newExecutionFrame( RuntimeUnknownResolver resolver, Optional functionResolver, - CelEvaluationListener listener) { + Optional listener) { int comprehensionMaxIterations = celOptions.enableComprehension() ? celOptions.comprehensionMaxIterations() : 0; return new ExecutionFrame(listener, resolver, functionResolver, comprehensionMaxIterations); @@ -244,7 +259,8 @@ private IntermediateResult evalInternal(ExecutionFrame frame, CelExpr expr) throw new IllegalStateException( "unexpected expression kind: " + expr.exprKind().getKind()); } - frame.getEvaluationListener().callback(expr, result.value()); + + frame.getEvaluationListener().ifPresent(listener -> listener.callback(expr, maybeAdaptToCelUnknownSet(result.value()))); return result; } catch (CelRuntimeException e) { throw CelEvaluationExceptionBuilder.newBuilder(e).setMetadata(metadata, expr.id()).build(); @@ -257,7 +273,7 @@ private IntermediateResult evalInternal(ExecutionFrame frame, CelExpr expr) } private static boolean isUnknownValue(Object value) { - return value instanceof CelUnknownSet || InterpreterUtil.isUnknown(value); + return InterpreterUtil.isAccumulatedUnknowns(value); } private static boolean isUnknownOrError(Object value) { @@ -552,18 +568,20 @@ private IntermediateResult mergeBooleanUnknowns(IntermediateResult lhs, Intermed throws CelEvaluationException { // TODO: migrate clients to a common type that reports both expr-id unknowns // and attribute sets. - if (lhs.value() instanceof CelUnknownSet && rhs.value() instanceof CelUnknownSet) { + Object lhsVal = lhs.value(); + Object rhsVal = rhs.value(); + if (lhsVal instanceof AccumulatedUnknowns && rhsVal instanceof AccumulatedUnknowns) { return IntermediateResult.create( - ((CelUnknownSet) lhs.value()).merge((CelUnknownSet) rhs.value())); - } else if (lhs.value() instanceof CelUnknownSet) { + ((AccumulatedUnknowns) lhsVal).merge((AccumulatedUnknowns) rhsVal)); + } else if (lhsVal instanceof AccumulatedUnknowns) { return lhs; - } else if (rhs.value() instanceof CelUnknownSet) { + } else if (rhsVal instanceof AccumulatedUnknowns) { return rhs; } // Otherwise fallback to normal impl return IntermediateResult.create( - InterpreterUtil.shortcircuitUnknownOrThrowable(lhs.value(), rhs.value())); + InterpreterUtil.shortcircuitUnknownOrThrowable(lhsVal, rhsVal)); } private enum ShortCircuitableOperators { @@ -1050,7 +1068,7 @@ private LazyExpression(CelExpr celExpr) { /** This class tracks the state meaningful to a single evaluation pass. */ static class ExecutionFrame { - private final CelEvaluationListener evaluationListener; + private final Optional evaluationListener; private final int maxIterations; private final ArrayDeque resolvers; private final Optional lateBoundFunctionResolver; @@ -1059,7 +1077,7 @@ static class ExecutionFrame { @VisibleForTesting int scopeLevel; private ExecutionFrame( - CelEvaluationListener evaluationListener, + Optional evaluationListener, RuntimeUnknownResolver resolver, Optional lateBoundFunctionResolver, int maxIterations) { @@ -1071,7 +1089,7 @@ private ExecutionFrame( this.maxIterations = maxIterations; } - private CelEvaluationListener getEvaluationListener() { + private Optional getEvaluationListener() { return evaluationListener; } diff --git a/runtime/src/main/java/dev/cel/runtime/InterpreterUtil.java b/runtime/src/main/java/dev/cel/runtime/InterpreterUtil.java index cb9353dd7..d1f3bd5c5 100644 --- a/runtime/src/main/java/dev/cel/runtime/InterpreterUtil.java +++ b/runtime/src/main/java/dev/cel/runtime/InterpreterUtil.java @@ -48,7 +48,7 @@ public static Object strict(Object valueOrThrowable) throws CelEvaluationExcepti } /** - * Check if raw object is ExprValue object and has UnknownSet + * Check if raw object is {@link CelUnknownSet}. * * @param obj Object to check. * @return boolean value if object is unknown. @@ -57,15 +57,19 @@ public static boolean isUnknown(Object obj) { return obj instanceof CelUnknownSet; } - static CelUnknownSet combineUnknownExprValue(Object... objs) { + static boolean isAccumulatedUnknowns(Object obj) { + return obj instanceof AccumulatedUnknowns; + } + + static AccumulatedUnknowns combineUnknownExprValue(Object... objs) { Set ids = new LinkedHashSet<>(); for (Object object : objs) { - if (isUnknown(object)) { - ids.addAll(((CelUnknownSet) object).unknownExprIds()); + if (isAccumulatedUnknowns(object)) { + ids.addAll(((AccumulatedUnknowns) object).exprIds()); } } - return CelUnknownSet.create(ids); + return AccumulatedUnknowns.create(ids); } /** @@ -79,17 +83,17 @@ static CelUnknownSet combineUnknownExprValue(Object... objs) { public static Object shortcircuitUnknownOrThrowable(Object left, Object right) throws CelEvaluationException { // unknown unknown ==> unknown combined - if (InterpreterUtil.isUnknown(left) && InterpreterUtil.isUnknown(right)) { + if (InterpreterUtil.isAccumulatedUnknowns(left) && InterpreterUtil.isAccumulatedUnknowns(right)) { return InterpreterUtil.combineUnknownExprValue(left, right); } // unknown ==> unknown // unknown t|f ==> unknown - if (InterpreterUtil.isUnknown(left)) { + if (InterpreterUtil.isAccumulatedUnknowns(left)) { return left; } // unknown ==> unknown // t|f unknown ==> unknown - if (InterpreterUtil.isUnknown(right)) { + if (InterpreterUtil.isAccumulatedUnknowns(right)) { return right; } // Throw left or right side exception for now, should combine them into ErrorSet. @@ -106,12 +110,12 @@ public static Object shortcircuitUnknownOrThrowable(Object left, Object right) public static Object valueOrUnknown(@Nullable Object valueOrThrowable, Long id) { // Handle the unknown value case. - if (isUnknown(valueOrThrowable)) { - return CelUnknownSet.create(id); + if (isAccumulatedUnknowns(valueOrThrowable)) { + return AccumulatedUnknowns.create(id); } // Handle the null value case. if (valueOrThrowable == null) { - return CelUnknownSet.create(id); + return AccumulatedUnknowns.create(id); } return valueOrThrowable; } diff --git a/runtime/src/main/java/dev/cel/runtime/LiteProgramImpl.java b/runtime/src/main/java/dev/cel/runtime/LiteProgramImpl.java index f7de52aa9..6b6dc13f1 100644 --- a/runtime/src/main/java/dev/cel/runtime/LiteProgramImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/LiteProgramImpl.java @@ -40,8 +40,7 @@ public Object eval(Map mapValue, CelFunctionResolver lateBoundFunctio return interpretable() .eval( Activation.copyOf(mapValue), - lateBoundFunctionResolver, - CelEvaluationListener.noOpListener()); + lateBoundFunctionResolver); } static CelLiteRuntime.Program plan(Interpretable interpretable) { diff --git a/runtime/src/main/java/dev/cel/runtime/ProgramImpl.java b/runtime/src/main/java/dev/cel/runtime/ProgramImpl.java index 6f36027ab..d9d63e653 100644 --- a/runtime/src/main/java/dev/cel/runtime/ProgramImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/ProgramImpl.java @@ -53,8 +53,7 @@ public Object eval(CelVariableResolver resolver, CelFunctionResolver lateBoundFu throws CelEvaluationException { return evalInternal( (name) -> resolver.find(name).orElse(null), - lateBoundFunctionResolver, - CelEvaluationListener.noOpListener()); + lateBoundFunctionResolver); } @Override @@ -62,8 +61,7 @@ public Object eval(Map mapValue, CelFunctionResolver lateBoundFunctio throws CelEvaluationException { return evalInternal( Activation.copyOf(mapValue), - lateBoundFunctionResolver, - CelEvaluationListener.noOpListener()); + lateBoundFunctionResolver); } @Override @@ -110,17 +108,22 @@ public Object trace( @Override public Object advanceEvaluation(UnknownContext context) throws CelEvaluationException { - return evalInternal(context, Optional.empty(), CelEvaluationListener.noOpListener()); + return evalInternal(context, Optional.empty(), Optional.empty()); } private Object evalInternal(GlobalResolver resolver) throws CelEvaluationException { return evalInternal( - UnknownContext.create(resolver), Optional.empty(), CelEvaluationListener.noOpListener()); + UnknownContext.create(resolver), Optional.empty(), Optional.empty()); } private Object evalInternal(GlobalResolver resolver, CelEvaluationListener listener) throws CelEvaluationException { - return evalInternal(UnknownContext.create(resolver), Optional.empty(), listener); + return evalInternal(UnknownContext.create(resolver), Optional.empty(), Optional.of(listener)); + } + + private Object evalInternal(GlobalResolver resolver, CelFunctionResolver functionResolver) + throws CelEvaluationException { + return evalInternal(UnknownContext.create(resolver), Optional.of(functionResolver), Optional.empty()); } private Object evalInternal( @@ -129,7 +132,7 @@ private Object evalInternal( CelEvaluationListener listener) throws CelEvaluationException { return evalInternal( - UnknownContext.create(resolver), Optional.of(lateBoundFunctionResolver), listener); + UnknownContext.create(resolver), Optional.of(lateBoundFunctionResolver), Optional.of(listener)); } /** @@ -139,7 +142,7 @@ private Object evalInternal( private Object evalInternal( UnknownContext context, Optional lateBoundFunctionResolver, - CelEvaluationListener listener) + Optional listener) throws CelEvaluationException { Interpretable impl = getInterpretable(); if (getOptions().enableUnknownTracking()) { @@ -157,10 +160,15 @@ private Object evalInternal( lateBoundFunctionResolver, listener); } else { - if (lateBoundFunctionResolver.isPresent()) { - return impl.eval(context.variableResolver(), lateBoundFunctionResolver.get(), listener); + if (lateBoundFunctionResolver.isPresent() && listener.isPresent()) { + return impl.eval(context.variableResolver(), lateBoundFunctionResolver.get(), listener.get()); + } else if (lateBoundFunctionResolver.isPresent()) { + return impl.eval(context.variableResolver(), lateBoundFunctionResolver.get()); + } else if (listener.isPresent()) { + return impl.eval(context.variableResolver(), listener.get()); } - return impl.eval(context.variableResolver(), listener); + + return impl.eval(context.variableResolver()); } } diff --git a/runtime/src/main/java/dev/cel/runtime/RuntimeUnknownResolver.java b/runtime/src/main/java/dev/cel/runtime/RuntimeUnknownResolver.java index 804cdb1fd..09f20cdd4 100644 --- a/runtime/src/main/java/dev/cel/runtime/RuntimeUnknownResolver.java +++ b/runtime/src/main/java/dev/cel/runtime/RuntimeUnknownResolver.java @@ -91,8 +91,12 @@ public RuntimeUnknownResolver build() { * Return a single element unknown set if the attribute is partially unknown based on the defined * patterns. */ - Optional maybePartialUnknown(CelAttribute attribute) { - return attributeResolver.maybePartialUnknown(attribute); + Optional maybePartialUnknown(CelAttribute attribute) { + CelUnknownSet unknownSet = attributeResolver.maybePartialUnknown(attribute).orElse(null); + if (unknownSet == null) { + return Optional.empty(); + } + return Optional.of(adaptToAccumulatedUnknowns(unknownSet)); } /** Resolve a simple name to a value. */ @@ -102,7 +106,7 @@ DefaultInterpreter.IntermediateResult resolveSimpleName(String name, Long exprId if (attributeTrackingEnabled) { attr = CelAttribute.fromQualifiedIdentifier(name); - Optional result = attributeResolver.resolve(attr); + Optional result = resolveAttribute(attr); if (result.isPresent()) { return DefaultInterpreter.IntermediateResult.create(attr, result.get()); } @@ -123,7 +127,24 @@ void cacheLazilyEvaluatedResult(String name, DefaultInterpreter.IntermediateResu * resolved values behind field accesses and index operations. */ Optional resolveAttribute(CelAttribute attr) { - return attributeResolver.resolve(attr); + Object resolved = attributeResolver.resolve(attr).orElse(null); + if (resolved == null) { + return Optional.empty(); + } + + return Optional.of(maybeAdaptToAccumulatedUnknowns(resolved)); + } + + private static Object maybeAdaptToAccumulatedUnknowns(Object val) { + if (!(val instanceof CelUnknownSet)) { + return val; + } + + return adaptToAccumulatedUnknowns((CelUnknownSet) val) ; + } + + private static AccumulatedUnknowns adaptToAccumulatedUnknowns(CelUnknownSet unknowns) { + return AccumulatedUnknowns.create(unknowns.unknownExprIds(), unknowns.attributes()); } ScopedResolver withScope(Map vars) { diff --git a/runtime/src/main/java/dev/cel/runtime/UnknownTrackingInterpretable.java b/runtime/src/main/java/dev/cel/runtime/UnknownTrackingInterpretable.java index 4eaf08e65..9a1a4964a 100644 --- a/runtime/src/main/java/dev/cel/runtime/UnknownTrackingInterpretable.java +++ b/runtime/src/main/java/dev/cel/runtime/UnknownTrackingInterpretable.java @@ -36,6 +36,6 @@ public interface UnknownTrackingInterpretable { Object evalTrackingUnknowns( RuntimeUnknownResolver resolver, Optional lateBoundFunctionResolver, - CelEvaluationListener listener) + Optional listener) throws CelEvaluationException; } diff --git a/runtime/src/test/java/dev/cel/runtime/CelRuntimeTest.java b/runtime/src/test/java/dev/cel/runtime/CelRuntimeTest.java index 13199cbf3..39b0efc02 100644 --- a/runtime/src/test/java/dev/cel/runtime/CelRuntimeTest.java +++ b/runtime/src/test/java/dev/cel/runtime/CelRuntimeTest.java @@ -20,6 +20,7 @@ import com.google.api.expr.v1alpha1.Constant; import com.google.api.expr.v1alpha1.Expr; import com.google.api.expr.v1alpha1.Type.PrimitiveType; +import com.google.common.base.Stopwatch; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.protobuf.Any; @@ -41,6 +42,8 @@ import dev.cel.common.ast.CelExpr; import dev.cel.common.ast.CelExpr.ExprKind.Kind; import dev.cel.common.types.CelV1AlphaTypes; +import dev.cel.common.types.ListType; +import dev.cel.common.types.MapType; import dev.cel.common.types.SimpleType; import dev.cel.common.types.StructTypeReference; import dev.cel.compiler.CelCompiler; @@ -48,6 +51,7 @@ import dev.cel.expr.conformance.proto3.TestAllTypes; import dev.cel.parser.CelStandardMacro; import dev.cel.parser.CelUnparserFactory; +import dev.cel.runtime.CelAttribute.Qualifier; import java.util.List; import java.util.Map; import java.util.Optional; @@ -685,4 +689,49 @@ public void standardEnvironmentDisabledForRuntime_throws() throws Exception { .hasMessageThat() .contains("No matching overload for function 'size'. Overload candidates: size_string"); } + + @Test + @TestParameters("{size: 10000}") + @TestParameters("{size: 20000}") + @TestParameters("{size: 40000}") + @TestParameters("{size: 80000}") + @TestParameters("{size: 160000}") + public void benchmark_lotsOfUnknownMerges(int size) throws Exception { + CelOptions celOptions = CelOptions.current().enableUnknownTracking(true).build(); + CelCompiler celCompiler = + CelCompilerFactory.standardCelCompilerBuilder() + .addVar("list", ListType.create(SimpleType.INT)) + .addVar("unk", MapType.create(SimpleType.INT, SimpleType.INT)) + .setOptions(celOptions) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS).build(); + CelRuntime celRuntime = + CelRuntimeFactory.standardCelRuntimeBuilder() + .setOptions(celOptions) + .build(); + CelAbstractSyntaxTree ast = celCompiler.compile("list.map(y, unk[y])").getAst(); + + ImmutableList.Builder listBuilder = ImmutableList.builder(); + for (long i = 1; i <= size; i++) { + listBuilder.add(i); + } + ImmutableList list = listBuilder.build(); + CelRuntime.Program program = celRuntime.createProgram(ast); + UnknownContext ctx = UnknownContext.create(name -> { + if (name.equals("list")) { + return Optional.of(list); + } + return Optional.empty(); + }, ImmutableList.of( + CelAttributePattern.fromQualifiedIdentifier("unk") + .qualify(Qualifier.ofWildCard()))); + + // warmup cache + CelUnknownSet result = (CelUnknownSet) program.advanceEvaluation(ctx); + Stopwatch sw = Stopwatch.createStarted(); + result = (CelUnknownSet) program.advanceEvaluation(ctx); + sw.stop(); + System.err.println("Elapsed: " + sw); + + assertThat(result.attributes()).hasSize(size); + } }