From cd05924fb69c74fcd2fdaf523a7dba6a6dc36c7a Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Fri, 31 Oct 2025 14:35:17 -0700 Subject: [PATCH] Consolidate function overload and resolver interfaces PiperOrigin-RevId: 826632322 --- runtime/BUILD.bazel | 24 +++-- .../src/main/java/dev/cel/runtime/BUILD.bazel | 94 ++++++++++++------- .../dev/cel/runtime/CelFunctionOverload.java | 13 ++- .../dev/cel/runtime/CelFunctionResolver.java | 18 +++- .../cel/runtime/CelLateFunctionBindings.java | 4 +- .../dev/cel/runtime/CelResolvedOverload.java | 12 ++- .../dev/cel/runtime/CelStandardFunctions.java | 62 +++++------- .../dev/cel/runtime/DefaultDispatcher.java | 42 +-------- .../dev/cel/runtime/DefaultInterpreter.java | 12 +-- .../main/java/dev/cel/runtime/Dispatcher.java | 2 +- .../dev/cel/runtime/FunctionOverload.java | 46 --------- .../dev/cel/runtime/FunctionResolver.java | 44 --------- .../java/dev/cel/runtime/Interpretable.java | 4 +- .../main/java/dev/cel/runtime/Registrar.java | 6 +- .../dev/cel/runtime/ResolvedOverload.java | 2 +- .../runtime/UnknownTrackingInterpretable.java | 2 +- .../java/dev/cel/runtime/standard/BUILD.bazel | 6 +- .../runtime/standard/CelStandardOverload.java | 4 +- .../src/test/java/dev/cel/runtime/BUILD.bazel | 4 +- .../cel/runtime/CelResolvedOverloadTest.java | 9 +- .../cel/runtime/CelStandardFunctionsTest.java | 4 +- .../cel/runtime/DefaultDispatcherTest.java | 8 +- runtime/standard/BUILD.bazel | 10 ++ 23 files changed, 187 insertions(+), 245 deletions(-) delete mode 100644 runtime/src/main/java/dev/cel/runtime/FunctionOverload.java delete mode 100644 runtime/src/main/java/dev/cel/runtime/FunctionResolver.java diff --git a/runtime/BUILD.bazel b/runtime/BUILD.bazel index b7d59ce96..3bcb86766 100644 --- a/runtime/BUILD.bazel +++ b/runtime/BUILD.bazel @@ -76,12 +76,6 @@ cel_android_library( exports = ["//runtime/src/main/java/dev/cel/runtime:late_function_binding_android"], ) -java_library( - name = "function_overload_impl", - visibility = ["//:internal"], - exports = ["//runtime/src/main/java/dev/cel/runtime:function_overload_impl"], -) - java_library( name = "evaluation_exception_builder", exports = ["//runtime/src/main/java/dev/cel/runtime:evaluation_exception_builder"], @@ -220,3 +214,21 @@ cel_android_library( visibility = ["//:internal"], exports = ["//runtime/src/main/java/dev/cel/runtime:lite_runtime_impl_android"], ) + +java_library( + name = "resolved_overload", + visibility = ["//:internal"], + exports = ["//runtime/src/main/java/dev/cel/runtime:resolved_overload"], +) + +cel_android_library( + name = "resolved_overload_android", + visibility = ["//:internal"], + exports = ["//runtime/src/main/java/dev/cel/runtime:resolved_overload_android"], +) + +java_library( + name = "resolved_overload_internal", + visibility = ["//:internal"], + exports = ["//runtime/src/main/java/dev/cel/runtime:resolved_overload_internal"], +) diff --git a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel index 3df163f79..e5df8250b 100644 --- a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel @@ -51,13 +51,6 @@ FUNCTION_BINDING_SOURCES = [ "FunctionBindingImpl.java", ] -# keep sorted -FUNCTION_OVERLOAD_IMPL_SOURCES = [ - "FunctionOverload.java", - "FunctionResolver.java", - "ResolvedOverload.java", -] - # keep sorted INTERPRABLE_SOURCES = [ "GlobalResolver.java", @@ -128,10 +121,12 @@ java_library( ":base", ":evaluation_exception", ":evaluation_exception_builder", - ":function_overload_impl", + ":resolved_overload", + ":resolved_overload_internal", "//:auto_value", "//common:error_codes", "//common/annotations", + "//runtime:function_resolver", "@maven//:com_google_code_findbugs_annotations", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", @@ -146,7 +141,9 @@ cel_android_library( ":base_android", ":evaluation_exception", ":evaluation_exception_builder", - ":function_overload_impl_android", + ":function_resolver_android", + ":resolved_overload_android", + ":resolved_overload_internal_android", "//:auto_value", "//common:error_codes", "//common/annotations", @@ -250,7 +247,7 @@ java_library( tags = [ ], deps = [ - ":function_overload_impl", + ":function_overload", ":metadata", "//common:cel_ast", "//common/annotations", @@ -264,11 +261,10 @@ cel_android_library( srcs = BASE_SOURCES, visibility = ["//visibility:private"], deps = [ - ":function_overload_impl_android", + ":function_overload_android", ":metadata", "//common:cel_ast_android", "//common/annotations", - "@maven//:com_google_code_findbugs_annotations", "@maven//:com_google_errorprone_error_prone_annotations", "@maven_android//:com_google_guava_guava", ], @@ -288,10 +284,10 @@ java_library( ":evaluation_exception", ":evaluation_exception_builder", ":evaluation_listener", - ":function_overload_impl", ":interpretable", ":interpreter_util", ":metadata", + ":resolved_overload_internal", ":runtime_helpers", ":runtime_type_provider", ":type_resolver", @@ -306,6 +302,7 @@ java_library( "//common/types", "//common/types:type_providers", "//common/values:cel_byte_string", + "//runtime:function_resolver", "@maven//:com_google_code_findbugs_annotations", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", @@ -326,10 +323,11 @@ cel_android_library( ":evaluation_exception", ":evaluation_exception_builder", ":evaluation_listener_android", - ":function_overload_impl_android", + ":function_resolver_android", ":interpretable_android", ":interpreter_util_android", ":metadata", + ":resolved_overload_internal_android", ":runtime_helpers_android", ":runtime_type_provider_android", ":type_resolver_android", @@ -486,7 +484,6 @@ RUNTIME_SOURCES = [ LATE_FUNCTION_BINDING_SOURCES = [ "CelLateFunctionBindings.java", - "CelResolvedOverload.java", ] java_library( @@ -498,9 +495,9 @@ java_library( ":dispatcher", ":evaluation_exception", ":function_binding", - ":function_overload", - ":function_overload_impl", ":function_resolver", + ":resolved_overload", + ":resolved_overload_internal", "//:auto_value", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", @@ -516,9 +513,9 @@ cel_android_library( ":dispatcher_android", ":evaluation_exception", ":function_binding_android", - ":function_overload_android", - ":function_overload_impl_android", ":function_resolver_android", + ":resolved_overload_android", + ":resolved_overload_internal_android", "//:auto_value", "@maven//:com_google_errorprone_error_prone_annotations", "@maven_android//:com_google_guava_guava", @@ -586,8 +583,8 @@ java_library( deps = [ ":evaluation_exception", ":evaluation_listener", - ":function_overload_impl", "//common/annotations", + "//runtime:function_resolver", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:org_jspecify_jspecify", ], @@ -600,7 +597,7 @@ cel_android_library( deps = [ ":evaluation_exception", ":evaluation_listener_android", - ":function_overload_impl_android", + ":function_resolver_android", "//common/annotations", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:org_jspecify_jspecify", @@ -652,6 +649,7 @@ java_library( "//runtime/standard:not_equals", "//runtime/standard:size", "//runtime/standard:standard_function", + "//runtime/standard:standard_overload", "//runtime/standard:starts_with", "//runtime/standard:string", "//runtime/standard:subtract", @@ -707,6 +705,7 @@ cel_android_library( "//runtime/standard:not_equals_android", "//runtime/standard:size_android", "//runtime/standard:standard_function_android", + "//runtime/standard:standard_overload_android", "//runtime/standard:starts_with_android", "//runtime/standard:string_android", "//runtime/standard:subtract_android", @@ -751,7 +750,8 @@ java_library( tags = [ ], deps = [ - ":function_overload_impl", + ":evaluation_exception", + ":resolved_overload_internal", "@maven//:com_google_code_findbugs_annotations", "@maven//:com_google_errorprone_error_prone_annotations", ], @@ -761,7 +761,8 @@ cel_android_library( name = "function_resolver_android", srcs = ["CelFunctionResolver.java"], deps = [ - ":function_overload_impl_android", + ":evaluation_exception", + ":resolved_overload_internal_android", "@maven//:com_google_code_findbugs_annotations", "@maven//:com_google_errorprone_error_prone_annotations", ], @@ -775,7 +776,7 @@ java_library( tags = [ ], deps = [ - ":function_overload_impl", + ":evaluation_exception", "@maven//:com_google_errorprone_error_prone_annotations", ], ) @@ -786,19 +787,18 @@ cel_android_library( "CelFunctionOverload.java", ], deps = [ - ":function_overload_impl_android", + ":evaluation_exception", "@maven//:com_google_errorprone_error_prone_annotations", ], ) java_library( - name = "function_overload_impl", - srcs = FUNCTION_OVERLOAD_IMPL_SOURCES, + name = "resolved_overload_internal", + srcs = ["ResolvedOverload.java"], tags = [ ], deps = [ - ":evaluation_exception", - "//common/annotations", + ":function_overload", "@maven//:com_google_code_findbugs_annotations", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_protobuf_protobuf_java", @@ -806,11 +806,11 @@ java_library( ) cel_android_library( - name = "function_overload_impl_android", - srcs = FUNCTION_OVERLOAD_IMPL_SOURCES, + name = "resolved_overload_internal_android", + srcs = ["ResolvedOverload.java"], + visibility = ["//visibility:private"], deps = [ - ":evaluation_exception", - "//common/annotations", + ":function_overload_android", "@maven//:com_google_code_findbugs_annotations", "@maven//:com_google_errorprone_error_prone_annotations", "@maven_android//:com_google_protobuf_protobuf_javalite", @@ -1167,3 +1167,31 @@ cel_android_library( "@maven//:com_google_errorprone_error_prone_annotations", ], ) + +java_library( + name = "resolved_overload", + srcs = ["CelResolvedOverload.java"], + tags = [ + ], + deps = [ + ":function_overload", + ":resolved_overload_internal", + "//:auto_value", + "@maven//:com_google_errorprone_error_prone_annotations", + "@maven//:com_google_guava_guava", + ], +) + +cel_android_library( + name = "resolved_overload_android", + srcs = ["CelResolvedOverload.java"], + tags = [ + ], + deps = [ + ":function_overload_android", + ":resolved_overload_internal_android", + "//:auto_value", + "@maven//:com_google_errorprone_error_prone_annotations", + "@maven_android//:com_google_guava_guava", + ], +) diff --git a/runtime/src/main/java/dev/cel/runtime/CelFunctionOverload.java b/runtime/src/main/java/dev/cel/runtime/CelFunctionOverload.java index f924c7d62..a1341cb21 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelFunctionOverload.java +++ b/runtime/src/main/java/dev/cel/runtime/CelFunctionOverload.java @@ -19,7 +19,10 @@ /** Interface describing the general signature of all CEL custom function implementations. */ @Immutable @FunctionalInterface -public interface CelFunctionOverload extends FunctionOverload { +public interface CelFunctionOverload { + + /** Evaluate a set of arguments throwing a {@code CelException} on error. */ + Object apply(Object[] args) throws CelEvaluationException; /** * Helper interface for describing unary functions where the type-parameter is used to improve @@ -27,7 +30,9 @@ public interface CelFunctionOverload extends FunctionOverload { */ @Immutable @FunctionalInterface - interface Unary extends FunctionOverload.Unary {} + interface Unary { + Object apply(T arg) throws CelEvaluationException; + } /** * Helper interface for describing binary functions where the type parameters are used to improve @@ -35,5 +40,7 @@ interface Unary extends FunctionOverload.Unary {} */ @Immutable @FunctionalInterface - interface Binary extends FunctionOverload.Binary {} + interface Binary { + Object apply(T1 arg1, T2 arg2) throws CelEvaluationException; + } } diff --git a/runtime/src/main/java/dev/cel/runtime/CelFunctionResolver.java b/runtime/src/main/java/dev/cel/runtime/CelFunctionResolver.java index d769ff238..8df7fd0dc 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelFunctionResolver.java +++ b/runtime/src/main/java/dev/cel/runtime/CelFunctionResolver.java @@ -15,10 +15,26 @@ package dev.cel.runtime; import javax.annotation.concurrent.ThreadSafe; +import java.util.List; +import java.util.Optional; /** * Interface to a resolver for CEL functions based on the function name, overload ids, and * arguments. */ @ThreadSafe -public interface CelFunctionResolver extends FunctionResolver {} +public interface CelFunctionResolver { + + /** + * Finds a specific function overload to invoke based on given parameters. + * + * @param functionName the logical name of the function being invoked. + * @param overloadIds A list of function overload ids. The dispatcher selects the unique overload + * from this list with matching arguments. + * @param args The arguments to pass to the function. + * @return an optional value of the resolved overload. + * @throws CelEvaluationException if the overload resolution is ambiguous, + */ + Optional findOverload( + String functionName, List overloadIds, Object[] args) throws CelEvaluationException; +} diff --git a/runtime/src/main/java/dev/cel/runtime/CelLateFunctionBindings.java b/runtime/src/main/java/dev/cel/runtime/CelLateFunctionBindings.java index 7f83e38fd..282ef1eb9 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelLateFunctionBindings.java +++ b/runtime/src/main/java/dev/cel/runtime/CelLateFunctionBindings.java @@ -57,7 +57,7 @@ public static CelLateFunctionBindings from(List functions) { private static ResolvedOverload createResolvedOverload(CelFunctionBinding binding) { return CelResolvedOverload.of( binding.getOverloadId(), - binding.getArgTypes(), - (args) -> binding.getDefinition().apply(args)); + (args) -> binding.getDefinition().apply(args), + binding.getArgTypes()); } } diff --git a/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java b/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java index e23749f15..9725315cc 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java +++ b/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java @@ -17,6 +17,7 @@ import com.google.auto.value.AutoValue; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.Immutable; +import java.util.List; /** * Representation of a function overload which has been resolved to a specific set of argument types @@ -24,7 +25,7 @@ */ @AutoValue @Immutable -public abstract class CelResolvedOverload implements ResolvedOverload { +abstract class CelResolvedOverload implements ResolvedOverload { /** The overload id of the function. */ @Override @@ -42,15 +43,16 @@ public abstract class CelResolvedOverload implements ResolvedOverload { * Creates a new resolved overload from the given overload id, parameter types, and definition. */ public static CelResolvedOverload of( - String overloadId, Class[] parameterTypes, CelFunctionOverload definition) { - return of(overloadId, ImmutableList.copyOf(parameterTypes), definition); + String overloadId, CelFunctionOverload definition, Class... parameterTypes) { + return of(overloadId, definition, ImmutableList.copyOf(parameterTypes)); } /** * Creates a new resolved overload from the given overload id, parameter types, and definition. */ public static CelResolvedOverload of( - String overloadId, ImmutableList> parameterTypes, CelFunctionOverload definition) { - return new AutoValue_CelResolvedOverload(overloadId, parameterTypes, definition); + String overloadId, CelFunctionOverload definition, List> parameterTypes) { + return new AutoValue_CelResolvedOverload( + overloadId, ImmutableList.copyOf(parameterTypes), definition); } } diff --git a/runtime/src/main/java/dev/cel/runtime/CelStandardFunctions.java b/runtime/src/main/java/dev/cel/runtime/CelStandardFunctions.java index 299dd7a89..7d781d4af 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelStandardFunctions.java +++ b/runtime/src/main/java/dev/cel/runtime/CelStandardFunctions.java @@ -39,6 +39,7 @@ import dev.cel.runtime.standard.BytesFunction; import dev.cel.runtime.standard.BytesFunction.BytesOverload; import dev.cel.runtime.standard.CelStandardFunction; +import dev.cel.runtime.standard.CelStandardOverload; import dev.cel.runtime.standard.ContainsFunction; import dev.cel.runtime.standard.ContainsFunction.ContainsOverload; import dev.cel.runtime.standard.DivideOperator; @@ -116,7 +117,7 @@ @Immutable public final class CelStandardFunctions { - private final ImmutableSet standardOverloads; + private final ImmutableSet standardOverloads; public static final ImmutableSet ALL_STANDARD_FUNCTIONS = ImmutableSet.of( @@ -163,8 +164,8 @@ public final class CelStandardFunctions { /** * Enumeration of Standard Function bindings. * - *

Note: The conditional, logical_or, logical_and, not_strictly_false, and type functions are - * currently special-cased, and does not appear in this enum. + *

Note: The conditional, logical_or, logical_and, and type functions are currently + * special-cased, and does not appear in this enum. */ public enum StandardFunction { LOGICAL_NOT(BooleanOperator.LOGICAL_NOT), @@ -331,7 +332,7 @@ public enum StandardFunction { public static final class Overload { /** Overloads for internal functions that may have been rewritten by macros (ex: @in) */ - public enum InternalOperator implements StandardOverload { + public enum InternalOperator implements CelStandardOverload { IN_LIST(InOverload.IN_LIST::newFunctionBinding), IN_MAP(InOverload.IN_MAP::newFunctionBinding); @@ -349,7 +350,7 @@ public CelFunctionBinding newFunctionBinding( } /** Overloads for functions that test relations. */ - public enum Relation implements StandardOverload { + public enum Relation implements CelStandardOverload { EQUALS(EqualsOverload.EQUALS::newFunctionBinding), NOT_EQUALS(NotEqualsOverload.NOT_EQUALS::newFunctionBinding); @@ -367,7 +368,7 @@ public CelFunctionBinding newFunctionBinding( } /** Overloads for performing arithmetic operations. */ - public enum Arithmetic implements StandardOverload { + public enum Arithmetic implements CelStandardOverload { ADD_INT64(AddOverload.ADD_INT64::newFunctionBinding), ADD_UINT64(AddOverload.ADD_UINT64::newFunctionBinding), ADD_BYTES(AddOverload.ADD_BYTES::newFunctionBinding), @@ -415,7 +416,7 @@ public CelFunctionBinding newFunctionBinding( } /** Overloads for indexing a list or a map. */ - public enum Index implements StandardOverload { + public enum Index implements CelStandardOverload { INDEX_LIST(IndexOverload.INDEX_LIST::newFunctionBinding), INDEX_MAP(IndexOverload.INDEX_MAP::newFunctionBinding); @@ -433,7 +434,7 @@ public CelFunctionBinding newFunctionBinding( } /** Overloads for retrieving the size of a literal or a collection. */ - public enum Size implements StandardOverload { + public enum Size implements CelStandardOverload { SIZE_BYTES(SizeOverload.SIZE_BYTES::newFunctionBinding), BYTES_SIZE(SizeOverload.BYTES_SIZE::newFunctionBinding), SIZE_LIST(SizeOverload.SIZE_LIST::newFunctionBinding), @@ -457,7 +458,7 @@ public CelFunctionBinding newFunctionBinding( } /** Overloads for performing type conversions. */ - public enum Conversions implements StandardOverload { + public enum Conversions implements CelStandardOverload { BOOL_TO_BOOL(BoolOverload.BOOL_TO_BOOL::newFunctionBinding), STRING_TO_BOOL(BoolOverload.STRING_TO_BOOL::newFunctionBinding), INT64_TO_INT64(IntOverload.INT64_TO_INT64::newFunctionBinding), @@ -514,7 +515,7 @@ public CelFunctionBinding newFunctionBinding( * Overloads for functions performing string matching, such as regular expressions or contains * check. */ - public enum StringMatchers implements StandardOverload { + public enum StringMatchers implements CelStandardOverload { MATCHES(MatchesOverload.MATCHES::newFunctionBinding), MATCHES_STRING(MatchesOverload.MATCHES_STRING::newFunctionBinding), CONTAINS_STRING(ContainsOverload.CONTAINS_STRING::newFunctionBinding), @@ -535,7 +536,7 @@ public CelFunctionBinding newFunctionBinding( } /** Overloads for logical operators that return a bool as a result. */ - public enum BooleanOperator implements StandardOverload { + public enum BooleanOperator implements CelStandardOverload { LOGICAL_NOT(LogicalNotOverload.LOGICAL_NOT::newFunctionBinding); private final FunctionBindingCreator bindingCreator; @@ -552,7 +553,7 @@ public CelFunctionBinding newFunctionBinding( } /** Overloads for functions performing date/time operations. */ - public enum DateTime implements StandardOverload { + public enum DateTime implements CelStandardOverload { TIMESTAMP_TO_YEAR(GetFullYearOverload.TIMESTAMP_TO_YEAR::newFunctionBinding), TIMESTAMP_TO_YEAR_WITH_TZ( GetFullYearOverload.TIMESTAMP_TO_YEAR_WITH_TZ::newFunctionBinding), @@ -605,7 +606,7 @@ public CelFunctionBinding newFunctionBinding( } /** Overloads for performing numeric comparisons. */ - public enum Comparison implements StandardOverload { + public enum Comparison implements CelStandardOverload { LESS_BOOL(LessOverload.LESS_BOOL::newFunctionBinding, false), LESS_INT64(LessOverload.LESS_INT64::newFunctionBinding, false), LESS_UINT64(LessOverload.LESS_UINT64::newFunctionBinding, false), @@ -702,20 +703,20 @@ public boolean isHeterogeneousComparison() { private Overload() {} } - private final ImmutableSet standardOverloads; + private final ImmutableSet standardOverloads; - StandardFunction(StandardOverload... overloads) { + StandardFunction(CelStandardOverload... overloads) { this.standardOverloads = ImmutableSet.copyOf(overloads); } @VisibleForTesting - ImmutableSet getOverloads() { + ImmutableSet getOverloads() { return standardOverloads; } } @VisibleForTesting - ImmutableSet getOverloads() { + ImmutableSet getOverloads() { return standardOverloads; } @@ -723,19 +724,13 @@ ImmutableSet getOverloads() { public ImmutableSet newFunctionBindings( RuntimeEquality runtimeEquality, CelOptions celOptions) { ImmutableSet.Builder builder = ImmutableSet.builder(); - for (StandardOverload overload : standardOverloads) { + for (CelStandardOverload overload : standardOverloads) { builder.add(overload.newFunctionBinding(celOptions, runtimeEquality)); } return builder.build(); } - /** General interface for defining a standard function overload. */ - @Immutable - public interface StandardOverload { - CelFunctionBinding newFunctionBinding(CelOptions celOptions, RuntimeEquality runtimeEquality); - } - /** Builder for constructing the set of standard function/identifiers. */ public static final class Builder { private ImmutableSet includeFunctions; @@ -805,7 +800,7 @@ public CelStandardFunctions build() { "You may only populate one of the following builder methods: includeFunctions," + " excludeFunctions or filterFunctions"); - ImmutableSet.Builder standardOverloadBuilder = ImmutableSet.builder(); + ImmutableSet.Builder standardOverloadBuilder = ImmutableSet.builder(); for (StandardFunction standardFunction : StandardFunction.values()) { if (hasIncludeFunctions) { if (this.includeFunctions.contains(standardFunction)) { @@ -820,15 +815,16 @@ public CelStandardFunctions build() { continue; } if (hasFilterFunction) { - ImmutableSet.Builder filteredOverloadsBuilder = ImmutableSet.builder(); - for (StandardOverload standardOverload : standardFunction.standardOverloads) { + ImmutableSet.Builder filteredOverloadsBuilder = + ImmutableSet.builder(); + for (CelStandardOverload standardOverload : standardFunction.standardOverloads) { boolean includeOverload = functionFilter.include(standardFunction, standardOverload); if (includeOverload) { standardOverloadBuilder.add(standardOverload); } } - ImmutableSet filteredOverloads = filteredOverloadsBuilder.build(); + ImmutableSet filteredOverloads = filteredOverloadsBuilder.build(); if (!filteredOverloads.isEmpty()) { standardOverloadBuilder.addAll(filteredOverloads); } @@ -848,7 +844,7 @@ public CelStandardFunctions build() { */ @FunctionalInterface public interface FunctionFilter { - boolean include(StandardFunction standardFunction, StandardOverload standardOverload); + boolean include(StandardFunction standardFunction, CelStandardOverload standardOverload); } } @@ -857,13 +853,7 @@ public static Builder newBuilder() { return new Builder(); } - @FunctionalInterface - @Immutable - private interface FunctionBindingCreator { - CelFunctionBinding create(CelOptions celOptions, RuntimeEquality runtimeEquality); - } - - private CelStandardFunctions(ImmutableSet standardOverloads) { + private CelStandardFunctions(ImmutableSet standardOverloads) { this.standardOverloads = standardOverloads; } } diff --git a/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java b/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java index 1b89939ea..c4c27005e 100644 --- a/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java +++ b/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java @@ -14,9 +14,7 @@ package dev.cel.runtime; -import com.google.auto.value.AutoValue; import com.google.common.base.Joiner; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.Immutable; import javax.annotation.concurrent.ThreadSafe; @@ -48,8 +46,7 @@ public synchronized void add( String overloadId, Class argType, final Registrar.UnaryFunction function) { overloads.put( overloadId, - ResolvedOverloadImpl.of( - overloadId, new Class[] {argType}, args -> function.apply((T) args[0]))); + CelResolvedOverload.of(overloadId, args -> function.apply((T) args[0]), argType)); } @Override @@ -61,18 +58,14 @@ public synchronized void add( final Registrar.BinaryFunction function) { overloads.put( overloadId, - ResolvedOverloadImpl.of( - overloadId, - new Class[] {argType1, argType2}, - args -> function.apply((T1) args[0], (T2) args[1]))); + CelResolvedOverload.of( + overloadId, args -> function.apply((T1) args[0], (T2) args[1]), argType1, argType2)); } @Override public synchronized void add( String overloadId, List> argTypes, Registrar.Function function) { - overloads.put( - overloadId, - ResolvedOverloadImpl.of(overloadId, argTypes.toArray(new Class[0]), function)); + overloads.put(overloadId, CelResolvedOverload.of(overloadId, function, argTypes)); } @Override @@ -144,31 +137,4 @@ public Dispatcher.ImmutableCopy immutableCopy() { } private DefaultDispatcher() {} - - @AutoValue - @Immutable - abstract static class ResolvedOverloadImpl implements ResolvedOverload { - /** The overload id of the function. */ - @Override - public abstract String getOverloadId(); - - /** The types of the function parameters. */ - @Override - public abstract ImmutableList> getParameterTypes(); - - /** The function definition. */ - @Override - public abstract FunctionOverload getDefinition(); - - static ResolvedOverload of( - String overloadId, Class[] parameterTypes, FunctionOverload definition) { - return of(overloadId, ImmutableList.copyOf(parameterTypes), definition); - } - - static ResolvedOverload of( - String overloadId, ImmutableList> parameterTypes, FunctionOverload definition) { - return new AutoValue_DefaultDispatcher_ResolvedOverloadImpl( - overloadId, parameterTypes, definition); - } - } } diff --git a/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java b/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java index afdca0e3e..b471d972e 100644 --- a/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java +++ b/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java @@ -154,7 +154,7 @@ public Object eval(GlobalResolver resolver, CelEvaluationListener listener) } @Override - public Object eval(GlobalResolver resolver, FunctionResolver lateBoundFunctionResolver) + public Object eval(GlobalResolver resolver, CelFunctionResolver lateBoundFunctionResolver) throws CelEvaluationException { return evalTrackingUnknowns( RuntimeUnknownResolver.fromResolver(resolver), @@ -165,7 +165,7 @@ public Object eval(GlobalResolver resolver, FunctionResolver lateBoundFunctionRe @Override public Object eval( GlobalResolver resolver, - FunctionResolver lateBoundFunctionResolver, + CelFunctionResolver lateBoundFunctionResolver, CelEvaluationListener listener) throws CelEvaluationException { return evalTrackingUnknowns( @@ -177,7 +177,7 @@ public Object eval( @Override public Object evalTrackingUnknowns( RuntimeUnknownResolver resolver, - Optional functionResolver, + Optional functionResolver, Optional listener) throws CelEvaluationException { ExecutionFrame frame = newExecutionFrame(resolver, functionResolver, listener); @@ -223,7 +223,7 @@ ExecutionFrame newTestExecutionFrame(GlobalResolver resolver) { private ExecutionFrame newExecutionFrame( RuntimeUnknownResolver resolver, - Optional functionResolver, + Optional functionResolver, Optional listener) { int comprehensionMaxIterations = celOptions.enableComprehension() ? celOptions.comprehensionMaxIterations() : 0; @@ -1105,7 +1105,7 @@ static class ExecutionFrame { private final Optional evaluationListener; private final int maxIterations; private final ArrayDeque resolvers; - private final Optional lateBoundFunctionResolver; + private final Optional lateBoundFunctionResolver; private RuntimeUnknownResolver currentResolver; private int iterations; @VisibleForTesting int scopeLevel; @@ -1113,7 +1113,7 @@ static class ExecutionFrame { private ExecutionFrame( Optional evaluationListener, RuntimeUnknownResolver resolver, - Optional lateBoundFunctionResolver, + Optional lateBoundFunctionResolver, int maxIterations) { this.evaluationListener = evaluationListener; this.resolvers = new ArrayDeque<>(); diff --git a/runtime/src/main/java/dev/cel/runtime/Dispatcher.java b/runtime/src/main/java/dev/cel/runtime/Dispatcher.java index 017e76685..e7bc6163f 100644 --- a/runtime/src/main/java/dev/cel/runtime/Dispatcher.java +++ b/runtime/src/main/java/dev/cel/runtime/Dispatcher.java @@ -25,7 +25,7 @@ */ @ThreadSafe @Internal -interface Dispatcher extends FunctionResolver { +interface Dispatcher extends CelFunctionResolver { /** * Returns an {@link ImmutableCopy} from current instance. diff --git a/runtime/src/main/java/dev/cel/runtime/FunctionOverload.java b/runtime/src/main/java/dev/cel/runtime/FunctionOverload.java deleted file mode 100644 index bf825f004..000000000 --- a/runtime/src/main/java/dev/cel/runtime/FunctionOverload.java +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package dev.cel.runtime; - -import com.google.errorprone.annotations.Immutable; - -/** Interface describing the general signature of all CEL custom function implementations. */ -@FunctionalInterface -@Immutable -interface FunctionOverload { - - /** Evaluate a set of arguments throwing a {@code CelException} on error. */ - Object apply(Object[] args) throws CelEvaluationException; - - /** - * Helper interface for describing unary functions where the type-parameter is used to improve - * compile-time correctness of function bindings. - */ - @Immutable - @FunctionalInterface - interface Unary { - Object apply(T arg) throws CelEvaluationException; - } - - /** - * Helper interface for describing binary functions where the type parameters are used to improve - * compile-time correctness of function bindings. - */ - @Immutable - @FunctionalInterface - interface Binary { - Object apply(T1 arg1, T2 arg2) throws CelEvaluationException; - } -} diff --git a/runtime/src/main/java/dev/cel/runtime/FunctionResolver.java b/runtime/src/main/java/dev/cel/runtime/FunctionResolver.java deleted file mode 100644 index 888780901..000000000 --- a/runtime/src/main/java/dev/cel/runtime/FunctionResolver.java +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package dev.cel.runtime; - -import javax.annotation.concurrent.ThreadSafe; -import dev.cel.common.annotations.Internal; -import java.util.List; -import java.util.Optional; - -/** - * Interface to a resolver for CEL functions based on the function name, overload ids, and - * arguments. - * - *

CEL Library Internals. Do Not Use. - */ -@ThreadSafe -@Internal -interface FunctionResolver { - - /** - * Finds a specific function overload to invoke based on given parameters. - * - * @param functionName the logical name of the function being invoked. - * @param overloadIds A list of function overload ids. The dispatcher selects the unique overload - * from this list with matching arguments. - * @param args The arguments to pass to the function. - * @return an optional value of the resolved overload. - * @throws CelEvaluationException if the overload resolution is ambiguous, - */ - Optional findOverload( - String functionName, List overloadIds, Object[] args) throws CelEvaluationException; -} diff --git a/runtime/src/main/java/dev/cel/runtime/Interpretable.java b/runtime/src/main/java/dev/cel/runtime/Interpretable.java index 21e95921d..ece90cb4b 100644 --- a/runtime/src/main/java/dev/cel/runtime/Interpretable.java +++ b/runtime/src/main/java/dev/cel/runtime/Interpretable.java @@ -45,7 +45,7 @@ Object eval(GlobalResolver resolver, CelEvaluationListener listener) * directly such as recording telemetry or evaluation state in a more granular fashion than a more * general evaluation listener might permit. */ - Object eval(GlobalResolver resolver, FunctionResolver lateBoundFunctionResolver) + Object eval(GlobalResolver resolver, CelFunctionResolver lateBoundFunctionResolver) throws CelEvaluationException; /** @@ -58,7 +58,7 @@ Object eval(GlobalResolver resolver, FunctionResolver lateBoundFunctionResolver) */ Object eval( GlobalResolver resolver, - FunctionResolver lateBoundFunctionResolver, + CelFunctionResolver lateBoundFunctionResolver, CelEvaluationListener listener) throws CelEvaluationException; } diff --git a/runtime/src/main/java/dev/cel/runtime/Registrar.java b/runtime/src/main/java/dev/cel/runtime/Registrar.java index 3a3afa1d0..12dbe6cb3 100644 --- a/runtime/src/main/java/dev/cel/runtime/Registrar.java +++ b/runtime/src/main/java/dev/cel/runtime/Registrar.java @@ -28,21 +28,21 @@ public interface Registrar { /** Interface describing the general signature of all CEL custom function implementations. */ @Immutable - interface Function extends FunctionOverload {} + interface Function extends CelFunctionOverload {} /** * Helper interface for describing unary functions where the type-parameter is used to improve * compile-time correctness of function bindings. */ @Immutable - interface UnaryFunction extends FunctionOverload.Unary {} + interface UnaryFunction extends CelFunctionOverload.Unary {} /** * Helper interface for describing binary functions where the type parameters are used to improve * compile-time correctness of function bindings. */ @Immutable - interface BinaryFunction extends FunctionOverload.Binary {} + interface BinaryFunction extends CelFunctionOverload.Binary {} /** Adds a unary function to the dispatcher. */ void add(String overloadId, Class argType, UnaryFunction function); diff --git a/runtime/src/main/java/dev/cel/runtime/ResolvedOverload.java b/runtime/src/main/java/dev/cel/runtime/ResolvedOverload.java index f71748814..bc7544199 100644 --- a/runtime/src/main/java/dev/cel/runtime/ResolvedOverload.java +++ b/runtime/src/main/java/dev/cel/runtime/ResolvedOverload.java @@ -33,7 +33,7 @@ interface ResolvedOverload { List> getParameterTypes(); /** The function definition. */ - FunctionOverload getDefinition(); + CelFunctionOverload getDefinition(); /** * Returns true if the overload's expected argument types match the types of the given arguments. diff --git a/runtime/src/main/java/dev/cel/runtime/UnknownTrackingInterpretable.java b/runtime/src/main/java/dev/cel/runtime/UnknownTrackingInterpretable.java index 9a1a4964a..8422910a1 100644 --- a/runtime/src/main/java/dev/cel/runtime/UnknownTrackingInterpretable.java +++ b/runtime/src/main/java/dev/cel/runtime/UnknownTrackingInterpretable.java @@ -35,7 +35,7 @@ public interface UnknownTrackingInterpretable { */ Object evalTrackingUnknowns( RuntimeUnknownResolver resolver, - Optional lateBoundFunctionResolver, + Optional lateBoundFunctionResolver, Optional listener) throws CelEvaluationException; } diff --git a/runtime/src/main/java/dev/cel/runtime/standard/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/standard/BUILD.bazel index 08177a474..851f375a4 100644 --- a/runtime/src/main/java/dev/cel/runtime/standard/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/standard/BUILD.bazel @@ -1422,7 +1422,8 @@ cel_android_library( java_library( name = "standard_overload", srcs = ["CelStandardOverload.java"], - visibility = ["//visibility:private"], + tags = [ + ], deps = [ "//common:options", "//runtime:function_binding", @@ -1434,7 +1435,8 @@ java_library( cel_android_library( name = "standard_overload_android", srcs = ["CelStandardOverload.java"], - visibility = ["//visibility:private"], + tags = [ + ], deps = [ "//common:options", "//runtime:function_binding_android", diff --git a/runtime/src/main/java/dev/cel/runtime/standard/CelStandardOverload.java b/runtime/src/main/java/dev/cel/runtime/standard/CelStandardOverload.java index 2b4e385db..7c6599737 100644 --- a/runtime/src/main/java/dev/cel/runtime/standard/CelStandardOverload.java +++ b/runtime/src/main/java/dev/cel/runtime/standard/CelStandardOverload.java @@ -25,10 +25,12 @@ * overload. */ @Immutable -interface CelStandardOverload { +public interface CelStandardOverload { + /** Constructs a new {@link CelFunctionBinding} for this CEL standard overload. */ CelFunctionBinding newFunctionBinding(CelOptions celOptions, RuntimeEquality runtimeEquality); + /** TODO: To be removed in the upcoming CL */ @FunctionalInterface @Immutable interface FunctionBindingCreator { diff --git a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel index 5137cad39..81e206fdb 100644 --- a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel @@ -65,7 +65,6 @@ java_library( "//runtime:evaluation_exception_builder", "//runtime:evaluation_listener", "//runtime:function_binding", - "//runtime:function_overload_impl", "//runtime:interpretable", "//runtime:interpreter", "//runtime:interpreter_util", @@ -75,12 +74,15 @@ java_library( "//runtime:proto_message_activation_factory", "//runtime:proto_message_runtime_equality", "//runtime:proto_message_runtime_helpers", + "//runtime:resolved_overload", + "//runtime:resolved_overload_internal", "//runtime:runtime_equality", "//runtime:runtime_helpers", "//runtime:standard_functions", "//runtime:type_resolver", "//runtime:unknown_attributes", "//runtime:unknown_options", + "//runtime/standard:standard_overload", "//testing/protos:message_with_enum_cel_java_proto", "//testing/protos:message_with_enum_java_proto", "//testing/protos:multi_file_cel_java_proto", diff --git a/runtime/src/test/java/dev/cel/runtime/CelResolvedOverloadTest.java b/runtime/src/test/java/dev/cel/runtime/CelResolvedOverloadTest.java index 8b0fd7193..2c81b89d4 100644 --- a/runtime/src/test/java/dev/cel/runtime/CelResolvedOverloadTest.java +++ b/runtime/src/test/java/dev/cel/runtime/CelResolvedOverloadTest.java @@ -16,7 +16,6 @@ import static com.google.common.truth.Truth.assertThat; -import com.google.common.collect.ImmutableList; import dev.cel.expr.conformance.proto3.TestAllTypes; import org.junit.Test; import org.junit.runner.RunWith; @@ -29,11 +28,11 @@ public final class CelResolvedOverloadTest { CelResolvedOverload getIncrementIntOverload() { return CelResolvedOverload.of( "increment_int", - ImmutableList.of(Long.class), (args) -> { Long arg = (Long) args[0]; return arg + 1; - }); + }, + Long.class); } @Test @@ -44,14 +43,14 @@ public void canHandle_matchingTypes_returnsTrue() { @Test public void canHandle_nullMessageType_returnsTrue() { CelResolvedOverload overload = - CelResolvedOverload.of("identity", ImmutableList.of(TestAllTypes.class), (args) -> args[0]); + CelResolvedOverload.of("identity", (args) -> args[0], TestAllTypes.class); assertThat(overload.canHandle(new Object[] {null})).isTrue(); } @Test public void canHandle_nullPrimitive_returnsFalse() { CelResolvedOverload overload = - CelResolvedOverload.of("identity", ImmutableList.of(Long.class), (args) -> args[0]); + CelResolvedOverload.of("identity", (args) -> args[0], Long.class); assertThat(overload.canHandle(new Object[] {null})).isFalse(); } diff --git a/runtime/src/test/java/dev/cel/runtime/CelStandardFunctionsTest.java b/runtime/src/test/java/dev/cel/runtime/CelStandardFunctionsTest.java index 09da78615..eac68fb63 100644 --- a/runtime/src/test/java/dev/cel/runtime/CelStandardFunctionsTest.java +++ b/runtime/src/test/java/dev/cel/runtime/CelStandardFunctionsTest.java @@ -25,7 +25,7 @@ import dev.cel.compiler.CelCompilerFactory; import dev.cel.runtime.CelStandardFunctions.StandardFunction; import dev.cel.runtime.CelStandardFunctions.StandardFunction.Overload.Arithmetic; -import dev.cel.runtime.CelStandardFunctions.StandardOverload; +import dev.cel.runtime.standard.CelStandardOverload; import org.junit.Test; import org.junit.runner.RunWith; @@ -90,7 +90,7 @@ public void standardFunctions_includeFunctions() { assertThat(celStandardFunctions.getOverloads()) .containsExactlyElementsIn( - ImmutableSet.builder() + ImmutableSet.builder() .addAll(CelStandardFunctions.StandardFunction.ADD.getOverloads()) .addAll(CelStandardFunctions.StandardFunction.SUBTRACT.getOverloads()) .build()); diff --git a/runtime/src/test/java/dev/cel/runtime/DefaultDispatcherTest.java b/runtime/src/test/java/dev/cel/runtime/DefaultDispatcherTest.java index 556d6945c..03eba51ce 100644 --- a/runtime/src/test/java/dev/cel/runtime/DefaultDispatcherTest.java +++ b/runtime/src/test/java/dev/cel/runtime/DefaultDispatcherTest.java @@ -35,13 +35,9 @@ public final class DefaultDispatcherTest { public void setup() { overloads = new HashMap<>(); overloads.put( - "overload_1", - CelResolvedOverload.of( - "overload_1", new Class[] {Long.class}, args -> (Long) args[0] + 1)); + "overload_1", CelResolvedOverload.of("overload_1", args -> (Long) args[0] + 1, Long.class)); overloads.put( - "overload_2", - CelResolvedOverload.of( - "overload_2", new Class[] {Long.class}, args -> (Long) args[0] + 2)); + "overload_2", CelResolvedOverload.of("overload_2", args -> (Long) args[0] + 2, Long.class)); } @Test diff --git a/runtime/standard/BUILD.bazel b/runtime/standard/BUILD.bazel index 5f84c105f..a69f69887 100644 --- a/runtime/standard/BUILD.bazel +++ b/runtime/standard/BUILD.bazel @@ -405,3 +405,13 @@ cel_android_library( name = "uint_android", exports = ["//runtime/src/main/java/dev/cel/runtime/standard:uint_android"], ) + +java_library( + name = "standard_overload", + exports = ["//runtime/src/main/java/dev/cel/runtime/standard:standard_overload"], +) + +cel_android_library( + name = "standard_overload_android", + exports = ["//runtime/src/main/java/dev/cel/runtime/standard:standard_overload_android"], +)