From 99427705ae46bc8f3a01f3b4e2c618e77b1f8ec9 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Fri, 26 Sep 2025 16:45:38 -0700 Subject: [PATCH] POC: Const fold times --- .../cel/checker/CelStandardDeclarations.java | 2 +- .../dev/cel/optimizer/optimizers/BUILD.bazel | 2 + .../optimizers/ConstantFoldingOptimizer.java | 52 ++++++++++++++++--- .../ConstantFoldingOptimizerTest.java | 23 +++++--- 4 files changed, 66 insertions(+), 13 deletions(-) diff --git a/checker/src/main/java/dev/cel/checker/CelStandardDeclarations.java b/checker/src/main/java/dev/cel/checker/CelStandardDeclarations.java index 3865430ec..94b7ad313 100644 --- a/checker/src/main/java/dev/cel/checker/CelStandardDeclarations.java +++ b/checker/src/main/java/dev/cel/checker/CelStandardDeclarations.java @@ -1488,7 +1488,7 @@ public CelFunctionDecl functionDecl() { return celFunctionDecl; } - String functionName() { + public String functionName() { return functionName; } diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel b/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel index 7674870a8..b880be632 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel @@ -19,6 +19,7 @@ java_library( ":default_optimizer_constants", "//:auto_value", "//bundle:cel", + "//checker:standard_decl", "//common:cel_ast", "//common:cel_source", "//common:compiler_common", @@ -35,6 +36,7 @@ java_library( "//runtime", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", + "@maven//:org_threeten_threeten_extra", ], ) diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java index b40250fe6..cebc1de5c 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java @@ -16,6 +16,8 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.MoreCollectors.onlyElement; +import static dev.cel.checker.CelStandardDeclarations.StandardFunction.DURATION; +import static dev.cel.checker.CelStandardDeclarations.StandardFunction.TIMESTAMP; import com.google.auto.value.AutoValue; import com.google.common.collect.ImmutableList; @@ -45,6 +47,9 @@ import dev.cel.optimizer.CelOptimizationException; import dev.cel.parser.Operator; import dev.cel.runtime.CelEvaluationException; +import java.math.BigDecimal; +import java.time.Duration; +import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -142,6 +147,14 @@ private boolean canFold(CelNavigableMutableExpr navigableExpr) { return false; } + // Timestamps/durations in CEL are calls, but they are effectively treated as literals. + // Expressions like timestamp(123) cannot be folded directly, but arithmetics involving timestamps + // can be folded. + // Ex: timestamp(123) - timestamp(100) = duration("23s") + if (isExprTimestampOrDuration(navigableExpr)) { + return false; + } + CelMutableCall mutableCall = navigableExpr.expr().call(); String functionName = mutableCall.function(); @@ -197,14 +210,19 @@ private boolean canFold(CelNavigableMutableExpr navigableExpr) { private boolean containsFoldableFunctionOnly(CelNavigableMutableExpr navigableExpr) { return navigableExpr .allNodes() + .filter(node -> node.getKind().equals(Kind.CALL)) + .map(node -> node.expr().call()) .allMatch( - node -> { - if (node.getKind().equals(Kind.CALL)) { - return foldableFunctions.contains(node.expr().call().function()); - } + call -> foldableFunctions.contains(call.function())); + } + + private static boolean isExprTimestampOrDuration(CelNavigableMutableExpr navigableExpr) { + if (!navigableExpr.getKind().equals(Kind.CALL)) { + return true; + } - return true; - }); + CelMutableCall call = navigableExpr.expr().call(); + return call.function().equals(TIMESTAMP.functionName()) || call.function().equals(DURATION.functionName()); } private static boolean canFoldInOperator(CelNavigableMutableExpr navigableExpr) { @@ -318,12 +336,32 @@ private Optional maybeAdaptEvaluatedResult(Object result) { } return Optional.of(CelMutableExpr.ofMap(CelMutableMap.create(mapEntries))); + } else if (result instanceof Duration) { + String durationStrArg = formatDurationArgument((Duration) result); + CelMutableCall durationCall = CelMutableCall.create(DURATION.functionName(), CelMutableExpr.ofConstant(CelConstant.ofValue(durationStrArg))); + return Optional.of(CelMutableExpr.ofCall(durationCall)); + } else if (result instanceof Instant) { + String timestampStrArg = ((Instant) result).toString(); + CelMutableCall timestampCall = CelMutableCall.create(TIMESTAMP.functionName(), CelMutableExpr.ofConstant(CelConstant.ofValue(timestampStrArg))); + return Optional.of(CelMutableExpr.ofCall(timestampCall)); } // Evaluated result cannot be folded (e.g: unknowns) return Optional.empty(); } + private static String formatDurationArgument(Duration duration) { + if (duration.isZero()) { + return "0"; + } + + BigDecimal seconds = BigDecimal.valueOf(duration.getSeconds()); + BigDecimal nanos = BigDecimal.valueOf(duration.getNano(), 9); + BigDecimal totalSeconds = seconds.add(nanos); + + return totalSeconds.stripTrailingZeros().toPlainString() + "s"; + } + private Optional maybeRewriteOptional( Optional optResult, CelMutableAst mutableAst, CelMutableExpr expr) { if (!optResult.isPresent()) { @@ -352,6 +390,8 @@ private Optional maybeRewriteOptional( return Optional.empty(); } + + /** Inspects the non-strict calls to determine whether a branch can be removed. */ private Optional maybePruneBranches( CelMutableAst mutableAst, CelMutableExpr expr) { diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java index 4efbeb1c2..d1eb1bb9a 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java @@ -46,6 +46,10 @@ @RunWith(TestParameterInjector.class) public class ConstantFoldingOptimizerTest { + private static final CelOptions CEL_OPTIONS = CelOptions.current() + .enableTimestampEpoch(true) + .evaluateCanonicalTypesToNativeValues(true) + .build(); private static final Cel CEL = CelFactory.standardCelBuilder() .addVar("x", SimpleType.DYN) @@ -60,19 +64,20 @@ public class ConstantFoldingOptimizerTest { CelFunctionBinding.from("get_true_overload", ImmutableList.of(), unused -> true)) .addMessageTypes(TestAllTypes.getDescriptor()) .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) + .setOptions(CEL_OPTIONS) .addCompilerLibraries( CelExtensions.bindings(), CelOptionalLibrary.INSTANCE, - CelExtensions.math(CelOptions.DEFAULT), + CelExtensions.math(CEL_OPTIONS), CelExtensions.strings(), - CelExtensions.sets(CelOptions.DEFAULT), - CelExtensions.encoders(CelOptions.DEFAULT)) + CelExtensions.sets(CEL_OPTIONS), + CelExtensions.encoders(CEL_OPTIONS)) .addRuntimeLibraries( CelOptionalLibrary.INSTANCE, - CelExtensions.math(CelOptions.DEFAULT), + CelExtensions.math(CEL_OPTIONS), CelExtensions.strings(), - CelExtensions.sets(CelOptions.DEFAULT), - CelExtensions.encoders(CelOptions.DEFAULT)) + CelExtensions.sets(CEL_OPTIONS), + CelExtensions.encoders(CEL_OPTIONS)) .build(); private static final CelOptimizer CEL_OPTIMIZER = @@ -211,6 +216,10 @@ public class ConstantFoldingOptimizerTest { @TestParameters("{source: '42 != 42', expected: 'false'}") @TestParameters("{source: '[\"foo\",\"bar\"] == [\"foo\",\"bar\"]', expected: 'true'}") @TestParameters("{source: '[\"bar\",\"foo\"] == [\"foo\",\"bar\"]', expected: 'false'}") + @TestParameters("{source: 'duration(\"1h\") - duration(\"60m\")', expected: 'duration(\"0\")'}") + @TestParameters("{source: 'duration(\"2h23m42s12ms42us92ns\") + duration(\"129481231298125ns\")', expected: 'duration(\"138103.243340217s\")'}") + @TestParameters("{source: 'timestamp(900000) - timestamp(100)', expected: 'duration(\"899900s\")'}") + @TestParameters("{source: 'timestamp(\"2000-01-01T00:02:03.2123Z\") + duration(\"25h2m32s42ms53us29ns\")', expected: 'timestamp(\"2000-01-02T01:04:35.254353029Z\")'}") // TODO: Support folding lists with mixed types. This requires mutable lists. // @TestParameters("{source: 'dyn([1]) + [1.0]'}") public void constantFold_success(String source, String expected) throws Exception { @@ -348,6 +357,8 @@ public void constantFold_macros_withoutMacroCallMetadata(String source) throws E @TestParameters("{source: 'get_true() == true'}") @TestParameters("{source: 'x == x'}") @TestParameters("{source: 'x == 42'}") + @TestParameters("{source: 'timestamp(100)'}") + @TestParameters("{source: 'duration(\"1h\")'}") public void constantFold_noOp(String source) throws Exception { CelAbstractSyntaxTree ast = CEL.compile(source).getAst();