Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .bazelversion
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
8.5.0
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.MoreCollectors.onlyElement;
import static dev.cel.checker.CelStandardDeclarations.StandardFunction.DURATION;
import static dev.cel.checker.CelStandardDeclarations.StandardFunction.TIMESTAMP;

Expand Down Expand Up @@ -183,9 +182,9 @@ private boolean canFold(CelNavigableMutableExpr navigableExpr) {
if (functionName.equals(Operator.EQUALS.getFunction())
|| functionName.equals(Operator.NOT_EQUALS.getFunction())) {
if (mutableCall.args().stream()
.anyMatch(node -> isExprConstantOfKind(node, CelConstant.Kind.BOOLEAN_VALUE))
.anyMatch(node -> isExprConstantOfKind(node, CelConstant.Kind.BOOLEAN_VALUE))
|| mutableCall.args().stream()
.allMatch(node -> node.getKind().equals(Kind.CONSTANT))) {
.allMatch(node -> node.getKind().equals(Kind.CONSTANT))) {
return true;
}
}
Expand All @@ -196,17 +195,69 @@ private boolean canFold(CelNavigableMutableExpr navigableExpr) {

// Default case: all call arguments must be constants. If the argument is a container (ex:
// list, map), then its arguments must be a constant.
return areChildrenArgConstant(navigableExpr);
return navigableExpr.children().allMatch(this::canEvaluate);
case SELECT:
CelNavigableMutableExpr operand = navigableExpr.children().collect(onlyElement());
return areChildrenArgConstant(operand);
return navigableExpr.children().allMatch(this::canEvaluate);
case COMPREHENSION:
return !isNestedComprehension(navigableExpr);
if (isNestedComprehension(navigableExpr)) {
return false;
}
CelMutableComprehension comprehension = navigableExpr.expr().comprehension();

if (!canEvaluate(CelNavigableMutableExpr.fromExpr(comprehension.iterRange()))
|| !canEvaluate(CelNavigableMutableExpr.fromExpr(comprehension.accuInit()))) {
return false;
}

return canEvaluateComprehensionBody(CelNavigableMutableExpr.fromExpr(comprehension.loopStep()))
&& canEvaluateComprehensionBody(CelNavigableMutableExpr.fromExpr(comprehension.loopCondition()));
default:
return false;
}
}

/**
* Checks if a subtree is safe to evaluate (i.e: it evaluates down to a constant expression)
*/
private boolean canEvaluate(CelNavigableMutableExpr expression) {
return expression.allNodes().allMatch(this::isAllowedInConstantExpr);
}

private boolean canEvaluateComprehensionBody(CelNavigableMutableExpr expression) {
return expression.allNodes().allMatch(node -> {
Kind kind = node.getKind();
if (kind.equals(Kind.IDENT) || kind.equals(Kind.COMPREHENSION)) {
return true;
}
return isAllowedInConstantExpr(node);
});
}

private boolean isAllowedInConstantExpr(CelNavigableMutableExpr node) {
Kind kind = node.getKind();
if (kind.equals(Kind.CONSTANT)
|| kind.equals(Kind.LIST)
|| kind.equals(Kind.MAP)
|| kind.equals(Kind.STRUCT)
|| kind.equals(Kind.SELECT)) {
return true;
}
if (kind.equals(Kind.CALL)) {
CelMutableCall call = node.expr().call();
return foldableFunctions.contains(call.function());
}

return false;
}

private boolean isAllowedInFoldableExpr(CelNavigableMutableExpr node) {
Kind kind = node.getKind();
if (kind.equals(Kind.IDENT) || kind.equals(Kind.COMPREHENSION)) {
return true;
}
return isAllowedInConstantExpr(node);
}

private boolean containsFoldableFunctionOnly(CelNavigableMutableExpr navigableExpr) {
return navigableExpr
.allNodes()
Expand Down Expand Up @@ -248,22 +299,6 @@ private static boolean canFoldInOperator(CelNavigableMutableExpr navigableExpr)
return true;
}

private static boolean areChildrenArgConstant(CelNavigableMutableExpr expr) {
if (expr.getKind().equals(Kind.CONSTANT)) {
return true;
}

if (expr.getKind().equals(Kind.CALL)
|| expr.getKind().equals(Kind.LIST)
|| expr.getKind().equals(Kind.MAP)
|| expr.getKind().equals(Kind.SELECT)
|| expr.getKind().equals(Kind.STRUCT)) {
return expr.children().allMatch(ConstantFoldingOptimizer::areChildrenArgConstant);
}

return false;
}

private static boolean isNestedComprehension(CelNavigableMutableExpr expr) {
Optional<CelNavigableMutableExpr> maybeParent = expr.parent();
while (maybeParent.isPresent()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
public class ConstantFoldingOptimizerTest {
private static final CelOptions CEL_OPTIONS =
CelOptions.current()
.populateMacroCalls(true)
.enableTimestampEpoch(true)
.build();
private static final Cel CEL =
Expand All @@ -56,12 +57,23 @@ public class ConstantFoldingOptimizerTest {
.addVar("y", SimpleType.DYN)
.addVar("list_var", ListType.create(SimpleType.STRING))
.addVar("map_var", MapType.create(SimpleType.STRING, SimpleType.STRING))
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
.addFunctionDeclarations(
CelFunctionDecl.newFunctionDeclaration(
"get_true",
CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)))
CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)),
CelFunctionDecl.newFunctionDeclaration(
"get_list",
CelOverloadDecl.newGlobalOverload(
"get_list_overload",
ListType.create(SimpleType.INT),
ListType.create(SimpleType.INT)))
)
.addFunctionBindings(
CelFunctionBinding.from("get_true_overload", ImmutableList.of(), unused -> true))
CelFunctionBinding.from("get_true_overload", ImmutableList.of(), unused -> true),
CelFunctionBinding.from(
"get_list_overload", ImmutableList.class, arg -> arg)
)
.addMessageTypes(TestAllTypes.getDescriptor())
.setContainer(CelContainer.ofName("cel.expr.conformance.proto3"))
.setOptions(CEL_OPTIONS)
Expand Down Expand Up @@ -371,6 +383,8 @@ public void constantFold_macros_withoutMacroCallMetadata(String source) throws E
@TestParameters("{source: 'x == 42'}")
@TestParameters("{source: 'timestamp(100)'}")
@TestParameters("{source: 'duration(\"1h\")'}")
@TestParameters("{source: '[true].exists(x, x == get_true())'}")
@TestParameters("{source: 'get_list([1, 2]).map(x, x * 2)'}")
public void constantFold_noOp(String source) throws Exception {
CelAbstractSyntaxTree ast = CEL.compile(source).getAst();

Expand Down
12 changes: 8 additions & 4 deletions policy/src/main/java/dev/cel/policy/CelPolicy.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -77,8 +78,7 @@ public abstract static class Builder {

public abstract Builder setPolicySource(CelPolicySource policySource);

// This should stay package-private to encourage add/set methods to be used instead.
abstract ImmutableMap.Builder<String, Object> metadataBuilder();
private final HashMap<String, Object> metadata = new HashMap<>();

public abstract Builder setMetadata(ImmutableMap<String, Object> value);

Expand All @@ -90,6 +90,10 @@ public List<Import> imports() {
return Collections.unmodifiableList(importList);
}

public Map<String, Object> metadata() {
return Collections.unmodifiableMap(metadata);
}

@CanIgnoreReturnValue
public Builder addImport(Import value) {
importList.add(value);
Expand All @@ -104,13 +108,13 @@ public Builder addImports(Collection<Import> values) {

@CanIgnoreReturnValue
public Builder putMetadata(String key, Object value) {
metadataBuilder().put(key, value);
metadata.put(key, value);
return this;
}

@CanIgnoreReturnValue
public Builder putMetadata(Map<String, Object> map) {
metadataBuilder().putAll(map);
metadata.putAll(map);
return this;
}

Expand Down
29 changes: 29 additions & 0 deletions policy/src/main/java/dev/cel/policy/testing/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
load("@rules_java//java:defs.bzl", "java_library")

package(
default_applicable_licenses = ["//:license"],
default_visibility = [
"//policy/testing:__pkg__",
],
)

java_library(
name = "policy_test_suite_helper",
testonly = True,
srcs = [
"PolicyTestSuiteHelper.java",
],
deps = [
"//bundle:cel",
"//common:cel_ast",
"//common:compiler_common",
"//common/formats:value_string",
"//policy",
"//policy:parser",
"//policy:parser_builder",
"//policy:policy_parser_context",
"//runtime:evaluation_exception",
"@maven//:com_google_guava_guava",
"@maven//:org_yaml_snakeyaml",
],
)
Loading
Loading