Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions runtime/src/main/java/dev/cel/runtime/AccumulatedUnknowns.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package dev.cel.runtime;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;

/**
* An internal representation used for fast accumulation of unknown expr IDs and attributes.
* For safety, this object should never be returned as an evaluated result and instead be adapted into an immutable CelUnknownSet.
*/
final class AccumulatedUnknowns {

private final List<Long> exprIds;
private final List<CelAttribute> attributes;

List<Long> exprIds() {
return exprIds;
}

List<CelAttribute> attributes() {
return attributes;
}

AccumulatedUnknowns merge(AccumulatedUnknowns arg) {
this.exprIds.addAll(arg.exprIds);
this.attributes.addAll(arg.attributes);
return this;
}

static AccumulatedUnknowns create(Long... ids) {
return create(Arrays.asList(ids));
}

static AccumulatedUnknowns create(Collection<Long> ids) {
return create(ids, new ArrayList<>());
}

static AccumulatedUnknowns create(Collection<Long> exprIds, Collection<CelAttribute> attributes) {
return new AccumulatedUnknowns(new ArrayList<>(exprIds), new ArrayList<>(attributes));
}

private AccumulatedUnknowns(List<Long> exprIds, List<CelAttribute> attributes) {
this.exprIds = exprIds;
this.attributes = attributes;
}
}
18 changes: 18 additions & 0 deletions runtime/src/main/java/dev/cel/runtime/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ java_library(
],
exports = [":base"],
deps = [
":accumulated_unknowns",
":base",
":concatenated_list_view",
":dispatcher",
Expand Down Expand Up @@ -306,6 +307,7 @@ cel_android_library(
srcs = INTERPRETER_SOURCES,
visibility = ["//visibility:private"],
deps = [
":accumulated_unknowns_android",
":base_android",
":concatenated_list_view",
":dispatcher_android",
Expand Down Expand Up @@ -1034,6 +1036,7 @@ java_library(
tags = [
],
deps = [
":accumulated_unknowns",
":evaluation_exception",
":unknown_attributes",
"//common/annotations",
Expand All @@ -1047,6 +1050,7 @@ cel_android_library(
srcs = ["InterpreterUtil.java"],
visibility = ["//visibility:private"],
deps = [
":accumulated_unknowns_android",
":evaluation_exception",
":unknown_attributes_android",
"//common/annotations",
Expand Down Expand Up @@ -1105,3 +1109,17 @@ java_library(
# used_by_android
visibility = ["//visibility:private"],
)

java_library(
name = "accumulated_unknowns",
srcs = ["AccumulatedUnknowns.java"],
visibility = ["//visibility:private"],
deps = [":unknown_attributes"],
)

cel_android_library(
name = "accumulated_unknowns_android",
srcs = ["AccumulatedUnknowns.java"],
visibility = ["//visibility:private"],
deps = [":unknown_attributes_android"],
)
24 changes: 12 additions & 12 deletions runtime/src/main/java/dev/cel/runtime/CallArgumentChecker.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class CallArgumentChecker {
private final ArrayList<Long> exprIds;
private final RuntimeUnknownResolver resolver;
private final boolean acceptPartial;
private Optional<CelUnknownSet> unknowns;
private Optional<AccumulatedUnknowns> unknowns;

private CallArgumentChecker(RuntimeUnknownResolver resolver, boolean acceptPartial) {
this.exprIds = new ArrayList<>();
Expand Down Expand Up @@ -61,29 +61,29 @@ static CallArgumentChecker createAcceptingPartial(RuntimeUnknownResolver resolve
return new CallArgumentChecker(resolver, true);
}

private static Optional<CelUnknownSet> mergeOptionalUnknowns(
Optional<CelUnknownSet> lhs, Optional<CelUnknownSet> rhs) {
private static Optional<AccumulatedUnknowns> mergeOptionalUnknowns(
Optional<AccumulatedUnknowns> lhs, Optional<AccumulatedUnknowns> rhs) {
return lhs.isPresent() ? rhs.isPresent() ? Optional.of(lhs.get().merge(rhs.get())) : lhs : rhs;
}

/** Determine if the call argument is unknown and accumulate if so. */
void checkArg(DefaultInterpreter.IntermediateResult arg) {
// Handle attribute tracked unknowns.
Optional<CelUnknownSet> argUnknowns = maybeUnknownFromArg(arg);
Optional<AccumulatedUnknowns> argUnknowns = maybeUnknownFromArg(arg);
unknowns = mergeOptionalUnknowns(unknowns, argUnknowns);

// support for ExprValue unknowns.
if (InterpreterUtil.isUnknown(arg.value())) {
CelUnknownSet unknownSet = (CelUnknownSet) arg.value();
exprIds.addAll(unknownSet.unknownExprIds());
if (InterpreterUtil.isAccumulatedUnknowns(arg.value())) {
AccumulatedUnknowns unknownSet = (AccumulatedUnknowns) arg.value();
exprIds.addAll(unknownSet.exprIds());
}
}

private Optional<CelUnknownSet> maybeUnknownFromArg(DefaultInterpreter.IntermediateResult arg) {
if (arg.value() instanceof CelUnknownSet) {
CelUnknownSet celUnknownSet = (CelUnknownSet) arg.value();
private Optional<AccumulatedUnknowns> maybeUnknownFromArg(DefaultInterpreter.IntermediateResult arg) {
if (arg.value() instanceof AccumulatedUnknowns) {
AccumulatedUnknowns celUnknownSet = (AccumulatedUnknowns) arg.value();
if (!celUnknownSet.attributes().isEmpty()) {
return Optional.of((CelUnknownSet) arg.value());
return Optional.of((AccumulatedUnknowns) arg.value());
}
}
if (!acceptPartial) {
Expand All @@ -99,7 +99,7 @@ Optional<Object> maybeUnknowns() {
}

if (!exprIds.isEmpty()) {
return Optional.of(CelUnknownSet.create(exprIds));
return Optional.of(AccumulatedUnknowns.create(exprIds));
}

return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,4 @@ public interface CelEvaluationListener {
* @param evaluatedResult Evaluated result.
*/
void callback(CelExpr expr, Object evaluatedResult);

/** Construct a listener that does nothing. */
static CelEvaluationListener noOpListener() {
return (arg1, arg2) -> {};
}
}
2 changes: 1 addition & 1 deletion runtime/src/main/java/dev/cel/runtime/CelUnknownSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ static CelUnknownSet create(Iterable<Long> unknownExprIds) {
return create(ImmutableSet.of(), ImmutableSet.copyOf(unknownExprIds));
}

private static CelUnknownSet create(
static CelUnknownSet create(
ImmutableSet<CelAttribute> attributes, ImmutableSet<Long> unknownExprIds) {
return new AutoValue_CelUnknownSet(attributes, unknownExprIds);
}
Expand Down
54 changes: 36 additions & 18 deletions runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.google.errorprone.annotations.Immutable;
import javax.annotation.concurrent.ThreadSafe;
Expand Down Expand Up @@ -137,20 +138,22 @@ static final class DefaultInterpretable implements Interpretable, UnknownTrackin
@Override
public Object eval(GlobalResolver resolver) throws CelEvaluationException {
// Result is already unwrapped from IntermediateResult.
return eval(resolver, CelEvaluationListener.noOpListener());
return evalTrackingUnknowns(
RuntimeUnknownResolver.fromResolver(resolver), Optional.empty(), Optional.empty());
}

@Override
public Object eval(GlobalResolver resolver, CelEvaluationListener listener)
throws CelEvaluationException {
return evalTrackingUnknowns(
RuntimeUnknownResolver.fromResolver(resolver), Optional.empty(), listener);
RuntimeUnknownResolver.fromResolver(resolver), Optional.empty(), Optional.of(listener));
}

@Override
public Object eval(GlobalResolver resolver, FunctionResolver lateBoundFunctionResolver)
throws CelEvaluationException {
return eval(resolver, lateBoundFunctionResolver, CelEvaluationListener.noOpListener());
return evalTrackingUnknowns(
RuntimeUnknownResolver.fromResolver(resolver), Optional.of(lateBoundFunctionResolver), Optional.empty());
}

@Override
Expand All @@ -162,19 +165,31 @@ public Object eval(
return evalTrackingUnknowns(
RuntimeUnknownResolver.fromResolver(resolver),
Optional.of(lateBoundFunctionResolver),
listener);
Optional.of(listener));
}

@Override
public Object evalTrackingUnknowns(
RuntimeUnknownResolver resolver,
Optional<? extends FunctionResolver> functionResolver,
CelEvaluationListener listener)
Optional<CelEvaluationListener> listener)
throws CelEvaluationException {
ExecutionFrame frame = newExecutionFrame(resolver, functionResolver, listener);
IntermediateResult internalResult = evalInternal(frame, ast.getExpr());

return internalResult.value();
Object underlyingValue = internalResult.value();

return maybeAdaptToCelUnknownSet(underlyingValue);
}

private static Object maybeAdaptToCelUnknownSet(Object val) {
if (!(val instanceof AccumulatedUnknowns)) {
return val;
}

AccumulatedUnknowns unknowns = (AccumulatedUnknowns) val;

return CelUnknownSet.create(ImmutableSet.copyOf(unknowns.attributes()), ImmutableSet.copyOf(unknowns.exprIds()));
}

/**
Expand All @@ -198,13 +213,13 @@ ExecutionFrame newTestExecutionFrame(GlobalResolver resolver) {
return newExecutionFrame(
RuntimeUnknownResolver.fromResolver(resolver),
Optional.empty(),
CelEvaluationListener.noOpListener());
Optional.empty());
}

private ExecutionFrame newExecutionFrame(
RuntimeUnknownResolver resolver,
Optional<? extends FunctionResolver> functionResolver,
CelEvaluationListener listener) {
Optional<CelEvaluationListener> listener) {
int comprehensionMaxIterations =
celOptions.enableComprehension() ? celOptions.comprehensionMaxIterations() : 0;
return new ExecutionFrame(listener, resolver, functionResolver, comprehensionMaxIterations);
Expand Down Expand Up @@ -244,7 +259,8 @@ private IntermediateResult evalInternal(ExecutionFrame frame, CelExpr expr)
throw new IllegalStateException(
"unexpected expression kind: " + expr.exprKind().getKind());
}
frame.getEvaluationListener().callback(expr, result.value());

frame.getEvaluationListener().ifPresent(listener -> listener.callback(expr, maybeAdaptToCelUnknownSet(result.value())));
return result;
} catch (CelRuntimeException e) {
throw CelEvaluationExceptionBuilder.newBuilder(e).setMetadata(metadata, expr.id()).build();
Expand All @@ -257,7 +273,7 @@ private IntermediateResult evalInternal(ExecutionFrame frame, CelExpr expr)
}

private static boolean isUnknownValue(Object value) {
return value instanceof CelUnknownSet || InterpreterUtil.isUnknown(value);
return InterpreterUtil.isAccumulatedUnknowns(value);
}

private static boolean isUnknownOrError(Object value) {
Expand Down Expand Up @@ -552,18 +568,20 @@ private IntermediateResult mergeBooleanUnknowns(IntermediateResult lhs, Intermed
throws CelEvaluationException {
// TODO: migrate clients to a common type that reports both expr-id unknowns
// and attribute sets.
if (lhs.value() instanceof CelUnknownSet && rhs.value() instanceof CelUnknownSet) {
Object lhsVal = lhs.value();
Object rhsVal = rhs.value();
if (lhsVal instanceof AccumulatedUnknowns && rhsVal instanceof AccumulatedUnknowns) {
return IntermediateResult.create(
((CelUnknownSet) lhs.value()).merge((CelUnknownSet) rhs.value()));
} else if (lhs.value() instanceof CelUnknownSet) {
((AccumulatedUnknowns) lhsVal).merge((AccumulatedUnknowns) rhsVal));
} else if (lhsVal instanceof AccumulatedUnknowns) {
return lhs;
} else if (rhs.value() instanceof CelUnknownSet) {
} else if (rhsVal instanceof AccumulatedUnknowns) {
return rhs;
}

// Otherwise fallback to normal impl
return IntermediateResult.create(
InterpreterUtil.shortcircuitUnknownOrThrowable(lhs.value(), rhs.value()));
InterpreterUtil.shortcircuitUnknownOrThrowable(lhsVal, rhsVal));
}

private enum ShortCircuitableOperators {
Expand Down Expand Up @@ -1050,7 +1068,7 @@ private LazyExpression(CelExpr celExpr) {

/** This class tracks the state meaningful to a single evaluation pass. */
static class ExecutionFrame {
private final CelEvaluationListener evaluationListener;
private final Optional<CelEvaluationListener> evaluationListener;
private final int maxIterations;
private final ArrayDeque<RuntimeUnknownResolver> resolvers;
private final Optional<? extends FunctionResolver> lateBoundFunctionResolver;
Expand All @@ -1059,7 +1077,7 @@ static class ExecutionFrame {
@VisibleForTesting int scopeLevel;

private ExecutionFrame(
CelEvaluationListener evaluationListener,
Optional<CelEvaluationListener> evaluationListener,
RuntimeUnknownResolver resolver,
Optional<? extends FunctionResolver> lateBoundFunctionResolver,
int maxIterations) {
Expand All @@ -1071,7 +1089,7 @@ private ExecutionFrame(
this.maxIterations = maxIterations;
}

private CelEvaluationListener getEvaluationListener() {
private Optional<CelEvaluationListener> getEvaluationListener() {
return evaluationListener;
}

Expand Down
Loading