diff --git a/.bazelversion b/.bazelversion new file mode 100644 index 000000000..6d2890793 --- /dev/null +++ b/.bazelversion @@ -0,0 +1 @@ +8.5.0 diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java index ca1744c88..dddb1fa22 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java @@ -15,7 +15,6 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.MoreCollectors.onlyElement; import static dev.cel.checker.CelStandardDeclarations.StandardFunction.DURATION; import static dev.cel.checker.CelStandardDeclarations.StandardFunction.TIMESTAMP; @@ -183,9 +182,9 @@ private boolean canFold(CelNavigableMutableExpr navigableExpr) { if (functionName.equals(Operator.EQUALS.getFunction()) || functionName.equals(Operator.NOT_EQUALS.getFunction())) { if (mutableCall.args().stream() - .anyMatch(node -> isExprConstantOfKind(node, CelConstant.Kind.BOOLEAN_VALUE)) + .anyMatch(node -> isExprConstantOfKind(node, CelConstant.Kind.BOOLEAN_VALUE)) || mutableCall.args().stream() - .allMatch(node -> node.getKind().equals(Kind.CONSTANT))) { + .allMatch(node -> node.getKind().equals(Kind.CONSTANT))) { return true; } } @@ -196,17 +195,69 @@ private boolean canFold(CelNavigableMutableExpr navigableExpr) { // Default case: all call arguments must be constants. If the argument is a container (ex: // list, map), then its arguments must be a constant. - return areChildrenArgConstant(navigableExpr); + return navigableExpr.children().allMatch(this::canEvaluate); case SELECT: - CelNavigableMutableExpr operand = navigableExpr.children().collect(onlyElement()); - return areChildrenArgConstant(operand); + return navigableExpr.children().allMatch(this::canEvaluate); case COMPREHENSION: - return !isNestedComprehension(navigableExpr); + if (isNestedComprehension(navigableExpr)) { + return false; + } + CelMutableComprehension comprehension = navigableExpr.expr().comprehension(); + + if (!canEvaluate(CelNavigableMutableExpr.fromExpr(comprehension.iterRange())) + || !canEvaluate(CelNavigableMutableExpr.fromExpr(comprehension.accuInit()))) { + return false; + } + + return canEvaluateComprehensionBody(CelNavigableMutableExpr.fromExpr(comprehension.loopStep())) + && canEvaluateComprehensionBody(CelNavigableMutableExpr.fromExpr(comprehension.loopCondition())); default: return false; } } + /** + * Checks if a subtree is safe to evaluate (i.e: it evaluates down to a constant expression) + */ + private boolean canEvaluate(CelNavigableMutableExpr expression) { + return expression.allNodes().allMatch(this::isAllowedInConstantExpr); + } + + private boolean canEvaluateComprehensionBody(CelNavigableMutableExpr expression) { + return expression.allNodes().allMatch(node -> { + Kind kind = node.getKind(); + if (kind.equals(Kind.IDENT) || kind.equals(Kind.COMPREHENSION)) { + return true; + } + return isAllowedInConstantExpr(node); + }); + } + + private boolean isAllowedInConstantExpr(CelNavigableMutableExpr node) { + Kind kind = node.getKind(); + if (kind.equals(Kind.CONSTANT) + || kind.equals(Kind.LIST) + || kind.equals(Kind.MAP) + || kind.equals(Kind.STRUCT) + || kind.equals(Kind.SELECT)) { + return true; + } + if (kind.equals(Kind.CALL)) { + CelMutableCall call = node.expr().call(); + return foldableFunctions.contains(call.function()); + } + + return false; + } + + private boolean isAllowedInFoldableExpr(CelNavigableMutableExpr node) { + Kind kind = node.getKind(); + if (kind.equals(Kind.IDENT) || kind.equals(Kind.COMPREHENSION)) { + return true; + } + return isAllowedInConstantExpr(node); + } + private boolean containsFoldableFunctionOnly(CelNavigableMutableExpr navigableExpr) { return navigableExpr .allNodes() @@ -248,22 +299,6 @@ private static boolean canFoldInOperator(CelNavigableMutableExpr navigableExpr) return true; } - private static boolean areChildrenArgConstant(CelNavigableMutableExpr expr) { - if (expr.getKind().equals(Kind.CONSTANT)) { - return true; - } - - if (expr.getKind().equals(Kind.CALL) - || expr.getKind().equals(Kind.LIST) - || expr.getKind().equals(Kind.MAP) - || expr.getKind().equals(Kind.SELECT) - || expr.getKind().equals(Kind.STRUCT)) { - return expr.children().allMatch(ConstantFoldingOptimizer::areChildrenArgConstant); - } - - return false; - } - private static boolean isNestedComprehension(CelNavigableMutableExpr expr) { Optional maybeParent = expr.parent(); while (maybeParent.isPresent()) { diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java index 1a3fac852..afc441d2a 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java @@ -48,6 +48,7 @@ public class ConstantFoldingOptimizerTest { private static final CelOptions CEL_OPTIONS = CelOptions.current() + .populateMacroCalls(true) .enableTimestampEpoch(true) .build(); private static final Cel CEL = @@ -56,12 +57,23 @@ public class ConstantFoldingOptimizerTest { .addVar("y", SimpleType.DYN) .addVar("list_var", ListType.create(SimpleType.STRING)) .addVar("map_var", MapType.create(SimpleType.STRING, SimpleType.STRING)) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) .addFunctionDeclarations( CelFunctionDecl.newFunctionDeclaration( "get_true", - CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL))) + CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)), + CelFunctionDecl.newFunctionDeclaration( + "get_list", + CelOverloadDecl.newGlobalOverload( + "get_list_overload", + ListType.create(SimpleType.INT), + ListType.create(SimpleType.INT))) + ) .addFunctionBindings( - CelFunctionBinding.from("get_true_overload", ImmutableList.of(), unused -> true)) + CelFunctionBinding.from("get_true_overload", ImmutableList.of(), unused -> true), + CelFunctionBinding.from( + "get_list_overload", ImmutableList.class, arg -> arg) + ) .addMessageTypes(TestAllTypes.getDescriptor()) .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) .setOptions(CEL_OPTIONS) @@ -371,6 +383,8 @@ public void constantFold_macros_withoutMacroCallMetadata(String source) throws E @TestParameters("{source: 'x == 42'}") @TestParameters("{source: 'timestamp(100)'}") @TestParameters("{source: 'duration(\"1h\")'}") + @TestParameters("{source: '[true].exists(x, x == get_true())'}") + @TestParameters("{source: 'get_list([1, 2]).map(x, x * 2)'}") public void constantFold_noOp(String source) throws Exception { CelAbstractSyntaxTree ast = CEL.compile(source).getAst(); diff --git a/policy/src/main/java/dev/cel/policy/CelPolicy.java b/policy/src/main/java/dev/cel/policy/CelPolicy.java index 9980d0cad..b73d9e0b1 100644 --- a/policy/src/main/java/dev/cel/policy/CelPolicy.java +++ b/policy/src/main/java/dev/cel/policy/CelPolicy.java @@ -27,6 +27,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -77,8 +78,7 @@ public abstract static class Builder { public abstract Builder setPolicySource(CelPolicySource policySource); - // This should stay package-private to encourage add/set methods to be used instead. - abstract ImmutableMap.Builder metadataBuilder(); + private final HashMap metadata = new HashMap<>(); public abstract Builder setMetadata(ImmutableMap value); @@ -90,6 +90,10 @@ public List imports() { return Collections.unmodifiableList(importList); } + public Map metadata() { + return Collections.unmodifiableMap(metadata); + } + @CanIgnoreReturnValue public Builder addImport(Import value) { importList.add(value); @@ -104,13 +108,13 @@ public Builder addImports(Collection values) { @CanIgnoreReturnValue public Builder putMetadata(String key, Object value) { - metadataBuilder().put(key, value); + metadata.put(key, value); return this; } @CanIgnoreReturnValue public Builder putMetadata(Map map) { - metadataBuilder().putAll(map); + metadata.putAll(map); return this; } diff --git a/policy/src/main/java/dev/cel/policy/testing/BUILD.bazel b/policy/src/main/java/dev/cel/policy/testing/BUILD.bazel new file mode 100644 index 000000000..6c847e0a6 --- /dev/null +++ b/policy/src/main/java/dev/cel/policy/testing/BUILD.bazel @@ -0,0 +1,29 @@ +load("@rules_java//java:defs.bzl", "java_library") + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//policy/testing:__pkg__", + ], +) + +java_library( + name = "policy_test_suite_helper", + testonly = True, + srcs = [ + "PolicyTestSuiteHelper.java", + ], + deps = [ + "//bundle:cel", + "//common:cel_ast", + "//common:compiler_common", + "//common/formats:value_string", + "//policy", + "//policy:parser", + "//policy:parser_builder", + "//policy:policy_parser_context", + "//runtime:evaluation_exception", + "@maven//:com_google_guava_guava", + "@maven//:org_yaml_snakeyaml", + ], +) diff --git a/policy/src/main/java/dev/cel/policy/testing/PolicyTestSuiteHelper.java b/policy/src/main/java/dev/cel/policy/testing/PolicyTestSuiteHelper.java new file mode 100644 index 000000000..99bcab727 --- /dev/null +++ b/policy/src/main/java/dev/cel/policy/testing/PolicyTestSuiteHelper.java @@ -0,0 +1,192 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.policy.testing; + +import static com.google.common.base.Strings.isNullOrEmpty; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Ascii; +import com.google.common.collect.ImmutableMap; +import com.google.common.io.Resources; +import dev.cel.bundle.Cel; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelValidationException; +import dev.cel.runtime.CelEvaluationException; +import java.io.IOException; +import java.net.URL; +import java.util.List; +import java.util.Map; +import org.yaml.snakeyaml.LoaderOptions; +import org.yaml.snakeyaml.Yaml; +import org.yaml.snakeyaml.constructor.Constructor; + +/** + * Helper to assist with policy testing. + * + **/ +public final class PolicyTestSuiteHelper { + + /** + * TODO + */ + public static PolicyTestSuite readTestSuite(String path) throws IOException { + Yaml yaml = new Yaml(new Constructor(PolicyTestSuite.class, new LoaderOptions())); + String testContent = readFile(path); + + return yaml.load(testContent); + } + + /** + * TODO + * @param yamlPath + * @return + * @throws IOException + */ + public static String readFromYaml(String yamlPath) throws IOException { + return readFile(yamlPath); + } + + /** + * TestSuite describes a set of tests divided by section. + * + *

Visibility must be public for YAML deserialization to work. This is effectively + * package-private since the outer class is. + */ + @VisibleForTesting + public static final class PolicyTestSuite { + private String description; + private List section; + + public void setDescription(String description) { + this.description = description; + } + + public void setSection(List section) { + this.section = section; + } + + public String getDescription() { + return description; + } + + public List getSection() { + return section; + } + + @VisibleForTesting + public static final class PolicyTestSection { + private String name; + private List tests; + + public void setName(String name) { + this.name = name; + } + + public void setTests(List tests) { + this.tests = tests; + } + + public String getName() { + return name; + } + + public List getTests() { + return tests; + } + + @VisibleForTesting + public static final class PolicyTestCase { + private String name; + private Map input; + private String output; + + public void setName(String name) { + this.name = name; + } + + public void setInput(Map input) { + this.input = input; + } + + public void setOutput(String output) { + this.output = output; + } + + public String getName() { + return name; + } + + public Map getInput() { + return input; + } + + public String getOutput() { + return output; + } + + @VisibleForTesting + public static final class PolicyTestInput { + private Object value; + private String expr; + + public Object getValue() { + return value; + } + + public void setValue(Object value) { + this.value = value; + } + + public String getExpr() { + return expr; + } + + public void setExpr(String expr) { + this.expr = expr; + } + } + + public ImmutableMap toInputMap(Cel cel) + throws CelValidationException, CelEvaluationException { + ImmutableMap.Builder inputBuilder = ImmutableMap.builderWithExpectedSize( + input.size()); + for (Map.Entry entry : input.entrySet()) { + String exprInput = entry.getValue().getExpr(); + if (isNullOrEmpty(exprInput)) { + inputBuilder.put(entry.getKey(), entry.getValue().getValue()); + } else { + CelAbstractSyntaxTree exprInputAst = cel.compile(exprInput).getAst(); + inputBuilder.put(entry.getKey(), cel.createProgram(exprInputAst).eval()); + } + } + + return inputBuilder.buildOrThrow(); + } + } + } + } + + + private static URL getResource(String path) { + return Resources.getResource(Ascii.toLowerCase(path)); + } + + private static String readFile(String path) throws IOException { + return Resources.toString(getResource(path), UTF_8); + } + + private PolicyTestSuiteHelper() {} +} diff --git a/policy/src/test/java/dev/cel/policy/BUILD.bazel b/policy/src/test/java/dev/cel/policy/BUILD.bazel index 9106caf70..d51b5dc3e 100644 --- a/policy/src/test/java/dev/cel/policy/BUILD.bazel +++ b/policy/src/test/java/dev/cel/policy/BUILD.bazel @@ -33,6 +33,7 @@ java_library( "//policy:policy_parser_context", "//policy:source", "//policy:validation_exception", + "//policy/testing:policy_test_suite_helper", "//runtime", "//runtime:function_binding", "//runtime:late_function_binding", diff --git a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java index fa0da8a9a..c38e1f8e0 100644 --- a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java +++ b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java @@ -14,9 +14,8 @@ package dev.cel.policy; -import static com.google.common.base.Strings.isNullOrEmpty; import static com.google.common.truth.Truth.assertThat; -import static dev.cel.policy.PolicyTestHelper.readFromYaml; +import static dev.cel.policy.testing.PolicyTestSuiteHelper.readFromYaml; import static org.junit.Assert.assertThrows; import com.google.common.collect.ImmutableList; @@ -38,17 +37,15 @@ import dev.cel.parser.CelStandardMacro; import dev.cel.parser.CelUnparserFactory; import dev.cel.policy.PolicyTestHelper.K8sTagHandler; -import dev.cel.policy.PolicyTestHelper.PolicyTestSuite; -import dev.cel.policy.PolicyTestHelper.PolicyTestSuite.PolicyTestSection; -import dev.cel.policy.PolicyTestHelper.PolicyTestSuite.PolicyTestSection.PolicyTestCase; -import dev.cel.policy.PolicyTestHelper.PolicyTestSuite.PolicyTestSection.PolicyTestCase.PolicyTestInput; import dev.cel.policy.PolicyTestHelper.TestYamlPolicy; +import dev.cel.policy.testing.PolicyTestSuiteHelper.PolicyTestSuite; +import dev.cel.policy.testing.PolicyTestSuiteHelper.PolicyTestSuite.PolicyTestSection; +import dev.cel.policy.testing.PolicyTestSuiteHelper.PolicyTestSuite.PolicyTestSection.PolicyTestCase; import dev.cel.runtime.CelFunctionBinding; import dev.cel.runtime.CelLateFunctionBindings; import dev.cel.testing.testdata.SingleFileProto.SingleFile; import dev.cel.testing.testdata.proto3.StandaloneGlobalEnum; import java.io.IOException; -import java.util.Map; import java.util.Optional; import org.junit.Test; import org.junit.runner.RunWith; @@ -215,17 +212,8 @@ public void evaluateYamlPolicy_withCanonicalTestData( // Compile then evaluate the policy CelAbstractSyntaxTree compiledPolicyAst = CelPolicyCompilerFactory.newPolicyCompiler(cel).build().compile(policy); - ImmutableMap.Builder inputBuilder = ImmutableMap.builder(); - for (Map.Entry entry : testData.testCase.getInput().entrySet()) { - String exprInput = entry.getValue().getExpr(); - if (isNullOrEmpty(exprInput)) { - inputBuilder.put(entry.getKey(), entry.getValue().getValue()); - } else { - CelAbstractSyntaxTree exprInputAst = cel.compile(exprInput).getAst(); - inputBuilder.put(entry.getKey(), cel.createProgram(exprInputAst).eval()); - } - } - Object evalResult = cel.createProgram(compiledPolicyAst).eval(inputBuilder.buildOrThrow()); + ImmutableMap inputMap = testData.testCase.toInputMap(cel); + Object evalResult = cel.createProgram(compiledPolicyAst).eval(inputMap); // Assert // Note that policies may either produce an optional or a non-optional result, diff --git a/policy/src/test/java/dev/cel/policy/PolicyTestHelper.java b/policy/src/test/java/dev/cel/policy/PolicyTestHelper.java index 8d9e0084b..dab91afd7 100644 --- a/policy/src/test/java/dev/cel/policy/PolicyTestHelper.java +++ b/policy/src/test/java/dev/cel/policy/PolicyTestHelper.java @@ -1,42 +1,19 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - package dev.cel.policy; -import static java.nio.charset.StandardCharsets.UTF_8; +import static dev.cel.policy.testing.PolicyTestSuiteHelper.readFromYaml; +import static dev.cel.policy.testing.PolicyTestSuiteHelper.readTestSuite; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Ascii; -import com.google.common.io.Resources; import dev.cel.common.formats.ValueString; import dev.cel.policy.CelPolicy.Match; import dev.cel.policy.CelPolicy.Match.Result; import dev.cel.policy.CelPolicy.Rule; import dev.cel.policy.CelPolicyParser.TagVisitor; +import dev.cel.policy.testing.PolicyTestSuiteHelper.PolicyTestSuite; import java.io.IOException; -import java.net.URL; -import java.util.List; -import java.util.Map; -import org.yaml.snakeyaml.LoaderOptions; -import org.yaml.snakeyaml.Yaml; -import org.yaml.snakeyaml.constructor.Constructor; import org.yaml.snakeyaml.nodes.Node; import org.yaml.snakeyaml.nodes.SequenceNode; -/** Package-private class to assist with policy testing. */ final class PolicyTestHelper { - enum TestYamlPolicy { NESTED_RULE( "nested_rule", @@ -135,128 +112,11 @@ String readConfigYamlContent() throws IOException { } PolicyTestSuite readTestYamlContent() throws IOException { - Yaml yaml = new Yaml(new Constructor(PolicyTestSuite.class, new LoaderOptions())); - String testContent = readFile(String.format("policy/%s/tests.yaml", name)); - - return yaml.load(testContent); - } - } - - static String readFromYaml(String yamlPath) throws IOException { - return readFile(yamlPath); - } - - /** - * TestSuite describes a set of tests divided by section. - * - *

Visibility must be public for YAML deserialization to work. This is effectively - * package-private since the outer class is. - */ - @VisibleForTesting - public static final class PolicyTestSuite { - private String description; - private List section; - - public void setDescription(String description) { - this.description = description; - } - - public void setSection(List section) { - this.section = section; - } - - public String getDescription() { - return description; - } - - public List getSection() { - return section; - } - - @VisibleForTesting - public static final class PolicyTestSection { - private String name; - private List tests; - - public void setName(String name) { - this.name = name; - } - - public void setTests(List tests) { - this.tests = tests; - } - - public String getName() { - return name; - } - - public List getTests() { - return tests; - } - - @VisibleForTesting - public static final class PolicyTestCase { - private String name; - private Map input; - private String output; - - public void setName(String name) { - this.name = name; - } - - public void setInput(Map input) { - this.input = input; - } - - public void setOutput(String output) { - this.output = output; - } - - public String getName() { - return name; - } - - public Map getInput() { - return input; - } - - public String getOutput() { - return output; - } - - @VisibleForTesting - public static final class PolicyTestInput { - private Object value; - private String expr; - - public Object getValue() { - return value; - } - - public void setValue(Object value) { - this.value = value; - } - - public String getExpr() { - return expr; - } - - public void setExpr(String expr) { - this.expr = expr; - } - } - } + String testPath = String.format("policy/%s/tests.yaml", name); + return readTestSuite(testPath); } } - private static URL getResource(String path) { - return Resources.getResource(Ascii.toLowerCase(path)); - } - - private static String readFile(String path) throws IOException { - return Resources.toString(getResource(path), UTF_8); - } - static class K8sTagHandler implements TagVisitor { @Override @@ -360,3 +220,5 @@ public void visitMatchTag( private PolicyTestHelper() {} } + + diff --git a/policy/testing/BUILD.bazel b/policy/testing/BUILD.bazel new file mode 100644 index 000000000..834c0a978 --- /dev/null +++ b/policy/testing/BUILD.bazel @@ -0,0 +1,12 @@ +load("@rules_java//java:defs.bzl", "java_library") + +package( + default_applicable_licenses = ["//:license"], + default_visibility = ["//:internal"], +) + +java_library( + name = "policy_test_suite_helper", + testonly = True, + exports = ["//policy/src/main/java/dev/cel/policy/testing:policy_test_suite_helper"], +) diff --git a/tools/ai/BUILD.bazel b/tools/ai/BUILD.bazel new file mode 100644 index 000000000..97ee7aeef --- /dev/null +++ b/tools/ai/BUILD.bazel @@ -0,0 +1,17 @@ +load("@rules_java//java:defs.bzl", "java_library") + +package( + default_applicable_licenses = ["//:license"], + default_visibility = ["//visibility:public"], +) + +java_library( + name = "agentic_policy_compiler", + exports = ["//tools/src/main/java/dev/cel/tools/ai:agentic_policy_compiler"], +) + +alias( + name = "test_policies", + testonly = True, + actual = "//tools/src/test/resources:test_policies", +) diff --git a/tools/src/main/java/dev/cel/tools/ai/AgenticPolicyCompiler.java b/tools/src/main/java/dev/cel/tools/ai/AgenticPolicyCompiler.java new file mode 100644 index 000000000..778837f80 --- /dev/null +++ b/tools/src/main/java/dev/cel/tools/ai/AgenticPolicyCompiler.java @@ -0,0 +1,176 @@ +package dev.cel.tools.ai; + +import static dev.cel.common.formats.YamlHelper.assertYamlType; + +import dev.cel.bundle.Cel; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.formats.ValueString; +import dev.cel.common.formats.YamlHelper.YamlNodeType; +import dev.cel.policy.CelPolicy; +import dev.cel.policy.CelPolicy.Match; +import dev.cel.policy.CelPolicy.Match.Result; +import dev.cel.policy.CelPolicy.Rule; +import dev.cel.policy.CelPolicy.Variable; +import dev.cel.policy.CelPolicyCompiler; +import dev.cel.policy.CelPolicyCompilerFactory; +import dev.cel.policy.CelPolicyParser; +import dev.cel.policy.CelPolicyParser.TagVisitor; +import dev.cel.policy.CelPolicyParserFactory; +import dev.cel.policy.CelPolicyValidationException; +import dev.cel.policy.PolicyParserContext; +import java.util.ArrayList; +import java.util.List; +import org.yaml.snakeyaml.nodes.MappingNode; +import org.yaml.snakeyaml.nodes.Node; +import org.yaml.snakeyaml.nodes.NodeTuple; +import org.yaml.snakeyaml.nodes.ScalarNode; +import org.yaml.snakeyaml.nodes.SequenceNode; + +public final class AgenticPolicyCompiler { + + private static final CelPolicyParser POLICY_PARSER = + CelPolicyParserFactory.newYamlParserBuilder() + .addTagVisitor(new AgenticPolicyTagHandler()) + .build(); + + private final CelPolicyCompiler policyCompiler; + + public static AgenticPolicyCompiler newInstance(Cel cel) { + return new AgenticPolicyCompiler(cel); + } + + private AgenticPolicyCompiler(Cel cel) { + this.policyCompiler = CelPolicyCompilerFactory.newPolicyCompiler(cel).build(); + } + + public CelAbstractSyntaxTree compile(String policySource) throws CelPolicyValidationException { + CelPolicy policy = POLICY_PARSER.parse(policySource); + return policyCompiler.compile(policy); + } + + private static class AgenticPolicyTagHandler implements TagVisitor { + + @Override + public void visitPolicyTag( + PolicyParserContext ctx, + long id, + String tagName, + Node node, + CelPolicy.Builder policyBuilder) { + + switch (tagName) { + case "default": + if (assertYamlType(ctx, id, node, YamlNodeType.STRING)) { + policyBuilder.putMetadata("default_effect", ((ScalarNode) node).getValue()); + } + break; + + case "variables": + if (!assertYamlType(ctx, id, node, YamlNodeType.LIST)) return; + List parsedVariables = new ArrayList<>(); + SequenceNode varList = (SequenceNode) node; + + for (Node varNode : varList.getValue()) { + if (assertYamlType(ctx, ctx.collectMetadata(varNode), varNode, YamlNodeType.MAP)) { + MappingNode map = (MappingNode) varNode; + for (NodeTuple tuple : map.getValue()) { + String name = ((ScalarNode) tuple.getKeyNode()).getValue(); + String expr = ((ScalarNode) tuple.getValueNode()).getValue(); + parsedVariables.add(Variable.newBuilder() + .setName(ValueString.of(ctx.collectMetadata(tuple.getKeyNode()), name)) + .setExpression(ValueString.of(ctx.collectMetadata(tuple.getValueNode()), expr)) + .build()); + } + } + } + policyBuilder.putMetadata("top_level_variables", parsedVariables); + break; + + case "rules": + if (!assertYamlType(ctx, id, node, YamlNodeType.LIST)) return; + SequenceNode rulesNode = (SequenceNode) node; + Rule.Builder subRuleBuilder = Rule.newBuilder(ctx.collectMetadata(rulesNode)); + + if (policyBuilder.metadata().containsKey("top_level_variables")) { + List variables = (List) policyBuilder.metadata().get("top_level_variables"); + subRuleBuilder.addVariables(variables); + } + + for (Node ruleNode : rulesNode.getValue()) { + policyBuilder.putMetadata("effect", "deny"); + policyBuilder.putMetadata("message", ""); + policyBuilder.putMetadata("output_expr", null); + + Match subMatch = ctx.parseMatch(ctx, policyBuilder, ruleNode); + subRuleBuilder.addMatches(subMatch); + } + + if (policyBuilder.metadata().containsKey("default_effect")) { + String defaultEffect = policyBuilder.metadata().get("default_effect").toString(); + Match defaultMatch = Match.newBuilder(ctx.nextId()) + .setCondition(ValueString.of(ctx.nextId(), "true")) + .setResult(Result.ofOutput(ValueString.of(ctx.nextId(), generateMessageOutput(defaultEffect, "")))) + .build(); + subRuleBuilder.addMatches(defaultMatch); + } + policyBuilder.setRule(subRuleBuilder.build()); + break; + + default: + TagVisitor.super.visitPolicyTag(ctx, id, tagName, node, policyBuilder); + break; + } + } + + @Override + public void visitMatchTag( + PolicyParserContext ctx, + long id, + String tagName, + Node node, + CelPolicy.Builder policyBuilder, + Match.Builder matchBuilder) { + + switch (tagName) { + case "description": + if (assertYamlType(ctx, id, node, YamlNodeType.STRING)) { + matchBuilder.setExplanation(ValueString.of(ctx.nextId(), ((ScalarNode) node).getValue())); + } + break; + + case "effect": + case "message": + case "output_expr": + if (!assertYamlType(ctx, id, node, YamlNodeType.STRING)) return; + + String value = ((ScalarNode) node).getValue(); + policyBuilder.putMetadata(tagName, value); + + String currentEffect = (String) policyBuilder.metadata().get("effect"); + String currentMessage = (String) policyBuilder.metadata().get("message"); + String currentOutputExpr = (String) policyBuilder.metadata().get("output_expr"); + + String finalOutput = (currentOutputExpr != null) + ? generateDetailsOutput(currentEffect, currentOutputExpr) + : generateMessageOutput(currentEffect, currentMessage); + + matchBuilder.setResult(Result.ofOutput(ValueString.of(ctx.nextId(), finalOutput))); + break; + + default: + TagVisitor.super.visitMatchTag(ctx, id, tagName, node, policyBuilder, matchBuilder); + break; + } + } + + // The following will likely benefit from having a concrete output structure + private static String generateMessageOutput(String effect, String message) { + String safeMessage = message.replace("'", "\\'"); + return String.format("{'effect': '%s', 'message': '%s'}", effect, safeMessage); + } + + private static String generateDetailsOutput(String effect, String outputExpression) { + return String.format("{'effect': '%s', 'details': %s}", effect, outputExpression); + } + } +} diff --git a/tools/src/main/java/dev/cel/tools/ai/BUILD.bazel b/tools/src/main/java/dev/cel/tools/ai/BUILD.bazel new file mode 100644 index 000000000..150e06636 --- /dev/null +++ b/tools/src/main/java/dev/cel/tools/ai/BUILD.bazel @@ -0,0 +1,49 @@ +load("@com_google_protobuf//bazel:java_proto_library.bzl", "java_proto_library") +load("@rules_java//java:defs.bzl", "java_library") + +package( + default_applicable_licenses = [ + "//:license", + ], + default_visibility = ["//visibility:public"], + # default_visibility = [ + # "//tools/ai:__pkg__", + # ], +) + +java_library( + name = "agentic_policy_compiler", + srcs = ["AgenticPolicyCompiler.java"], + deps = [ + ":agent_context_java_proto", + "//bundle:cel", + "//common:cel_ast", + "//common/formats:value_string", + "//common/formats:yaml_helper", + "//common/types", + "//policy", + "//policy:compiler", + "//policy:compiler_factory", + "//policy:parser", + "//policy:parser_factory", + "//policy:policy_parser_context", + "//policy:validation_exception", + "@maven//:com_google_protobuf_protobuf_java", + "@maven//:org_yaml_snakeyaml", + ], +) + +proto_library( + name = "agent_context_proto", + srcs = ["agent_context.proto"], + deps = [ + "@com_google_protobuf//:duration_proto", + "@com_google_protobuf//:struct_proto", + "@com_google_protobuf//:timestamp_proto", + ], +) + +java_proto_library( + name = "agent_context_java_proto", + deps = [":agent_context_proto"], +) diff --git a/tools/src/main/java/dev/cel/tools/ai/agent_context.proto b/tools/src/main/java/dev/cel/tools/ai/agent_context.proto new file mode 100644 index 000000000..2f7a1d455 --- /dev/null +++ b/tools/src/main/java/dev/cel/tools/ai/agent_context.proto @@ -0,0 +1,468 @@ +edition = "2024"; + +package cel.expr.ai; + +import "google/protobuf/duration.proto"; +import "google/protobuf/struct.proto"; +import "google/protobuf/timestamp.proto"; + +option java_package = "dev.cel.expr.ai"; +option java_outer_classname = "AgentContextProto"; + +// Agent represents the AI System or Service being governed. +// It encapsulates the static configuration (Manifests, Identity) and the +// dynamic runtime state (Context, Inputs, Outputs). +message Agent { + // The unique resource name of the agent. + // e.g. "agents/finance-helper" or "publishers/google/agents/gemini-pro" + string name = 1; + + // Human-readable description of the agent's purpose. + string description = 2; + + // The semantic version of the agent definition. + string version = 3; + + // The underlying model family backing this agent. + Model model = 4; + + // The provider or vendor responsible for hosting/managing this agent. + AgentProvider provider = 5; + + // Identity of the Agent itself (Service Account / Principal) + // Independent of 'request.auth.principal' which may be the end user + // credentials or the agent's identity + AgentAuth auth = 6; + + // The accumulated security context (Trust, Sensitivity, Data Sources). + AgentContext context = 7; + + // The current turn's input (Prompt + Attachments) + AgentMessage input = 8; + + // The pending response (if evaluating egress/output policies) + AgentMessage output = 9; +} + +// AgentAuth represents the identity of the Agent itself. +// Independent of 'request.auth.principal' which may be the end user +// credentials or the agent's identity +message AgentAuth { + // The principal of the agent, prefer SPIFFE format of: + // spiffe:///ns//sa/ + // See: https://spiffe.io/docs/latest/spiffe/concepts/#spiffe-identifiers + string principal = 1; + + // Map of string keys to structured claims about the agent. + // For example, with JWT-based tokens, the claims would include fields + // indicating the following: + // + // - The issuer 'iss' (e.g. url of the identity provider) + // - The audience(s) 'aud' (e.g. the intended recipient(s) of the token) + // - The token's expiration time ('exp') + // - The token's subject ('sub') + google.protobuf.Struct claims = 2; + + // The OAuth scopes granted to the agent. + // This is a list of strings, where each string is a valid OAuth scope + // (e.g. "https://www.googleapis.com/auth/cloud-platform"). + repeated string oauth_scopes = 3; +} + +// AgentContext represents the aggregate security and data governance state +// of the agent's context window. +message AgentContext { + // Aggregated trust level associated with relevant data in the window + // (Min of all inputs). + TrustLevel trust = 1; + + // Origin/Lineage tracking. + repeated DataSource sources = 2; + + // The flattened text content of the current prompt. + string prompt = 3; + + // Describes the provenance of a data included in the context. + message DataSource { + // Unique id describing the originating data source. + string id = 1; // e.g. "bigquery:sales_table" + + // The category of origin for this data. + string provenance = 2; // e.g. "UserPrompt", "Database:Secure", "PublicWeb" + } + + // Extensions for provider-specific structured context metadata. + // + // Information which cannot be considered authoritative, but rather should be + // combined in very specific fashion with other inputs to the policy engine, + // or with out-of-band context should be provided via extension fields to + // allow the data to be supplied to the policy runtime without allowing policy + // authors to reference it directly. + // + // For example, the agent context may contain sensitive information, + // but the parameters supplied to a tool call may be non-sensitive. A + // conservative approach might assume that if the context is sensitive, the + // call must also be sensitive, but this may not be the case; hence, data + // sensitivity should be assessed via helper functions which determines the + // sensitivity most appropriate for the situation. + + extensions 1000 to 9999 [ + verification = DECLARATION, + declaration = { + number: 1000, + reserved: true + }, + declaration = { + number: 1001, + reserved: true + } + ]; +} + +// Describes the integrity/veracity of the data. +message TrustLevel { + // The trust level of the data. + // e.g. "untrusted", "trusted", "trusted_3p" + string level = 1; + + // Findings which support or are associated with this level. + repeated Finding findings = 2; +} + +// ClassificationLabel describes the classification of data within the context. +message ClassificationLabel { + // The common categories for different labels, may correspond to different + // classification systems. + enum Category { + // Unspecified category. + CATEGORY_UNSPECIFIED = 0; + // Sensitivity labels provide a hint about the nature of the data. + // e.g. 'pii', 'internal' + SENSITIVITY = 1; + // Safety labels provide a hint about the nature of the content provided or + // produced. e.g. 'child_safety', 'responsible_ai' + SAFETY = 2; + // Threat labels indicate some kind of attack on the agent or system. + // e.g. 'prompt_injection', 'malicious_uri' + THREAT = 3; + } + + // Common labels are 'pii', 'internal', 'child_safety' + string name = 1; + + // The category of the label. Optional, but recommended. + Category category = 2; + + // Findings which support or are associated with this label. + repeated Finding findings = 3; +} + +// For a given label, either sensitivity or trust, this message describes +// findings and confidence values associated with the label. +message Finding { + // The name of the confidence measure. + // e.g. "picc_score", "affinity_score" + string value = 1; + + // The confidence score between 0 and 1. + double confidence = 2; + + // An optional explanation for the confidence score. + // e.g. "The confidence score is low because the data is from a public + // source." + string explanation = 3; +} + +// AgentMessage represents a single turn in the conversation. +// It acts as a container for multimodal content (Text, Files, Tool Results). +message AgentMessage { + // A discrete unit of content within the message. + message Part { + oneof type { + // User or System text input. + ContentPart prompt = 1; + + // A request to execute a specific tool. + // + // If a call has been completed, the call will have the result or + // error populated. Calls which have not yet been resolved will only have + // the intent (arguments) populated. + ToolCall tool_call = 2; + + // A file or multimodal object (Image, PDF). + ContentPart attachment = 3; + + // An error that occurred during processing. + ErrorPart error = 4; + } + } + + // The actor who constructed the message (e.g., "user", "model", "tool"). + string role = 1; + + // The ordered sequence of content parts. + // + // In the case of a tool call, the result or error will be populated within + // the `ToolCall` message rather than split into a separate `Part`. + repeated Part parts = 2; + + // Arbitrary metadata associated with the message turn. + google.protobuf.Struct metadata = 3; + + // Message creation time + google.protobuf.Timestamp time = 4; +} + +// ContentPart is a catch-all message type capable of encapsulating other +// messages within its `structured_content` field. +// +// For example, a series of sub-agent MCP tool calls and results may be +// encapsulated as an `AgentMessage` in JSON form within the +// `structured_content` field. +// +// The approach is unconventional, but indicates how the data representation +// provided to policy requires helper methods to help make agent policies +// sensible and with support to type-convert from json to proto perhaps being +// a necessary on-demand feature within agent policies. +message ContentPart { + // Unique identifier for this content part. + string id = 1; + + // The type of content. + // + // Common values include: "text", "file", "json" + string type = 2; + + // The MIME type of the content. + // + // Common values include: "text/plain", "application/json", "image/png" + string mime_type = 3; + + // The name of the content. + string name = 4; + + // The description of the content. + string description = 5; + + // The URI of the content. + string uri = 6; + + // The string serialized representation of the content, either plain text or + // serialized JSON reflected from `structured_content`. + string content = 7; + + // The binary representation of the content. + // + // This field is used to represent binary data (e.g., images, PDFs) or + // serialized proto messages which come over the wire as base64-encoded string + // values that are expected to be decoded into binary data. + bytes data = 8; + + // The JSON object representation of the content, if applicable. + google.protobuf.Struct structured_content = 9; + + // Arbitrary metadata associated with the content part. + google.protobuf.Struct annotations = 10; + + // Timestamp associated with the content part. + google.protobuf.Timestamp time = 11; + + // Extensions for content-specific metadata. + extensions 1000 to 9999 [ + verification = DECLARATION, + declaration = { + number: 1000, + reserved: true + } + ]; +} + +// ErrorPart represents a processing error within the agent loop. +message ErrorPart { + // The identifier of the specific ContentPart, ToolCall, or Message that + // caused this error. Used to correlate the failure back to the originating + // action (e.g., matching a failed tool call). + string id = 1; + + // Standardized error code (e.g., gRPC status code or HTTP status). + int64 code = 2; + + // Developer-facing error message describing the failure. + string error_message = 3; + + // Timestamp when the error occurred. + google.protobuf.Timestamp time = 4; +} + +// AgentProvider describes the entity responsible for the agent's operation. +message AgentProvider { + // The base URL or endpoint where the agent service is hosted. + string url = 1; + + // The name of the organization providing the agent (e.g. "Google", + // "Salesforce"). + string organization = 2; +} + +// Model describes the AI model backing the agent. +message Model { + // Identifier of the model family (ex: gemini-pro, gpt-4 ...) + string name = 1; +} + +// ToolManifest describes a collection of tools provided by a specific +// source. +message ToolManifest { + // Metadata about the tool provider itself, including authorization + // requirements. + ToolProvider provider = 1; + + // Collection of Tool instances specified by the provider. + repeated Tool tools = 2; +} + +// Information about how the tools were provided and by whom. +message ToolProvider { + // URL where the tools were provided. + string url = 1; + + // Name of the tool provider. + string organization = 2; // e.g. "google-cloud" + + // URL for the OAuth authorization endpoint supported by this tool provider + string authorization_server_url = 3; + + // Repeated set of OAuth scopes for this tool provider. + repeated string supported_scopes = 4; +} + +// Tool describes a specific function or capability available to the agent. +message Tool { + // The unique name of the tool + string name = 1; // (e.g. "weather_lookup"). + + // Human readable description of what the tool does. + string description = 2; + + // JSON Schema defining the expected arguments. + google.protobuf.Struct input_schema = 3; + + // JSON Schema defining the expected output. + google.protobuf.Struct output_schema = 4; + + // Behavioral hints about the tool. + ToolAnnotations annotations = 5; + + // Arbitrary tool metadata. + google.protobuf.Struct metadata = 6; +} + +// Hints for describing a tool's behavior. +// +// Informed by annotations common to the MCP spec and conventions common to +// other agent frameworks. +message ToolAnnotations { + // If true, the tool does not modify its environment. + // Default: false + bool read_only = 1; + + // If true, the tool may perform destructive updates to its environment. + // If false, the tool performs only additive updates. + // NOTE: This property is meaningful only when `read_only == false` + bool destructive = 2; + + // If true, calling the tool repeatedly with the same arguments will have no + // additional effect on its environment. + // NOTE: This property is meaningful only when `read_only == false`. + bool idempotent = 3; + + // If true, this tool may interact with an "open world" of external entities. + // If false, the tools domain of interaction is closed. For example, the + // world of a web search tool is open, whereas that of a memory tool is not. + // + // Part of the lethal trifecta is using a tool which interacts with an open + // world as this provides an exfiltration path for sensitive data to leak + // to untrusted parties. + bool open_world = 4; + + // If true, this tool is intended to be called asynchronously. + // For example, a tool that starts a simulation process on a server and + // returns immediately. + bool async = 5; + + // The trust level of the tool's output. + // + // Part of the lethal trifecta is using a tool which outputs untrusted data. + TrustLevel output_trust = 6; + + // Extensions for provider-specific structured tool metadata. + // + // Such information should be considered supplementary to policies which + // consider such hints in conjuction with data provided to the tool call. + extensions 1000 to 9999 [ + verification = DECLARATION, + declaration = { + number: 1000, + reserved: true + }, + declaration = { + number: 1001, + reserved: true + }, + declaration = { + number: 1002, + reserved: true + }, + declaration = { + number: 1003, + reserved: true + }, + declaration = { + number: 1004, + reserved: true + } + ]; +} + +// ToolCall represents a specific invocation of a tool by the agent. +// It captures the intent (arguments), the status (result/error), and +// governance metadata (confirmation). +message ToolCall { + // Unique identifier for this tool call. + // Used to correlate the call with its result or error in the history. + string id = 1; + + // The name of the tool being called (e.g., "weather_lookup"). + // This should match a tool defined in the agent's ToolManifest. + string name = 2; + + // The arguments provided to the tool call. + // Policies can inspect these values to enforce data safety (e.g. no PII). + google.protobuf.Struct params = 3; + + // The execution status of the tool call. + // This field is populated if the tool has already been executed (in history). + oneof status { + // The successful output of the tool. + ContentPart result = 4; + + // The error encountered during execution. + ErrorPart error = 5; + } + + // Timestamp when the tool call was initiated. + google.protobuf.Timestamp time = 6; + + // Indicates if the user explicitly confirmed this action. + // Useful for Human-in-the-Loop (HITL) policies. + bool user_confirmed = 7; + + // Extensions for tool call specific metadata. + extensions 1000 to 9999 [ + verification = DECLARATION, + declaration = { + number: 1000, + reserved: true + } + ]; +} diff --git a/tools/src/test/java/dev/cel/tools/ai/AgenticPolicyCompilerTest.java b/tools/src/test/java/dev/cel/tools/ai/AgenticPolicyCompilerTest.java new file mode 100644 index 000000000..3662da815 --- /dev/null +++ b/tools/src/test/java/dev/cel/tools/ai/AgenticPolicyCompilerTest.java @@ -0,0 +1,306 @@ +package dev.cel.tools.ai; + +import static dev.cel.common.CelFunctionDecl.newFunctionDeclaration; +import static dev.cel.common.CelOverloadDecl.newGlobalOverload; +import static dev.cel.common.CelOverloadDecl.newMemberOverload; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.base.Ascii; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.io.Resources; +import com.google.common.truth.Expect; +import com.google.testing.junit.testparameterinjector.TestParameter; +import com.google.testing.junit.testparameterinjector.TestParameterInjector; +import dev.cel.bundle.Cel; +import dev.cel.bundle.CelFactory; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelContainer; +import dev.cel.common.CelValidationException; +import dev.cel.common.types.ListType; +import dev.cel.common.types.SimpleType; +import dev.cel.common.types.StructTypeReference; +import dev.cel.expr.ai.Agent; +import dev.cel.expr.ai.AgentContext; +import dev.cel.expr.ai.AgentMessage; +import dev.cel.expr.ai.Finding; +import dev.cel.expr.ai.Tool; +import dev.cel.expr.ai.ToolAnnotations; +import dev.cel.expr.ai.ToolCall; +import dev.cel.expr.ai.TrustLevel; +import dev.cel.parser.CelStandardMacro; +import dev.cel.policy.testing.PolicyTestSuiteHelper; +import dev.cel.policy.testing.PolicyTestSuiteHelper.PolicyTestSuite; +import dev.cel.policy.testing.PolicyTestSuiteHelper.PolicyTestSuite.PolicyTestSection; +import dev.cel.policy.testing.PolicyTestSuiteHelper.PolicyTestSuite.PolicyTestSection.PolicyTestCase; +import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.CelFunctionBinding; +import dev.cel.runtime.CelLateFunctionBindings; +import java.io.IOException; +import java.net.URL; +import java.time.Instant; +import java.util.List; +import java.util.stream.Collectors; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; + +@RunWith(TestParameterInjector.class) +public class AgenticPolicyCompilerTest { + @Rule + public final Expect expect = Expect.create(); + + private static final Cel CEL = CelFactory.standardCelBuilder() + .setContainer(CelContainer.ofName("cel.expr.ai")) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .addMessageTypes(Agent.getDescriptor()) + .addMessageTypes(AgentContext.getDescriptor()) + .addMessageTypes(TrustLevel.getDescriptor()) + .addMessageTypes(ToolCall.getDescriptor()) + .addMessageTypes(Tool.getDescriptor()) + .addMessageTypes(ToolAnnotations.getDescriptor()) + .addMessageTypes(AgentMessage.getDescriptor()) + .addMessageTypes(Finding.getDescriptor()) + + .addVar("agent.input", StructTypeReference.create("cel.expr.ai.AgentMessage")) + .addVar("agent.context", StructTypeReference.create("cel.expr.ai.AgentContext")) + .addVar("_test_history", ListType.create(StructTypeReference.create("cel.expr.ai.AgentMessage"))) + .addVar("now", SimpleType.TIMESTAMP) + + .addVar("tool.name", SimpleType.STRING) + .addVar("tool.annotations", StructTypeReference.create("cel.expr.ai.ToolAnnotations")) + .addVar("tool.call", StructTypeReference.create("cel.expr.ai.ToolCall")) + .addFunctionDeclarations( + newFunctionDeclaration( + "ai.finding", + newGlobalOverload( + "ai_finding_string_double", + StructTypeReference.create("cel.expr.ai.Finding"), + SimpleType.STRING, + SimpleType.DOUBLE + ) + ), + newFunctionDeclaration( + "threats", + newMemberOverload( + "agent_message_threats", + ListType.create(StructTypeReference.create("cel.expr.ai.Finding")), + StructTypeReference.create("cel.expr.ai.AgentMessage") + ) + ), + newFunctionDeclaration( + "sensitivityLabel", + newMemberOverload( + "tool_call_sensitivity_label", + ListType.create(StructTypeReference.create("cel.expr.ai.Finding")), + StructTypeReference.create("cel.expr.ai.ToolCall"), + SimpleType.STRING + ) + ), + newFunctionDeclaration( + "contains", + newMemberOverload( + "list_finding_contains_list_finding", + SimpleType.BOOL, + ListType.create(StructTypeReference.create("cel.expr.ai.Finding")), + ListType.create(StructTypeReference.create("cel.expr.ai.Finding")) + ) + ), + newFunctionDeclaration( + "agent.history", + newGlobalOverload( + "agent_history", + ListType.create(StructTypeReference.create("cel.expr.ai.AgentMessage")) + ) + ), + newFunctionDeclaration( + "role", + newMemberOverload( + "list_agent_message_role_string", + ListType.create(StructTypeReference.create("cel.expr.ai.AgentMessage")), + ListType.create(StructTypeReference.create("cel.expr.ai.AgentMessage")), + SimpleType.STRING + ) + ), + newFunctionDeclaration( + "after", + newMemberOverload( + "list_agent_message_after_timestamp", + ListType.create(StructTypeReference.create("cel.expr.ai.AgentMessage")), + ListType.create(StructTypeReference.create("cel.expr.ai.AgentMessage")), + SimpleType.TIMESTAMP + ) + ) + ) + .addFunctionBindings( + CelFunctionBinding.from( + "ai_finding_string_double", + ImmutableList.of(String.class, Double.class), + (args) -> Finding.newBuilder() + .setValue((String) args[0]) + .setConfidence((Double) args[1]) + .build() + ), + CelFunctionBinding.from( + "agent_message_threats", + AgentMessage.class, + (msg) -> { + if (msg.getPartsCount() > 0 && msg.getParts(0).hasPrompt()) { + String content = msg.getParts(0).getPrompt().getContent(); + if (content.contains("INJECTION_ATTACK")) { + return ImmutableList.of( + Finding.newBuilder().setValue("prompt_injection").setConfidence(0.95).build() + ); + } + if (content.contains("SUSPICIOUS")) { + return ImmutableList.of( + Finding.newBuilder().setValue("prompt_injection").setConfidence(0.6).build() + ); + } + } + return ImmutableList.of(); + } + ), + CelFunctionBinding.from( + "tool_call_sensitivity_label", + ImmutableList.of(ToolCall.class, String.class), + (args) -> { + ToolCall tool = (ToolCall) args[0]; + String label = (String) args[1]; + if ("pii".equals(label) && tool.getName().contains("PII")) { + return ImmutableList.of( + Finding.newBuilder().setValue("pii").setConfidence(1.0).build() + ); + } + return ImmutableList.of(); + } + ), + CelFunctionBinding.from( + "list_finding_contains_list_finding", + ImmutableList.of(List.class, List.class), + (args) -> { + List actualFindings = (List) args[0]; + List expectedFindings = (List) args[1]; + return expectedFindings.stream().anyMatch(expected -> + actualFindings.stream().anyMatch(actual -> + actual.getValue().equals(expected.getValue()) && + actual.getConfidence() >= expected.getConfidence() + ) + ); + } + ), + CelFunctionBinding.from( + "list_agent_message_role_string", + ImmutableList.of(List.class, String.class), + (args) -> { + List history = (List) args[0]; + String role = (String) args[1]; + return history.stream() + .filter(m -> m.getRole().equals(role)) + .collect(Collectors.toList()); + } + ), + CelFunctionBinding.from( + "list_agent_message_after_timestamp", + ImmutableList.of(List.class, Instant.class), + (args) -> { + List history = (List) args[0]; + Instant cutoff = (Instant) args[1]; + + return history.stream() + .filter(m -> { + com.google.protobuf.Timestamp protoTs = m.getTime(); + Instant msgTime = Instant.ofEpochSecond(protoTs.getSeconds(), protoTs.getNanos()); + return msgTime.compareTo(cutoff) >= 0; + }) + .collect(Collectors.toList()); + } + ) + ) + .build(); + + private static final AgenticPolicyCompiler COMPILER = AgenticPolicyCompiler.newInstance(CEL); + + @Test + public void runAgenticPolicyTestCases(@TestParameter AgenticPolicyTestCase testCase) throws Exception { + CelAbstractSyntaxTree compiledPolicy = compilePolicy(testCase.policyFilePath); + PolicyTestSuite testSuite = PolicyTestSuiteHelper.readTestSuite(testCase.policyTestCaseFilePath); + runTests(CEL, compiledPolicy, testSuite); + } + + private enum AgenticPolicyTestCase { + PROMPT_INJECTION_TESTS( + "prompt_injection.celpolicy", + "prompt_injection_tests.yaml" + ), + REQUIRE_USER_CONFIRMATION_FOR_TOOL( + "require_user_confirmation_for_tool.celpolicy", + "require_user_confirmation_for_tool_tests.yaml" + ), + OPEN_WORLD_TOOL_REPLAY( + "open_world_tool_replay.celpolicy", + "open_world_tool_replay_tests.yaml" + ), + TRUST_CASCADING( + "trust_cascading.celpolicy", + "trust_cascading_tests.yaml" + ), + TIME_BOUND_APPROVAL( + "time_bound_approval.celpolicy", + "time_bound_approval_tests.yaml" + ); + + private final String policyFilePath; + private final String policyTestCaseFilePath; + + AgenticPolicyTestCase(String policyFilePath, String policyTestCaseFilePath) { + this.policyFilePath = policyFilePath; + this.policyTestCaseFilePath = policyTestCaseFilePath; + } + } + + private static CelAbstractSyntaxTree compilePolicy(String policyPath) + throws Exception { + String policy = readFile(policyPath); + return COMPILER.compile(policy); + } + + private static String readFile(String path) throws IOException { + URL url = Resources.getResource(Ascii.toLowerCase(path)); + return Resources.toString(url, UTF_8); + } + + private void runTests(Cel cel, CelAbstractSyntaxTree ast, PolicyTestSuite testSuite) { + for (PolicyTestSection testSection : testSuite.getSection()) { + for (PolicyTestCase testCase : testSection.getTests()) { + String testName = String.format( + "%s: %s", testSection.getName(), testCase.getName()); + try { + ImmutableMap inputMap = testCase.toInputMap(cel); + + List history = + inputMap.containsKey("_test_history") + ? (List) inputMap.get("_test_history") + : ImmutableList.of(); + + @SuppressWarnings("Immutable") + CelLateFunctionBindings bindings = CelLateFunctionBindings.from( + CelFunctionBinding.from( + "agent_history", + ImmutableList.of(), // No args + (args) -> history + ) + ); + + Object evalResult = cel.createProgram(ast).eval(inputMap, bindings); + Object expectedOutput = cel.createProgram(cel.compile(testCase.getOutput()).getAst()).eval(); + expect.withMessage(testName).that(evalResult).isEqualTo(expectedOutput); + } catch (CelValidationException e) { + expect.withMessage("Failed to compile test case for " + testName + ". Reason:\n" + e.getMessage()).fail(); + } catch (CelEvaluationException e) { + expect.withMessage("Failed to evaluate test case for " + testName + ". Reason:\n" + e.getMessage()).fail(); + } + } + } + } +} diff --git a/tools/src/test/java/dev/cel/tools/ai/BUILD.bazel b/tools/src/test/java/dev/cel/tools/ai/BUILD.bazel new file mode 100644 index 000000000..9e43026ac --- /dev/null +++ b/tools/src/test/java/dev/cel/tools/ai/BUILD.bazel @@ -0,0 +1,43 @@ +load("@rules_java//java:defs.bzl", "java_library") +load("//:testing.bzl", "junit4_test_suites") + +package(default_applicable_licenses = ["//:license"]) + +java_library( + name = "tests", + testonly = True, + srcs = glob( + ["*.java"], + ), + resources = ["//tools/ai:test_policies"], + deps = [ + "//:java_truth", + "//bundle:cel", + "//common:cel_ast", + "//common:compiler_common", + "//common:container", + "//common/formats:value_string", + "//common/types", + "//parser:macro", + "//policy/testing:policy_test_suite_helper", + "//runtime:evaluation_exception", + "//runtime:function_binding", + "//runtime:late_function_binding", + "//tools/ai:agentic_policy_compiler", + "//tools/src/main/java/dev/cel/tools/ai:agent_context_java_proto", + "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", + "@maven//:com_google_protobuf_protobuf_java_util", + "@maven//:com_google_testparameterinjector_test_parameter_injector", + "@maven//:junit_junit", + ], +) + +junit4_test_suites( + name = "test_suites", + sizes = [ + "small", + ], + src_dir = "src/test/java", + deps = [":tests"], +) diff --git a/tools/src/test/resources/BUILD.bazel b/tools/src/test/resources/BUILD.bazel new file mode 100644 index 000000000..8fbb42fce --- /dev/null +++ b/tools/src/test/resources/BUILD.bazel @@ -0,0 +1,20 @@ +package( + default_applicable_licenses = [ + "//:license", + ], + default_testonly = True, + default_visibility = [ + "//tools/ai:__pkg__", + ], +) + +filegroup( + name = "test_policies", + testonly = True, + srcs = glob( + [ + "*.celpolicy", + "*.yaml", + ], + ), +) diff --git a/tools/src/test/resources/open_world_tool_replay.celpolicy b/tools/src/test/resources/open_world_tool_replay.celpolicy new file mode 100644 index 000000000..9ef6b4eaf --- /dev/null +++ b/tools/src/test/resources/open_world_tool_replay.celpolicy @@ -0,0 +1,14 @@ +name: "policy.safety.open_world_replay" +default: allow + +rules: + - description: "Limit turn window for open-world tools (internet access)" + condition: | + tool.annotations.open_world + effect: replay + output_expr: | + { + 'type': 'USER', + 'turn_window': 1, + 'reason': 'Tool interacts with the open world.' + } \ No newline at end of file diff --git a/tools/src/test/resources/open_world_tool_replay_tests.yaml b/tools/src/test/resources/open_world_tool_replay_tests.yaml new file mode 100644 index 000000000..44cac1595 --- /dev/null +++ b/tools/src/test/resources/open_world_tool_replay_tests.yaml @@ -0,0 +1,36 @@ +description: "Open World Tool Replay Policy Tests" + +section: +- name: "Capability Checks" + tests: + - name: "Open World Tool (Replay)" + input: + tool.annotations: + expr: > + ToolAnnotations{ open_world: true } + tool.call: + expr: > + ToolCall{ name: "internet_search" } + output: > + { + "effect": "replay", + "details": { + "type": "USER", + "turn_window": 1, + "reason": "Tool interacts with the open world." + } + } + + - name: "Closed World Tool (Allow)" + input: + tool.annotations: + expr: > + ToolAnnotations{ open_world: false } + tool.call: + expr: > + ToolCall{ name: "calculator" } + output: > + { + "effect": "allow", + "message": "" + } \ No newline at end of file diff --git a/tools/src/test/resources/prompt_injection.celpolicy b/tools/src/test/resources/prompt_injection.celpolicy new file mode 100644 index 000000000..01336d083 --- /dev/null +++ b/tools/src/test/resources/prompt_injection.celpolicy @@ -0,0 +1,20 @@ +name: "policy.safety.prompt.injection" +default: allow + +variables: + - high_confidence_threat: > + agent.input.threats().contains([ai.finding("prompt_injection", 0.9)]) + + - potential_threat: > + agent.input.threats().contains([ai.finding("prompt_injection", 0.5)]) + +rules: + - description: "Block high-confidence injection attacks" + condition: variables.high_confidence_threat + effect: deny + message: "High-confidence prompt injection detected." + + - description: "Require confirmation for suspicious inputs" + condition: variables.potential_threat + effect: confirm + message: "Potential prompt injection detected. User confirmation required." \ No newline at end of file diff --git a/tools/src/test/resources/prompt_injection_tests.yaml b/tools/src/test/resources/prompt_injection_tests.yaml new file mode 100644 index 000000000..58f805dcb --- /dev/null +++ b/tools/src/test/resources/prompt_injection_tests.yaml @@ -0,0 +1,61 @@ +description: "Prompt Injection Policy Tests" + +section: +- name: "Injection Classification Scenarios" + tests: + - name: "High Confidence Injection (Deny)" + input: + agent.input: + expr: > + AgentMessage{ + parts: [ + AgentMessage.Part{ + prompt: ContentPart{ + content: "INJECTION_ATTACK detected" + } + } + ] + } + output: > + { + "effect": "deny", + "message": "High-confidence prompt injection detected." + } + + - name: "Medium Confidence Injection (Confirm)" + input: + agent.input: + expr: > + AgentMessage{ + parts: [ + AgentMessage.Part{ + prompt: ContentPart{ + content: "This looks SUSPICIOUS but maybe safe" + } + } + ] + } + output: > + { + "effect": "confirm", + "message": "Potential prompt injection detected. User confirmation required." + } + + - name: "Safe Input (Allow)" + input: + agent.input: + expr: > + AgentMessage{ + parts: [ + AgentMessage.Part{ + prompt: ContentPart{ + content: "Just a normal user query" + } + } + ] + } + output: > + { + "effect": "allow", + "message": "" + } \ No newline at end of file diff --git a/tools/src/test/resources/require_user_confirmation_for_tool.celpolicy b/tools/src/test/resources/require_user_confirmation_for_tool.celpolicy new file mode 100644 index 000000000..983e1b72b --- /dev/null +++ b/tools/src/test/resources/require_user_confirmation_for_tool.celpolicy @@ -0,0 +1,20 @@ +name: "require_user_confirmation_for_mcp_tool" +default: deny + +variables: + - high_confidence_pii: > + tool.call.sensitivityLabel('pii').exists(f, f.confidence >= 0.8) + +rules: + - description: "Confirm tool calls if high-confidence PII is detected" + condition: > + variables.high_confidence_pii && + !tool.call.user_confirmed + effect: confirm + message: "This tool call contains sensitive data (PII). User confirmation is required." + + - description: "Allow if no high-confidence PII is detected OR if confirmed" + condition: > + !variables.high_confidence_pii || + tool.call.user_confirmed + effect: allow \ No newline at end of file diff --git a/tools/src/test/resources/require_user_confirmation_for_tool_tests.yaml b/tools/src/test/resources/require_user_confirmation_for_tool_tests.yaml new file mode 100644 index 000000000..3987b169a --- /dev/null +++ b/tools/src/test/resources/require_user_confirmation_for_tool_tests.yaml @@ -0,0 +1,31 @@ +description: "Require tool confirmation tests" + +section: +- name: "tool call test section" + tests: + - name: "reject_sensitive_tool_call" + input: + tool.call: + expr: > + ToolCall{ + name: "tool_with_PII", + user_confirmed: false + } + output: > + { + "effect": "confirm", + "message": "This tool call contains sensitive data (PII). User confirmation is required." + } + - name: "allow_confirmed_tool" + input: + tool.call: + expr: > + ToolCall{ + name: "tool_with_PII", + user_confirmed: true + } + output: > + { + "effect": "allow", + "message": "", + } \ No newline at end of file diff --git a/tools/src/test/resources/time_bound_approval.celpolicy b/tools/src/test/resources/time_bound_approval.celpolicy new file mode 100644 index 000000000..efb45fd6e --- /dev/null +++ b/tools/src/test/resources/time_bound_approval.celpolicy @@ -0,0 +1,23 @@ +name: "policy.safety.time_bound_approval" +default: allow + +variables: + # Define the validity window (30 seconds ago) + - approval_cutoff: now - duration('30s') + + # Find approval messages in the valid window + - valid_approvals: > + agent.history() + .after(variables.approval_cutoff) + .role('model') + .filter(m, has(m.metadata.step) && m.metadata.step == 'approval_granted') + + - has_valid_approval: variables.valid_approvals.size() > 0 + +rules: + - description: "Require approval within the last 30 seconds for sensitive writes" + condition: > + tool.name == 'database_write' && + !variables.has_valid_approval + effect: deny + message: "Authorization expired. Please re-approve the database write operation." \ No newline at end of file diff --git a/tools/src/test/resources/time_bound_approval_tests.yaml b/tools/src/test/resources/time_bound_approval_tests.yaml new file mode 100644 index 000000000..0b87fe24f --- /dev/null +++ b/tools/src/test/resources/time_bound_approval_tests.yaml @@ -0,0 +1,46 @@ +description: "Time-Bound Approval Policy Tests" + +section: +- name: "Time Window Enforcement" + tests: + - name: "Approval Expired (Deny)" + input: + tool.name: + value: "database_write" + now: + expr: timestamp("2024-01-01T12:01:00Z") + _test_history: + expr: > + [ + AgentMessage{ + role: "model", + time: timestamp("2024-01-01T12:00:00Z"), + metadata: { "step": "approval_granted" } + } + ] + output: > + { + "effect": "deny", + "message": "Authorization expired. Please re-approve the database write operation." + } + + - name: "Approval Valid (Allow)" + input: + tool.name: + value: "database_write" + now: + expr: timestamp("2024-01-01T12:00:10Z") + _test_history: + expr: > + [ + AgentMessage{ + role: "model", + time: timestamp("2024-01-01T12:00:00Z"), + metadata: { "step": "approval_granted" } + } + ] + output: > + { + "effect": "allow", + "message": "" + } \ No newline at end of file diff --git a/tools/src/test/resources/trust_cascading.celpolicy b/tools/src/test/resources/trust_cascading.celpolicy new file mode 100644 index 000000000..c24f140bc --- /dev/null +++ b/tools/src/test/resources/trust_cascading.celpolicy @@ -0,0 +1,35 @@ +name: "policy.trust.cascading" +default: allow + +variables: + # Critical security threats + - is_compromised: > + agent.context.trust.findings.contains([ai.finding("compromised_session", 0.9)]) + + # Compliance and/or hygiene issues with the source + - is_unverified: > + agent.context.trust.findings.contains([ai.finding("unverified_source", 0.8)]) + +rules: + - description: "Block sessions with high-confidence compromise indicators" + condition: variables.is_compromised + effect: deny + message: "Critical Trust Failure: Session is potentially compromised." + + - description: "Replay to request source verification" + condition: variables.is_unverified + effect: replay + output_expr: | + { + 'reason': 'Data source is unverified.', + 'action': 'verify_provenance' + } + + - description: "Replay generic untrusted contexts" + condition: agent.context.trust.level == 'untrusted' + effect: replay + output_expr: | + { + 'reason': 'Context trust is insufficient.', + 'required_level': 'trusted_3p' + } \ No newline at end of file diff --git a/tools/src/test/resources/trust_cascading_tests.yaml b/tools/src/test/resources/trust_cascading_tests.yaml new file mode 100644 index 000000000..465f36e65 --- /dev/null +++ b/tools/src/test/resources/trust_cascading_tests.yaml @@ -0,0 +1,56 @@ +description: "Trust Cascading Policy Tests" + +section: +- name: "Trust Finding Scenarios" + tests: + - name: "Critical Compromise (Deny)" + input: + agent.context: + expr: > + AgentContext{ + trust: TrustLevel{ + level: "untrusted", + findings: [ + Finding{ value: "compromised_session", confidence: 0.95 } + ] + } + } + output: > + { + "effect": "deny", + "message": "Critical Trust Failure: Session is potentially compromised." + } + + - name: "Unverified Source (Replay)" + input: + agent.context: + expr: > + AgentContext{ + trust: TrustLevel{ + level: "untrusted", + findings: [ + Finding{ value: "unverified_source", confidence: 0.85 } + ] + } + } + output: > + { + "effect": "replay", + "details": { + "reason": "Data source is unverified.", + "action": "verify_provenance" + } + } + + - name: "Trusted Context (Allow)" + input: + agent.context: + expr: > + AgentContext{ + trust: TrustLevel{ level: "trusted" } + } + output: > + { + "effect": "allow", + "message": "" + } \ No newline at end of file