From a35aab15b074bd5e9c5288fbe898946d65f9f1db Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Wed, 20 Aug 2025 13:47:48 -0700 Subject: [PATCH] Optimize composed expressions via constant folding and common subexpression elimination --- .../optimizers/ConstantFoldingOptimizer.java | 4 +- .../src/main/java/dev/cel/policy/BUILD.bazel | 5 ++ .../dev/cel/policy/CelPolicyCompilerImpl.java | 11 ++- .../java/dev/cel/policy/RuleComposer.java | 33 ++++++--- .../cel/policy/CelPolicyCompilerImplTest.java | 10 +-- .../java/dev/cel/policy/PolicyTestHelper.java | 70 +++---------------- 6 files changed, 55 insertions(+), 78 deletions(-) diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java index dd5a3c211..4e94664d8 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java @@ -202,7 +202,9 @@ private static boolean canFoldInOperator(CelNavigableMutableExpr navigableExpr) CelNavigableMutableExpr parent = identNode.parent().orElse(null); while (parent != null) { if (parent.getKind().equals(Kind.COMPREHENSION)) { - if (parent.expr().comprehension().accuVar().equals(identNode.expr().ident().name())) { + String identName = identNode.expr().ident().name(); + if (parent.expr().comprehension().accuVar().equals(identName) || + parent.expr().comprehension().iterVar().equals(identName)) { // Prevent folding a subexpression if it contains a variable declared by a // comprehension. The subexpression cannot be compiled without the full context of the // surrounding comprehension. diff --git a/policy/src/main/java/dev/cel/policy/BUILD.bazel b/policy/src/main/java/dev/cel/policy/BUILD.bazel index 9d4a3fad9..d32c7ecde 100644 --- a/policy/src/main/java/dev/cel/policy/BUILD.bazel +++ b/policy/src/main/java/dev/cel/policy/BUILD.bazel @@ -215,6 +215,8 @@ java_library( "//optimizer", "//optimizer:optimization_exception", "//optimizer:optimizer_builder", + "//optimizer/optimizers:common_subexpression_elimination", + "//optimizer/optimizers:constant_folding", "//validator", "//validator:ast_validator", "//validator:validator_builder", @@ -247,7 +249,10 @@ java_library( "//common:cel_ast", "//common:compiler_common", "//common:mutable_ast", + "//common/ast", + "//common/ast:mutable_expr", "//common/formats:value_string", + "//common/navigation:mutable_navigation", "//extensions:optional_library", "//optimizer:ast_optimizer", "//optimizer:mutable_ast", diff --git a/policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java b/policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java index 52900577b..794d94039 100644 --- a/policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java +++ b/policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java @@ -36,6 +36,9 @@ import dev.cel.optimizer.CelOptimizationException; import dev.cel.optimizer.CelOptimizer; import dev.cel.optimizer.CelOptimizerFactory; +import dev.cel.optimizer.optimizers.ConstantFoldingOptimizer; +import dev.cel.optimizer.optimizers.SubexpressionOptimizer; +import dev.cel.optimizer.optimizers.SubexpressionOptimizer.SubexpressionOptimizerOptions; import dev.cel.policy.CelCompiledRule.CelCompiledMatch; import dev.cel.policy.CelCompiledRule.CelCompiledMatch.Result; import dev.cel.policy.CelCompiledRule.CelCompiledMatch.Result.Kind; @@ -101,7 +104,13 @@ public CelAbstractSyntaxTree compose(CelPolicy policy, CelCompiledRule compiledR CelOptimizer optimizer = CelOptimizerFactory.standardCelOptimizerBuilder(cel) .addAstOptimizers( - RuleComposer.newInstance(compiledRule, variablesPrefix, iterationLimit)) + RuleComposer.newInstance(compiledRule, variablesPrefix, iterationLimit), + ConstantFoldingOptimizer.getInstance(), + SubexpressionOptimizer.newInstance(SubexpressionOptimizerOptions.newBuilder() + .populateMacroCalls(true) + .enableCelBlock(true) + .build()) + ) .build(); CelAbstractSyntaxTree ast; diff --git a/policy/src/main/java/dev/cel/policy/RuleComposer.java b/policy/src/main/java/dev/cel/policy/RuleComposer.java index 54735c1fb..5bae1f4b3 100644 --- a/policy/src/main/java/dev/cel/policy/RuleComposer.java +++ b/policy/src/main/java/dev/cel/policy/RuleComposer.java @@ -19,12 +19,16 @@ import static java.util.stream.Collectors.toCollection; import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import dev.cel.bundle.Cel; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelMutableAst; import dev.cel.common.CelValidationException; +import dev.cel.common.ast.CelExpr.ExprKind.Kind; import dev.cel.common.formats.ValueString; +import dev.cel.common.navigation.CelNavigableMutableAst; +import dev.cel.common.navigation.CelNavigableMutableExpr; import dev.cel.extensions.CelOptionalLibrary.Function; import dev.cel.optimizer.AstMutator; import dev.cel.optimizer.CelAstOptimizer; @@ -151,23 +155,30 @@ private RuleOptimizationResult optimizeRule(Cel cel, CelCompiledRule compiledRul } } - CelMutableAst result = matchAst; - for (CelCompiledVariable variable : Lists.reverse(compiledRule.variables())) { - result = - astMutator.replaceSubtreeWithNewBindMacro( - result, - variablePrefix + variable.name(), - CelMutableAst.fromCelAst(variable.ast()), - result.expr(), - result.expr().id(), - true); - } + CelMutableAst result = inlineCompiledVariables(matchAst, compiledRule.variables()); result = astMutator.renumberIdsConsecutively(result); return RuleOptimizationResult.create(result, isOptionalResult); } + private CelMutableAst inlineCompiledVariables(CelMutableAst ast, List compiledVariables) { + CelMutableAst mutatedAst = ast; + for (CelCompiledVariable compiledVariable : Lists.reverse(compiledVariables)) { + String variableName = variablePrefix + compiledVariable.name(); + ImmutableList exprsToReplace = CelNavigableMutableAst.fromAst(mutatedAst).getRoot().allNodes().filter( + node -> node.expr().getKind().equals(Kind.IDENT) && node.expr().ident().name().equals(variableName)) + .collect(toImmutableList()); + + for (CelNavigableMutableExpr expr : exprsToReplace) { + CelMutableAst variableAst = CelMutableAst.fromCelAst(compiledVariable.ast()); + mutatedAst = astMutator.replaceSubtree(mutatedAst, variableAst, expr.id()); + } + } + + return mutatedAst; + } + static RuleComposer newInstance( CelCompiledRule compiledRule, String variablePrefix, int iterationLimit) { return new RuleComposer(compiledRule, variablePrefix, iterationLimit); diff --git a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java index ee041192c..5e533ab71 100644 --- a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java +++ b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java @@ -136,8 +136,8 @@ public void compileYamlPolicy_multilineContainsError_throws( @Test public void compileYamlPolicy_exceedsDefaultAstDepthLimit_throws() throws Exception { - String longExpr = - "0+1+2+3+4+5+6+7+8+9+10+11+12+13+14+15+16+17+18+19+20+21+22+23+24+25+26+27+28+29+30+31+32+33+34+35+36+37+38+39+40+41+42+43+44+45+46+47+48+49+50"; + Cel cel = newCel().toCelBuilder().addVar("msg", SimpleType.DYN).build(); + String longExpr = "msg.b.c.d.e.f"; String policyContent = String.format( "name: deeply_nested_ast\n" + "rule:\n" + " match:\n" + " - output: %s", longExpr); @@ -146,11 +146,13 @@ public void compileYamlPolicy_exceedsDefaultAstDepthLimit_throws() throws Except CelPolicyValidationException e = assertThrows( CelPolicyValidationException.class, - () -> CelPolicyCompilerFactory.newPolicyCompiler(newCel()).build().compile(policy)); + () -> CelPolicyCompilerFactory.newPolicyCompiler(cel) + .setAstDepthLimit(5) + .build().compile(policy)); assertThat(e) .hasMessageThat() - .isEqualTo("ERROR: :-1:0: AST's depth exceeds the configured limit: 50."); + .isEqualTo("ERROR: :-1:0: AST's depth exceeds the configured limit: 5."); } @Test diff --git a/policy/src/test/java/dev/cel/policy/PolicyTestHelper.java b/policy/src/test/java/dev/cel/policy/PolicyTestHelper.java index 002025094..751509f4a 100644 --- a/policy/src/test/java/dev/cel/policy/PolicyTestHelper.java +++ b/policy/src/test/java/dev/cel/policy/PolicyTestHelper.java @@ -41,88 +41,36 @@ enum TestYamlPolicy { NESTED_RULE( "nested_rule", true, - "cel.bind(variables.permitted_regions, [\"us\", \"uk\", \"es\"]," - + " cel.bind(variables.banned_regions, {\"us\": false, \"ru\": false, \"ir\": false}," - + " (resource.origin in variables.banned_regions && " - + "!(resource.origin in variables.permitted_regions)) " - + "? optional.of({\"banned\": true}) : optional.none()).or(" - + "optional.of((resource.origin in variables.permitted_regions)" - + " ? {\"banned\": false} : {\"banned\": true})))"), + "cel.@block([resource.origin, @index0 in [\"us\", \"uk\", \"es\"], {\"banned\": true}], ((@index0 in {\"us\": false, \"ru\": false, \"ir\": false} && !@index1) ? optional.of(@index2) : optional.none()).or(optional.of(@index1 ? {\"banned\": false} : @index2)))"), NESTED_RULE2( "nested_rule2", false, - "cel.bind(variables.permitted_regions, [\"us\", \"uk\", \"es\"]," - + " resource.?user.orValue(\"\").startsWith(\"bad\") ?" - + " cel.bind(variables.banned_regions, {\"us\": false, \"ru\": false, \"ir\": false}," - + " (resource.origin in variables.banned_regions && !(resource.origin in" - + " variables.permitted_regions)) ? {\"banned\": \"restricted_region\"} : {\"banned\":" - + " \"bad_actor\"}) : (!(resource.origin in variables.permitted_regions) ? {\"banned\":" - + " \"unconfigured_region\"} : {}))"), + "cel.@block([resource.origin, !(@index0 in [\"us\", \"uk\", \"es\"])], resource.?user.orValue(\"\").startsWith(\"bad\") ? ((@index0 in {\"us\": false, \"ru\": false, \"ir\": false} && @index1) ? {\"banned\": \"restricted_region\"} : {\"banned\": \"bad_actor\"}) : (@index1 ? {\"banned\": \"unconfigured_region\"} : {}))"), NESTED_RULE3( "nested_rule3", true, - "cel.bind(variables.permitted_regions, [\"us\", \"uk\", \"es\"]," - + " resource.?user.orValue(\"\").startsWith(\"bad\") ?" - + " optional.of(cel.bind(variables.banned_regions, {\"us\": false, \"ru\": false," - + " \"ir\": false}, (resource.origin in variables.banned_regions && !(resource.origin" - + " in variables.permitted_regions)) ? {\"banned\": \"restricted_region\"} :" - + " {\"banned\": \"bad_actor\"})) : (!(resource.origin in variables.permitted_regions)" - + " ? optional.of({\"banned\": \"unconfigured_region\"}) : optional.none()))"), + "cel.@block([resource.origin, !(@index0 in [\"us\", \"uk\", \"es\"])], resource.?user.orValue(\"\").startsWith(\"bad\") ? optional.of((@index0 in {\"us\": false, \"ru\": false, \"ir\": false} && @index1) ? {\"banned\": \"restricted_region\"} : {\"banned\": \"bad_actor\"}) : (@index1 ? optional.of({\"banned\": \"unconfigured_region\"}) : optional.none()))"), REQUIRED_LABELS( "required_labels", true, - "" - + "cel.bind(variables.want, spec.labels, cel.bind(variables.missing, " - + "variables.want.filter(l, !(l in resource.labels)), cel.bind(variables.invalid, " - + "resource.labels.filter(l, l in variables.want && variables.want[l] != " - + "resource.labels[l]), (variables.missing.size() > 0) ? " - + "optional.of(\"missing one or more required labels: [\"\" + " - + "variables.missing.join(\",\") + \"\"]\") : ((variables.invalid.size() > 0) ? " - + "optional.of(\"invalid values provided on one or more labels: [\"\" + " - + "variables.invalid.join(\",\") + \"\"]\") : optional.none()))))"), + "cel.@block([spec.labels.filter(@it:0:0, !(@it:0:0 in resource.labels)), spec.labels, resource.labels, @index2.filter(@it:0:0, @it:0:0 in @index1 && @index1[@it:0:0] != @index2[@it:0:0])], (@index0.size() > 0) ? optional.of(\"missing one or more required labels: [\"\" + @index0.join(\",\") + \"\"]\") : ((@index3.size() > 0) ? optional.of(\"invalid values provided on one or more labels: [\"\" + @index3.join(\",\") + \"\"]\") : optional.none()))"), RESTRICTED_DESTINATIONS( "restricted_destinations", false, - "cel.bind(variables.matches_origin_ip, locationCode(origin.ip) == spec.origin," - + " cel.bind(variables.has_nationality, has(request.auth.claims.nationality)," - + " cel.bind(variables.matches_nationality, variables.has_nationality &&" - + " request.auth.claims.nationality == spec.origin, cel.bind(variables.matches_dest_ip," - + " locationCode(destination.ip) in spec.restricted_destinations," - + " cel.bind(variables.matches_dest_label, resource.labels.location in" - + " spec.restricted_destinations, cel.bind(variables.matches_dest," - + " variables.matches_dest_ip || variables.matches_dest_label," - + " (variables.matches_nationality && variables.matches_dest) ? true :" - + " ((!variables.has_nationality && variables.matches_origin_ip &&" - + " variables.matches_dest) ? true : false)))))))"), + "cel.@block([request.auth.claims, has(@index0.nationality), resource.labels.location in spec.restricted_destinations], (@index1 && @index0.nationality == spec.origin && (locationCode(destination.ip) in spec.restricted_destinations || @index2)) ? true : ((!@index1 && locationCode(origin.ip) == spec.origin && (locationCode(destination.ip) in spec.restricted_destinations || @index2)) ? true : false))"), K8S( "k8s", true, - "cel.bind(variables.env, resource.labels.?environment.orValue(\"prod\")," - + " cel.bind(variables.break_glass, resource.labels.?break_glass.orValue(\"false\") ==" - + " \"true\", !(variables.break_glass || resource.containers.all(c," - + " c.startsWith(variables.env + \".\"))) ? optional.of(\"only \" + variables.env + \"" - + " containers are allowed in namespace \" + resource.namespace) :" - + " optional.none()))"), + "cel.@block([resource.labels.?environment.orValue(\"prod\")], !(resource.labels.?break_glass.orValue(\"false\") == \"true\" || resource.containers.all(@it:0:0, @it:0:0.startsWith(@index0 + \".\"))) ? optional.of(\"only \" + @index0 + \" containers are allowed in namespace \" + resource.namespace) : optional.none())"), PB( "pb", true, - "(spec.single_int32 > TestAllTypes{single_int64: 10}.single_int64) ? optional.of(\"invalid" - + " spec, got single_int32=\" + string(spec.single_int32) + \", wanted <= 10\") :" - + " ((spec.standalone_enum == cel.expr.conformance.proto3.TestAllTypes.NestedEnum.BAR" - + " || dev.cel.testing.testdata.proto3.StandaloneGlobalEnum.SGAR ==" - + " dev.cel.testing.testdata.proto3.StandaloneGlobalEnum.SGOO) ? optional.of(\"invalid" - + " spec, neither nested nor imported enums may refer to BAR\") :" - + " optional.none())"), + "cel.@block([spec.single_int32], (@index0 > 10) ? optional.of(\"invalid spec, got single_int32=\" + string(@index0) + \", wanted <= 10\") : ((spec.standalone_enum == cel.expr.conformance.proto3.TestAllTypes.NestedEnum.BAR || dev.cel.testing.testdata.proto3.StandaloneGlobalEnum.SGAR == dev.cel.testing.testdata.proto3.StandaloneGlobalEnum.SGOO) ? optional.of(\"invalid spec, neither nested nor imported enums may refer to BAR\") : optional.none()))"), LIMITS( "limits", true, - "cel.bind(variables.greeting, \"hello\", cel.bind(variables.farewell, \"goodbye\"," - + " cel.bind(variables.person, \"me\", cel.bind(variables.message_fmt, \"%s, %s\"," - + " (now.getHours() >= 20) ? cel.bind(variables.message, variables.farewell + \", \" +" - + " variables.person, (now.getHours() < 21) ? optional.of(variables.message + \"!\") :" - + " ((now.getHours() < 22) ? optional.of(variables.message + \"!!\") : ((now.getHours()" - + " < 24) ? optional.of(variables.message + \"!!!\") : optional.none()))) :" - + " optional.of(variables.greeting + \", \" + variables.person)))))"); + "cel.@block([now.getHours()], (@index0 >= 20) ? ((@index0 < 21) ? optional.of(\"goodbye, me!\") : ((@index0 < 22) ? optional.of(\"goodbye, me!!\") : ((@index0 < 24) ? optional.of(\"goodbye, me!!!\") : optional.none()))) : optional.of(\"hello, me\"))") + ; private final String name; private final boolean producesOptionalResult;