diff --git a/policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java b/policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java index e320cfe70..52900577b 100644 --- a/policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java +++ b/policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java @@ -158,13 +158,17 @@ private void assertAstDepthIsSafe(CelAbstractSyntaxTree ast, Cel cel) private CelCompiledRule compileRuleImpl( CelPolicy.Rule rule, Cel ruleCel, CompilerContext compilerContext) { + // A local CEL environment used to compile a single rule. This temporary environment + // is used to declare policy variables iteratively in a given policy, ensuring proper scoping + // across a single / nested rule. + Cel localCel = ruleCel; ImmutableList.Builder variableBuilder = ImmutableList.builder(); for (Variable variable : rule.variables()) { ValueString expression = variable.expression(); CelAbstractSyntaxTree varAst; CelType outputType = SimpleType.DYN; try { - varAst = ruleCel.compile(expression.value()).getAst(); + varAst = localCel.compile(expression.value()).getAst(); outputType = varAst.getResultType(); } catch (CelValidationException e) { compilerContext.addIssue(expression.id(), e.getErrors()); @@ -174,7 +178,7 @@ private CelCompiledRule compileRuleImpl( String variableName = variable.name().value(); CelVarDecl newVariable = CelVarDecl.newVarDeclaration(variablesPrefix + variableName, outputType); - ruleCel = ruleCel.toCelBuilder().addVarDeclarations(newVariable).build(); + localCel = localCel.toCelBuilder().addVarDeclarations(newVariable).build(); variableBuilder.add(CelCompiledVariable.create(variableName, varAst, newVariable)); } @@ -182,7 +186,7 @@ private CelCompiledRule compileRuleImpl( for (Match match : rule.matches()) { CelAbstractSyntaxTree conditionAst; try { - conditionAst = ruleCel.compile(match.condition().value()).getAst(); + conditionAst = localCel.compile(match.condition().value()).getAst(); if (!conditionAst.getResultType().equals(SimpleType.BOOL)) { compilerContext.addIssue( match.condition().id(), @@ -199,7 +203,7 @@ private CelCompiledRule compileRuleImpl( CelAbstractSyntaxTree outputAst; ValueString output = match.result().output(); try { - outputAst = ruleCel.compile(output.value()).getAst(); + outputAst = localCel.compile(output.value()).getAst(); } catch (CelValidationException e) { compilerContext.addIssue(output.id(), e.getErrors()); continue; @@ -209,7 +213,7 @@ private CelCompiledRule compileRuleImpl( break; case RULE: CelCompiledRule nestedRule = - compileRuleImpl(match.result().rule(), ruleCel, compilerContext); + compileRuleImpl(match.result().rule(), localCel, compilerContext); matchResult = Result.ofRule(nestedRule); break; default: @@ -221,7 +225,7 @@ private CelCompiledRule compileRuleImpl( CelCompiledRule compiledRule = CelCompiledRule.create( - rule.id(), rule.ruleId(), variableBuilder.build(), matchBuilder.build(), cel); + rule.id(), rule.ruleId(), variableBuilder.build(), matchBuilder.build(), ruleCel); // Validate that all branches in the policy are reachable checkUnreachableCode(compiledRule, compilerContext); diff --git a/policy/src/test/java/dev/cel/policy/BUILD.bazel b/policy/src/test/java/dev/cel/policy/BUILD.bazel index 3b6a62e53..9106caf70 100644 --- a/policy/src/test/java/dev/cel/policy/BUILD.bazel +++ b/policy/src/test/java/dev/cel/policy/BUILD.bazel @@ -20,6 +20,7 @@ java_library( "//common/formats:value_string", "//common/internal", "//common/resources/testdata/proto3:standalone_global_enum_java_proto", + "//common/types", "//compiler", "//extensions:optional_library", "//parser:macro", @@ -35,6 +36,7 @@ java_library( "//runtime", "//runtime:function_binding", "//runtime:late_function_binding", + "//testing/protos:single_file_java_proto", "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", "@maven//:com_google_guava_guava", "@maven//:com_google_testparameterinjector_test_parameter_injector", diff --git a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java index d96fba461..ee041192c 100644 --- a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java +++ b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java @@ -31,6 +31,8 @@ import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelOptions; +import dev.cel.common.types.OptionalType; +import dev.cel.common.types.SimpleType; import dev.cel.expr.conformance.proto3.TestAllTypes; import dev.cel.extensions.CelOptionalLibrary; import dev.cel.parser.CelStandardMacro; @@ -43,6 +45,7 @@ import dev.cel.policy.PolicyTestHelper.TestYamlPolicy; import dev.cel.runtime.CelFunctionBinding; import dev.cel.runtime.CelLateFunctionBindings; +import dev.cel.testing.testdata.SingleFileProto.SingleFile; import dev.cel.testing.testdata.proto3.StandaloneGlobalEnum; import java.io.IOException; import java.util.Map; @@ -76,6 +79,28 @@ public void compileYamlPolicy_success(@TestParameter TestYamlPolicy yamlPolicy) assertThat(CelUnparserFactory.newUnparser().unparse(ast)).isEqualTo(yamlPolicy.getUnparsed()); } + @Test + public void compileYamlPolicy_withImportsOnNestedRules() throws Exception { + String policySource = + "imports:\n" + + " - name: cel.expr.conformance.proto3.TestAllTypes\n" + + " - name: dev.cel.testing.testdata.SingleFile\n" + + "rule:\n" + + " match:\n" + + " - rule:\n" + + " id: 'nested rule with imports'\n" + + " match:\n" + + " - condition: 'TestAllTypes{}.single_string == SingleFile{}.name'\n" + + " output: 'true'\n"; + Cel cel = newCel(); + CelPolicy policy = POLICY_PARSER.parse(policySource); + + CelAbstractSyntaxTree ast = + CelPolicyCompilerFactory.newPolicyCompiler(cel).build().compile(policy); + + assertThat(ast.getResultType()).isEqualTo(OptionalType.create(SimpleType.BOOL)); + } + @Test public void compileYamlPolicy_containsCompilationError_throws( @TestParameter TestErrorYamlPolicy testCase) throws Exception { @@ -292,7 +317,7 @@ private static Cel newCel() { .addCompilerLibraries(CelOptionalLibrary.INSTANCE) .addRuntimeLibraries(CelOptionalLibrary.INSTANCE) .addFileTypes(StandaloneGlobalEnum.getDescriptor().getFile()) - .addMessageTypes(TestAllTypes.getDescriptor()) + .addMessageTypes(TestAllTypes.getDescriptor(), SingleFile.getDescriptor()) .setOptions(CEL_OPTIONS) .addFunctionBindings( CelFunctionBinding.from( diff --git a/testing/src/test/resources/protos/BUILD.bazel b/testing/src/test/resources/protos/BUILD.bazel index 9106c666a..af361b174 100644 --- a/testing/src/test/resources/protos/BUILD.bazel +++ b/testing/src/test/resources/protos/BUILD.bazel @@ -20,6 +20,8 @@ proto_library( java_proto_library( name = "single_file_java_proto", + tags = [ + ], deps = [":single_file_proto"], )