From f773edf4bf559e38bdfebad2a865bc2e03389e3b Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Tue, 23 Sep 2025 11:35:40 -0700 Subject: [PATCH] Allow constant folding to fold equals operator PiperOrigin-RevId: 810518582 --- .../optimizers/ConstantFoldingOptimizer.java | 46 +++++++++++++++++++ .../ConstantFoldingOptimizerTest.java | 28 ++++++++++- 2 files changed, 72 insertions(+), 2 deletions(-) diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java index fd52f4138..b40250fe6 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java @@ -167,6 +167,16 @@ private boolean canFold(CelNavigableMutableExpr navigableExpr) { && cond.constant().getKind().equals(CelConstant.Kind.BOOLEAN_VALUE); } + if (functionName.equals(Operator.EQUALS.getFunction()) + || functionName.equals(Operator.NOT_EQUALS.getFunction())) { + if (mutableCall.args().stream() + .anyMatch(node -> isExprConstantOfKind(node, CelConstant.Kind.BOOLEAN_VALUE)) + || mutableCall.args().stream() + .allMatch(node -> node.getKind().equals(Kind.CONSTANT))) { + return true; + } + } + if (functionName.equals(Operator.IN.getFunction())) { return canFoldInOperator(navigableExpr); } @@ -393,6 +403,38 @@ private Optional maybePruneBranches( } } } + } else if (function.equals(Operator.EQUALS.getFunction()) + || function.equals(Operator.NOT_EQUALS.getFunction())) { + CelMutableExpr lhs = call.args().get(0); + CelMutableExpr rhs = call.args().get(1); + boolean lhsIsBoolean = isExprConstantOfKind(lhs, CelConstant.Kind.BOOLEAN_VALUE); + boolean rhsIsBoolean = isExprConstantOfKind(rhs, CelConstant.Kind.BOOLEAN_VALUE); + boolean invertCondition = function.equals(Operator.NOT_EQUALS.getFunction()); + Optional replacementExpr = Optional.empty(); + + if (lhs.getKind().equals(Kind.CONSTANT) && rhs.getKind().equals(Kind.CONSTANT)) { + // If both args are const, don't prune any branches and let maybeFold method evaluate this + // subExpr + return Optional.empty(); + } else if (lhsIsBoolean) { + boolean cond = invertCondition != lhs.constant().booleanValue(); + replacementExpr = + Optional.of( + cond + ? rhs + : CelMutableExpr.ofCall( + CelMutableCall.create(Operator.LOGICAL_NOT.getFunction(), rhs))); + } else if (rhsIsBoolean) { + boolean cond = invertCondition != rhs.constant().booleanValue(); + replacementExpr = + Optional.of( + cond + ? lhs + : CelMutableExpr.ofCall( + CelMutableCall.create(Operator.LOGICAL_NOT.getFunction(), lhs))); + } + + return replacementExpr.map(node -> astMutator.replaceSubtree(mutableAst, node, expr.id())); } return Optional.empty(); @@ -663,6 +705,10 @@ public static Builder newBuilder() { ConstantFoldingOptions() {} } + private static boolean isExprConstantOfKind(CelMutableExpr expr, CelConstant.Kind constantKind) { + return expr.getKind().equals(Kind.CONSTANT) && expr.constant().getKind().equals(constantKind); + } + private ConstantFoldingOptimizer(ConstantFoldingOptions constantFoldingOptions) { this.constantFoldingOptions = constantFoldingOptions; this.astMutator = AstMutator.newInstance(constantFoldingOptions.maxIterationLimit()); diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java index ea1e77edb..4efbeb1c2 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java @@ -66,13 +66,13 @@ public class ConstantFoldingOptimizerTest { CelExtensions.math(CelOptions.DEFAULT), CelExtensions.strings(), CelExtensions.sets(CelOptions.DEFAULT), - CelExtensions.encoders()) + CelExtensions.encoders(CelOptions.DEFAULT)) .addRuntimeLibraries( CelOptionalLibrary.INSTANCE, CelExtensions.math(CelOptions.DEFAULT), CelExtensions.strings(), CelExtensions.sets(CelOptions.DEFAULT), - CelExtensions.encoders()) + CelExtensions.encoders(CelOptions.DEFAULT)) .build(); private static final CelOptimizer CEL_OPTIMIZER = @@ -189,6 +189,28 @@ public class ConstantFoldingOptimizerTest { @TestParameters("{source: 'sets.contains([1], [1])', expected: 'true'}") @TestParameters( "{source: 'cel.bind(r0, [1, 2, 3], cel.bind(r1, 1 in r0, r1))', expected: 'true'}") + @TestParameters("{source: 'x == true', expected: 'x'}") + @TestParameters("{source: 'true == x', expected: 'x'}") + @TestParameters("{source: 'x == false', expected: '!x'}") + @TestParameters("{source: 'false == x', expected: '!x'}") + @TestParameters("{source: 'true == false', expected: 'false'}") + @TestParameters("{source: 'true == true', expected: 'true'}") + @TestParameters("{source: 'false == true', expected: 'false'}") + @TestParameters("{source: 'false == false', expected: 'true'}") + @TestParameters("{source: '10 == 42', expected: 'false'}") + @TestParameters("{source: '42 == 42', expected: 'true'}") + @TestParameters("{source: 'x != true', expected: '!x'}") + @TestParameters("{source: 'true != x', expected: '!x'}") + @TestParameters("{source: 'x != false', expected: 'x'}") + @TestParameters("{source: 'false != x', expected: 'x'}") + @TestParameters("{source: 'true != false', expected: 'true'}") + @TestParameters("{source: 'true != true', expected: 'false'}") + @TestParameters("{source: 'false != true', expected: 'true'}") + @TestParameters("{source: 'false != false', expected: 'false'}") + @TestParameters("{source: '10 != 42', expected: 'true'}") + @TestParameters("{source: '42 != 42', expected: 'false'}") + @TestParameters("{source: '[\"foo\",\"bar\"] == [\"foo\",\"bar\"]', expected: 'true'}") + @TestParameters("{source: '[\"bar\",\"foo\"] == [\"foo\",\"bar\"]', expected: 'false'}") // TODO: Support folding lists with mixed types. This requires mutable lists. // @TestParameters("{source: 'dyn([1]) + [1.0]'}") public void constantFold_success(String source, String expected) throws Exception { @@ -324,6 +346,8 @@ public void constantFold_macros_withoutMacroCallMetadata(String source) throws E @TestParameters("{source: 'TestAllTypes{single_int32: x, repeated_int32: [1, 2, 3]}'}") @TestParameters("{source: 'get_true() == get_true()'}") @TestParameters("{source: 'get_true() == true'}") + @TestParameters("{source: 'x == x'}") + @TestParameters("{source: 'x == 42'}") public void constantFold_noOp(String source) throws Exception { CelAbstractSyntaxTree ast = CEL.compile(source).getAst();