From 4180fc98dd7f6a1bf1b6205c9ffa15d6145ac02c Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Mon, 3 Nov 2025 13:42:24 -0800 Subject: [PATCH] Add isStrict flag to denote function overloads as strict PiperOrigin-RevId: 827633212 --- .../src/main/java/dev/cel/runtime/BUILD.bazel | 6 ++-- .../dev/cel/runtime/CelFunctionBinding.java | 8 +++-- .../cel/runtime/CelLateFunctionBindings.java | 1 + .../dev/cel/runtime/CelResolvedOverload.java | 29 ++++++++++++++++--- .../dev/cel/runtime/CelRuntimeLegacyImpl.java | 2 +- .../dev/cel/runtime/DefaultDispatcher.java | 29 ++----------------- .../dev/cel/runtime/FunctionBindingImpl.java | 13 ++++++++- .../java/dev/cel/runtime/LiteRuntimeImpl.java | 5 +++- .../main/java/dev/cel/runtime/Registrar.java | 6 ++-- .../dev/cel/runtime/ResolvedOverload.java | 15 ++++++++++ .../cel/runtime/CelResolvedOverloadTest.java | 6 ++-- .../cel/runtime/DefaultDispatcherTest.java | 8 +++-- .../cel/runtime/DefaultInterpreterTest.java | 7 ++++- 13 files changed, 87 insertions(+), 48 deletions(-) diff --git a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel index e5df8250b..4a5f6dcc7 100644 --- a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel @@ -118,9 +118,9 @@ java_library( tags = [ ], deps = [ - ":base", ":evaluation_exception", ":evaluation_exception_builder", + ":function_overload", ":resolved_overload", ":resolved_overload_internal", "//:auto_value", @@ -138,9 +138,9 @@ cel_android_library( srcs = DISPATCHER_SOURCES, visibility = ["//visibility:private"], deps = [ - ":base_android", ":evaluation_exception", ":evaluation_exception_builder", + ":function_overload_android", ":function_resolver_android", ":resolved_overload_android", ":resolved_overload_internal_android", @@ -725,7 +725,6 @@ java_library( ], deps = [ ":function_overload", - "//common/annotations", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", ], @@ -738,7 +737,6 @@ cel_android_library( ], deps = [ ":function_overload_android", - "//common/annotations", "@maven//:com_google_errorprone_error_prone_annotations", "@maven_android//:com_google_guava_guava", ], diff --git a/runtime/src/main/java/dev/cel/runtime/CelFunctionBinding.java b/runtime/src/main/java/dev/cel/runtime/CelFunctionBinding.java index 8fe2b8a2e..79b0f3f54 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelFunctionBinding.java +++ b/runtime/src/main/java/dev/cel/runtime/CelFunctionBinding.java @@ -16,7 +16,6 @@ import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.Immutable; -import dev.cel.common.annotations.Internal; /** * Binding consisting of an overload id, a Java-native argument signature, and an overload @@ -36,7 +35,7 @@ * *

Examples: string_startsWith_string, mathMax_list, lessThan_money_money */ -@Internal + @Immutable public interface CelFunctionBinding { String getOverloadId(); @@ -45,6 +44,8 @@ public interface CelFunctionBinding { CelFunctionOverload getDefinition(); + boolean isStrict(); + /** Create a unary function binding from the {@code overloadId}, {@code arg}, and {@code impl}. */ @SuppressWarnings("unchecked") static CelFunctionBinding from( @@ -66,6 +67,7 @@ static CelFunctionBinding from( /** Create a function binding from the {@code overloadId}, {@code argTypes}, and {@code impl}. */ static CelFunctionBinding from( String overloadId, Iterable> argTypes, CelFunctionOverload impl) { - return new FunctionBindingImpl(overloadId, ImmutableList.copyOf(argTypes), impl); + return new FunctionBindingImpl( + overloadId, ImmutableList.copyOf(argTypes), impl, /* isStrict= */ true); } } diff --git a/runtime/src/main/java/dev/cel/runtime/CelLateFunctionBindings.java b/runtime/src/main/java/dev/cel/runtime/CelLateFunctionBindings.java index 282ef1eb9..6762feec0 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelLateFunctionBindings.java +++ b/runtime/src/main/java/dev/cel/runtime/CelLateFunctionBindings.java @@ -58,6 +58,7 @@ private static ResolvedOverload createResolvedOverload(CelFunctionBinding bindin return CelResolvedOverload.of( binding.getOverloadId(), (args) -> binding.getDefinition().apply(args), + binding.isStrict(), binding.getArgTypes()); } } diff --git a/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java b/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java index 9725315cc..26f86a459 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java +++ b/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java @@ -35,6 +35,21 @@ abstract class CelResolvedOverload implements ResolvedOverload { @Override public abstract ImmutableList> getParameterTypes(); + /* Denotes whether an overload is strict. + * + *

A strict function will not be invoked if any of its arguments are an error or unknown value. + * The runtime automatically propagates the error or unknown instead. + * + *

A non-strict function will be invoked even if its arguments contain errors or unknowns. The + * function's implementation is then responsible for handling these values. This is primarily used + * for short-circuiting logical operators (e.g., `||`, `&&`) and comprehension's + * internal @not_strictly_false function. + * + *

In a vast majority of cases, this should be set to true. + */ + @Override + public abstract boolean isStrict(); + /** The function definition. */ @Override public abstract CelFunctionOverload getDefinition(); @@ -43,16 +58,22 @@ abstract class CelResolvedOverload implements ResolvedOverload { * Creates a new resolved overload from the given overload id, parameter types, and definition. */ public static CelResolvedOverload of( - String overloadId, CelFunctionOverload definition, Class... parameterTypes) { - return of(overloadId, definition, ImmutableList.copyOf(parameterTypes)); + String overloadId, + CelFunctionOverload definition, + boolean isStrict, + Class... parameterTypes) { + return of(overloadId, definition, isStrict, ImmutableList.copyOf(parameterTypes)); } /** * Creates a new resolved overload from the given overload id, parameter types, and definition. */ public static CelResolvedOverload of( - String overloadId, CelFunctionOverload definition, List> parameterTypes) { + String overloadId, + CelFunctionOverload definition, + boolean isStrict, + List> parameterTypes) { return new AutoValue_CelResolvedOverload( - overloadId, ImmutableList.copyOf(parameterTypes), definition); + overloadId, ImmutableList.copyOf(parameterTypes), isStrict, definition); } } diff --git a/runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java b/runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java index 2947445d1..4f58d44db 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java @@ -293,7 +293,7 @@ public CelRuntimeLegacyImpl build() { .forEach( (String overloadId, CelFunctionBinding func) -> dispatcher.add( - overloadId, func.getArgTypes(), (args) -> func.getDefinition().apply(args))); + overloadId, func.getArgTypes(), func.isStrict(), func.getDefinition())); RuntimeTypeProvider runtimeTypeProvider; diff --git a/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java b/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java index c4c27005e..6228dd253 100644 --- a/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java +++ b/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java @@ -32,7 +32,7 @@ *

Should be final, do not mock; mocking {@link Dispatcher} instead. */ @ThreadSafe -final class DefaultDispatcher implements Dispatcher, Registrar { +final class DefaultDispatcher implements Dispatcher { public static DefaultDispatcher create() { return new DefaultDispatcher(); } @@ -40,32 +40,9 @@ public static DefaultDispatcher create() { @GuardedBy("this") private final Map overloads = new HashMap<>(); - @Override - @SuppressWarnings("unchecked") - public synchronized void add( - String overloadId, Class argType, final Registrar.UnaryFunction function) { - overloads.put( - overloadId, - CelResolvedOverload.of(overloadId, args -> function.apply((T) args[0]), argType)); - } - - @Override - @SuppressWarnings("unchecked") - public synchronized void add( - String overloadId, - Class argType1, - Class argType2, - final Registrar.BinaryFunction function) { - overloads.put( - overloadId, - CelResolvedOverload.of( - overloadId, args -> function.apply((T1) args[0], (T2) args[1]), argType1, argType2)); - } - - @Override public synchronized void add( - String overloadId, List> argTypes, Registrar.Function function) { - overloads.put(overloadId, CelResolvedOverload.of(overloadId, function, argTypes)); + String overloadId, List> argTypes, boolean isStrict, CelFunctionOverload overload) { + overloads.put(overloadId, CelResolvedOverload.of(overloadId, overload, isStrict, argTypes)); } @Override diff --git a/runtime/src/main/java/dev/cel/runtime/FunctionBindingImpl.java b/runtime/src/main/java/dev/cel/runtime/FunctionBindingImpl.java index 501608571..b554ce41a 100644 --- a/runtime/src/main/java/dev/cel/runtime/FunctionBindingImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/FunctionBindingImpl.java @@ -26,6 +26,8 @@ final class FunctionBindingImpl implements CelFunctionBinding { private final CelFunctionOverload definition; + private final boolean isStrict; + @Override public String getOverloadId() { return overloadId; @@ -41,10 +43,19 @@ public CelFunctionOverload getDefinition() { return definition; } + @Override + public boolean isStrict() { + return isStrict; + } + FunctionBindingImpl( - String overloadId, ImmutableList> argTypes, CelFunctionOverload definition) { + String overloadId, + ImmutableList> argTypes, + CelFunctionOverload definition, + boolean isStrict) { this.overloadId = overloadId; this.argTypes = argTypes; this.definition = definition; + this.isStrict = isStrict; } } diff --git a/runtime/src/main/java/dev/cel/runtime/LiteRuntimeImpl.java b/runtime/src/main/java/dev/cel/runtime/LiteRuntimeImpl.java index 152c96160..aba73aed4 100644 --- a/runtime/src/main/java/dev/cel/runtime/LiteRuntimeImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/LiteRuntimeImpl.java @@ -182,7 +182,10 @@ public CelLiteRuntime build() { .forEach( (String overloadId, CelFunctionBinding func) -> dispatcher.add( - overloadId, func.getArgTypes(), (args) -> func.getDefinition().apply(args))); + overloadId, + func.getArgTypes(), + func.isStrict(), + (args) -> func.getDefinition().apply(args))); Interpreter interpreter = new DefaultInterpreter( diff --git a/runtime/src/main/java/dev/cel/runtime/Registrar.java b/runtime/src/main/java/dev/cel/runtime/Registrar.java index 12dbe6cb3..467a7c426 100644 --- a/runtime/src/main/java/dev/cel/runtime/Registrar.java +++ b/runtime/src/main/java/dev/cel/runtime/Registrar.java @@ -15,15 +15,15 @@ package dev.cel.runtime; import com.google.errorprone.annotations.Immutable; -import dev.cel.common.annotations.Internal; import java.util.List; /** * An object which registers the functions that a {@link Dispatcher} calls. * - *

CEL Library Internals. Do Not Use. + * @deprecated Do not use. This interface exists solely for legacy async stack compatibility + * reasons. */ -@Internal +@Deprecated public interface Registrar { /** Interface describing the general signature of all CEL custom function implementations. */ diff --git a/runtime/src/main/java/dev/cel/runtime/ResolvedOverload.java b/runtime/src/main/java/dev/cel/runtime/ResolvedOverload.java index bc7544199..d0b4d2f77 100644 --- a/runtime/src/main/java/dev/cel/runtime/ResolvedOverload.java +++ b/runtime/src/main/java/dev/cel/runtime/ResolvedOverload.java @@ -35,6 +35,21 @@ interface ResolvedOverload { /** The function definition. */ CelFunctionOverload getDefinition(); + /** + * Denotes whether an overload is strict. + * + *

A strict function will not be invoked if any of its arguments are an error or unknown value. + * The runtime automatically propagates the error or unknown instead. + * + *

A non-strict function will be invoked even if its arguments contain errors or unknowns. The + * function's implementation is then responsible for handling these values. This is primarily used + * for short-circuiting logical operators (e.g., `||`, `&&`) and comprehension's + * internal @not_strictly_false function. + * + *

In a vast majority of cases, a function should be kept strict. + */ + boolean isStrict(); + /** * Returns true if the overload's expected argument types match the types of the given arguments. */ diff --git a/runtime/src/test/java/dev/cel/runtime/CelResolvedOverloadTest.java b/runtime/src/test/java/dev/cel/runtime/CelResolvedOverloadTest.java index 2c81b89d4..0fdc4f65d 100644 --- a/runtime/src/test/java/dev/cel/runtime/CelResolvedOverloadTest.java +++ b/runtime/src/test/java/dev/cel/runtime/CelResolvedOverloadTest.java @@ -32,6 +32,7 @@ CelResolvedOverload getIncrementIntOverload() { Long arg = (Long) args[0]; return arg + 1; }, + /* isStrict= */ true, Long.class); } @@ -43,14 +44,15 @@ public void canHandle_matchingTypes_returnsTrue() { @Test public void canHandle_nullMessageType_returnsTrue() { CelResolvedOverload overload = - CelResolvedOverload.of("identity", (args) -> args[0], TestAllTypes.class); + CelResolvedOverload.of( + "identity", (args) -> args[0], /* isStrict= */ true, TestAllTypes.class); assertThat(overload.canHandle(new Object[] {null})).isTrue(); } @Test public void canHandle_nullPrimitive_returnsFalse() { CelResolvedOverload overload = - CelResolvedOverload.of("identity", (args) -> args[0], Long.class); + CelResolvedOverload.of("identity", (args) -> args[0], /* isStrict= */ true, Long.class); assertThat(overload.canHandle(new Object[] {null})).isFalse(); } diff --git a/runtime/src/test/java/dev/cel/runtime/DefaultDispatcherTest.java b/runtime/src/test/java/dev/cel/runtime/DefaultDispatcherTest.java index 03eba51ce..fb6831201 100644 --- a/runtime/src/test/java/dev/cel/runtime/DefaultDispatcherTest.java +++ b/runtime/src/test/java/dev/cel/runtime/DefaultDispatcherTest.java @@ -35,9 +35,13 @@ public final class DefaultDispatcherTest { public void setup() { overloads = new HashMap<>(); overloads.put( - "overload_1", CelResolvedOverload.of("overload_1", args -> (Long) args[0] + 1, Long.class)); + "overload_1", + CelResolvedOverload.of( + "overload_1", args -> (Long) args[0] + 1, /* isStrict= */ true, Long.class)); overloads.put( - "overload_2", CelResolvedOverload.of("overload_2", args -> (Long) args[0] + 2, Long.class)); + "overload_2", + CelResolvedOverload.of( + "overload_2", args -> (Long) args[0] + 2, /* isStrict= */ true, Long.class)); } @Test diff --git a/runtime/src/test/java/dev/cel/runtime/DefaultInterpreterTest.java b/runtime/src/test/java/dev/cel/runtime/DefaultInterpreterTest.java index 4f3501d41..e9ac6b5a3 100644 --- a/runtime/src/test/java/dev/cel/runtime/DefaultInterpreterTest.java +++ b/runtime/src/test/java/dev/cel/runtime/DefaultInterpreterTest.java @@ -17,6 +17,7 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; +import com.google.common.collect.ImmutableList; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelFunctionDecl; @@ -74,7 +75,11 @@ public Object adapt(String messageName, Object message) { }; CelAbstractSyntaxTree ast = celCompiler.compile("[1].all(x, [2].all(y, error()))").getAst(); DefaultDispatcher dispatcher = DefaultDispatcher.create(); - dispatcher.add("error", long.class, (args) -> new IllegalArgumentException("Always throws")); + dispatcher.add( + "error", + ImmutableList.of(long.class), + /* isStrict= */ true, + (args) -> new IllegalArgumentException("Always throws")); DefaultInterpreter defaultInterpreter = new DefaultInterpreter(new TypeResolver(), emptyProvider, dispatcher, CelOptions.DEFAULT); DefaultInterpretable interpretable =