diff --git a/runtime/BUILD.bazel b/runtime/BUILD.bazel index 251d45650..f3b60d4d7 100644 --- a/runtime/BUILD.bazel +++ b/runtime/BUILD.bazel @@ -227,12 +227,6 @@ cel_android_library( exports = ["//runtime/src/main/java/dev/cel/runtime:resolved_overload_android"], ) -java_library( - name = "resolved_overload_internal", - visibility = ["//:internal"], - exports = ["//runtime/src/main/java/dev/cel/runtime:resolved_overload_internal"], -) - java_library( name = "internal_function_binder", visibility = ["//:internal"], diff --git a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel index 56f9a7bd2..98d6244aa 100644 --- a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel @@ -124,11 +124,10 @@ java_library( ":evaluation_exception", ":evaluation_exception_builder", ":function_overload", + ":function_resolver", ":resolved_overload", - ":resolved_overload_internal", "//:auto_value", "//common:error_codes", - "//runtime:function_resolver", "@maven//:com_google_code_findbugs_annotations", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", @@ -145,7 +144,6 @@ cel_android_library( ":function_overload_android", ":function_resolver_android", ":resolved_overload_android", - ":resolved_overload_internal_android", "//:auto_value", "//common:error_codes", "@maven//:com_google_code_findbugs_annotations", @@ -285,10 +283,11 @@ java_library( ":evaluation_exception", ":evaluation_exception_builder", ":evaluation_listener", + ":function_resolver", ":interpretable", ":interpreter_util", ":metadata", - ":resolved_overload_internal", + ":resolved_overload", ":runtime_helpers", ":runtime_type_provider", ":type_resolver", @@ -303,7 +302,6 @@ java_library( "//common/types", "//common/types:type_providers", "//common/values:cel_byte_string", - "//runtime:function_resolver", "@maven//:com_google_code_findbugs_annotations", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", @@ -328,7 +326,7 @@ cel_android_library( ":interpretable_android", ":interpreter_util_android", ":metadata", - ":resolved_overload_internal_android", + ":resolved_overload_android", ":runtime_helpers_android", ":runtime_type_provider_android", ":type_resolver_android", @@ -498,7 +496,6 @@ java_library( ":function_binding", ":function_resolver", ":resolved_overload", - ":resolved_overload_internal", "//:auto_value", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", @@ -516,7 +513,6 @@ cel_android_library( ":function_binding_android", ":function_resolver_android", ":resolved_overload_android", - ":resolved_overload_internal_android", "//:auto_value", "@maven//:com_google_errorprone_error_prone_annotations", "@maven_android//:com_google_guava_guava", @@ -584,8 +580,8 @@ java_library( deps = [ ":evaluation_exception", ":evaluation_listener", + ":function_resolver", "//common/annotations", - "//runtime:function_resolver", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:org_jspecify_jspecify", ], @@ -752,7 +748,7 @@ java_library( ], deps = [ ":evaluation_exception", - ":resolved_overload_internal", + ":resolved_overload", "@maven//:com_google_code_findbugs_annotations", "@maven//:com_google_errorprone_error_prone_annotations", ], @@ -763,7 +759,7 @@ cel_android_library( srcs = ["CelFunctionResolver.java"], deps = [ ":evaluation_exception", - ":resolved_overload_internal_android", + ":resolved_overload_android", "@maven//:com_google_code_findbugs_annotations", "@maven//:com_google_errorprone_error_prone_annotations", ], @@ -793,33 +789,6 @@ cel_android_library( ], ) -java_library( - name = "resolved_overload_internal", - srcs = ["ResolvedOverload.java"], - tags = [ - ], - deps = [ - ":function_overload", - ":unknown_attributes", - "@maven//:com_google_code_findbugs_annotations", - "@maven//:com_google_errorprone_error_prone_annotations", - "@maven//:com_google_protobuf_protobuf_java", - ], -) - -cel_android_library( - name = "resolved_overload_internal_android", - srcs = ["ResolvedOverload.java"], - visibility = ["//visibility:private"], - deps = [ - ":function_overload_android", - ":unknown_attributes_android", - "@maven//:com_google_code_findbugs_annotations", - "@maven//:com_google_errorprone_error_prone_annotations", - "@maven_android//:com_google_protobuf_protobuf_javalite", - ], -) - java_library( name = "runtime", srcs = RUNTIME_SOURCES, @@ -837,6 +806,7 @@ java_library( ":function_resolver", ":interpretable", ":interpreter", + ":program", ":proto_message_activation_factory", ":proto_message_runtime_equality", ":runtime_equality", @@ -856,7 +826,6 @@ java_library( "//common/types:cel_types", "//common/values:cel_value_provider", "//common/values:proto_message_value_provider", - "//runtime:program", "@maven//:com_google_code_findbugs_annotations", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", @@ -873,12 +842,12 @@ java_library( deps = [ ":evaluation_exception", ":function_binding", + ":program", "//:auto_value", "//common:cel_ast", "//common:options", "//common/annotations", "//common/values:cel_value_provider", - "//runtime:program", "//runtime/standard:standard_function", "@maven//:com_google_code_findbugs_annotations", "@maven//:com_google_errorprone_error_prone_annotations", @@ -898,6 +867,7 @@ java_library( ":interpreter", ":lite_program_impl", ":lite_runtime", + ":program", ":runtime_equality", ":runtime_helpers", ":type_resolver", @@ -905,7 +875,6 @@ java_library( "//common:cel_ast", "//common:options", "//common/values:cel_value_provider", - "//runtime:program", "//runtime/standard:standard_function", "@maven//:com_google_code_findbugs_annotations", "@maven//:com_google_guava_guava", @@ -1201,10 +1170,11 @@ java_library( ], deps = [ ":function_overload", - ":resolved_overload_internal", + ":unknown_attributes", "//:auto_value", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", ], ) @@ -1215,10 +1185,11 @@ cel_android_library( ], deps = [ ":function_overload_android", - ":resolved_overload_internal_android", + ":unknown_attributes_android", "//:auto_value", "@maven//:com_google_errorprone_error_prone_annotations", "@maven_android//:com_google_guava_guava", + "@maven_android//:com_google_protobuf_protobuf_javalite", ], ) @@ -1253,9 +1224,9 @@ java_library( tags = [ ], deps = [ + ":function_binding", ":function_overload", "//common/annotations", - "//runtime:function_binding", "@maven//:com_google_guava_guava", ], ) @@ -1266,9 +1237,9 @@ cel_android_library( tags = [ ], deps = [ + ":function_binding_android", ":function_overload_android", "//common/annotations", - "//runtime:function_binding_android", "@maven_android//:com_google_guava_guava", ], ) diff --git a/runtime/src/main/java/dev/cel/runtime/CelFunctionResolver.java b/runtime/src/main/java/dev/cel/runtime/CelFunctionResolver.java index 8f00eebb3..1c5491c14 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelFunctionResolver.java +++ b/runtime/src/main/java/dev/cel/runtime/CelFunctionResolver.java @@ -35,6 +35,6 @@ public interface CelFunctionResolver { * @return an optional value of the resolved overload. * @throws CelEvaluationException if the overload resolution is ambiguous, */ - Optional findOverloadMatchingArgs( + Optional findOverloadMatchingArgs( String functionName, List overloadIds, Object[] args) throws CelEvaluationException; } diff --git a/runtime/src/main/java/dev/cel/runtime/CelLateFunctionBindings.java b/runtime/src/main/java/dev/cel/runtime/CelLateFunctionBindings.java index 75be39e92..f533766b2 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelLateFunctionBindings.java +++ b/runtime/src/main/java/dev/cel/runtime/CelLateFunctionBindings.java @@ -29,14 +29,14 @@ @Immutable public final class CelLateFunctionBindings implements CelFunctionResolver { - private final ImmutableMap functions; + private final ImmutableMap functions; - private CelLateFunctionBindings(ImmutableMap functions) { + private CelLateFunctionBindings(ImmutableMap functions) { this.functions = functions; } @Override - public Optional findOverloadMatchingArgs( + public Optional findOverloadMatchingArgs( String functionName, List overloadIds, Object[] args) throws CelEvaluationException { return DefaultDispatcher.findOverloadMatchingArgs(functionName, overloadIds, functions, args); } @@ -54,7 +54,7 @@ public static CelLateFunctionBindings from(List functions) { CelLateFunctionBindings::createResolvedOverload))); } - private static ResolvedOverload createResolvedOverload(CelFunctionBinding binding) { + private static CelResolvedOverload createResolvedOverload(CelFunctionBinding binding) { return CelResolvedOverload.of( binding.getOverloadId(), (args) -> binding.getDefinition().apply(args), diff --git a/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java b/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java index 26f86a459..74d9e269e 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java +++ b/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java @@ -17,7 +17,9 @@ import com.google.auto.value.AutoValue; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.Immutable; +import com.google.protobuf.MessageLite; import java.util.List; +import java.util.Map; /** * Representation of a function overload which has been resolved to a specific set of argument types @@ -25,14 +27,12 @@ */ @AutoValue @Immutable -abstract class CelResolvedOverload implements ResolvedOverload { +abstract class CelResolvedOverload { /** The overload id of the function. */ - @Override public abstract String getOverloadId(); /** The types of the function parameters. */ - @Override public abstract ImmutableList> getParameterTypes(); /* Denotes whether an overload is strict. @@ -47,11 +47,9 @@ abstract class CelResolvedOverload implements ResolvedOverload { * *

In a vast majority of cases, this should be set to true. */ - @Override public abstract boolean isStrict(); /** The function definition. */ - @Override public abstract CelFunctionOverload getDefinition(); /** @@ -76,4 +74,41 @@ public static CelResolvedOverload of( return new AutoValue_CelResolvedOverload( overloadId, ImmutableList.copyOf(parameterTypes), isStrict, definition); } + + /** + * Returns true if the overload's expected argument types match the types of the given arguments. + */ + boolean canHandle(Object[] arguments) { + ImmutableList> parameterTypes = getParameterTypes(); + if (parameterTypes.size() != arguments.length) { + return false; + } + for (int i = 0; i < parameterTypes.size(); i++) { + Class paramType = parameterTypes.get(i); + Object arg = arguments[i]; + if (arg == null) { + // null can be assigned to messages, maps, and to objects. + // TODO: Remove null special casing + if (paramType != Object.class + && !MessageLite.class.isAssignableFrom(paramType) + && !Map.class.isAssignableFrom(paramType)) { + return false; + } + continue; + } + + if (arg instanceof Exception || arg instanceof CelUnknownSet) { + // Only non-strict functions can accept errors/unknowns as arguments to a function + if (!isStrict()) { + // Skip assignability check below, but continue to validate remaining args + continue; + } + } + + if (!paramType.isAssignableFrom(arg.getClass())) { + return false; + } + } + return true; + } } diff --git a/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java b/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java index 35340abc5..adc99deb7 100644 --- a/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java +++ b/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java @@ -31,26 +31,26 @@ @Immutable final class DefaultDispatcher implements CelFunctionResolver { - private final ImmutableMap overloads; + private final ImmutableMap overloads; @Override - public Optional findOverloadMatchingArgs( + public Optional findOverloadMatchingArgs( String functionName, List overloadIds, Object[] args) throws CelEvaluationException { return findOverloadMatchingArgs(functionName, overloadIds, overloads, args); } /** Finds the overload that matches the given function name, overload IDs, and arguments. */ - static Optional findOverloadMatchingArgs( + static Optional findOverloadMatchingArgs( String functionName, List overloadIds, - Map overloads, + Map overloads, Object[] args) throws CelEvaluationException { int matchingOverloadCount = 0; - ResolvedOverload match = null; + CelResolvedOverload match = null; List candidates = null; for (String overloadId : overloadIds) { - ResolvedOverload overload = overloads.get(overloadId); + CelResolvedOverload overload = overloads.get(overloadId); // If the overload is null, it means that the function was not registered; however, it is // possible that the overload refers to a late-bound function. if (overload != null && overload.canHandle(args)) { @@ -85,9 +85,9 @@ static Optional findOverloadMatchingArgs( * * @throws IllegalStateException if there are multiple overloads that are marked non-strict. */ - Optional findSingleNonStrictOverload(List overloadIds) { + Optional findSingleNonStrictOverload(List overloadIds) { for (String overloadId : overloadIds) { - ResolvedOverload overload = overloads.get(overloadId); + CelResolvedOverload overload = overloads.get(overloadId); if (overload != null && !overload.isStrict()) { if (overloadIds.size() > 1) { throw new IllegalStateException( @@ -108,9 +108,9 @@ static Builder newBuilder() { @AutoBuilder(ofClass = DefaultDispatcher.class) abstract static class Builder { - abstract ImmutableMap overloads(); + abstract ImmutableMap overloads(); - abstract ImmutableMap.Builder overloadsBuilder(); + abstract ImmutableMap.Builder overloadsBuilder(); @CanIgnoreReturnValue Builder addOverload( @@ -130,7 +130,7 @@ Builder addOverload( abstract DefaultDispatcher build(); } - DefaultDispatcher(ImmutableMap overloads) { + DefaultDispatcher(ImmutableMap overloads) { this.overloads = overloads; } } diff --git a/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java b/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java index 931194d6e..546290a4e 100644 --- a/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java +++ b/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java @@ -494,7 +494,7 @@ private IntermediateResult dispatchCall( Object[] argArray = Arrays.stream(argResults).map(IntermediateResult::value).toArray(); ImmutableList overloadIds = reference.overloadIds(); - ResolvedOverload overload = + CelResolvedOverload overload = findOverloadOrThrow(frame, expr, callExpr.function(), overloadIds, argArray); try { Object dispatchResult = overload.getDefinition().apply(argArray); @@ -517,7 +517,7 @@ private IntermediateResult dispatchCall( } } - private ResolvedOverload findOverloadOrThrow( + private CelResolvedOverload findOverloadOrThrow( ExecutionFrame frame, CelExpr expr, String functionName, @@ -525,7 +525,7 @@ private ResolvedOverload findOverloadOrThrow( Object[] args) throws CelEvaluationException { try { - Optional funcImpl = + Optional funcImpl = dispatcher.findOverloadMatchingArgs(functionName, overloadIds, args); if (funcImpl.isPresent()) { return funcImpl.get(); @@ -1132,7 +1132,7 @@ private RuntimeUnknownResolver getResolver() { return currentResolver; } - private Optional findOverload( + private Optional findOverload( String function, List overloadIds, Object[] args) throws CelEvaluationException { if (lateBoundFunctionResolver.isPresent()) { return lateBoundFunctionResolver diff --git a/runtime/src/main/java/dev/cel/runtime/ResolvedOverload.java b/runtime/src/main/java/dev/cel/runtime/ResolvedOverload.java deleted file mode 100644 index 5d632e695..000000000 --- a/runtime/src/main/java/dev/cel/runtime/ResolvedOverload.java +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package dev.cel.runtime; - -import com.google.errorprone.annotations.Immutable; -import com.google.protobuf.MessageLite; -import java.util.List; -import java.util.Map; - -/** - * Representation of a function overload which has been resolved to a specific set of argument types - * and a function definition. - */ -@Immutable -interface ResolvedOverload { - - /** The overload id of the function. */ - String getOverloadId(); - - /** The types of the function parameters. */ - List> getParameterTypes(); - - /** The function definition. */ - CelFunctionOverload getDefinition(); - - /** - * Denotes whether an overload is strict. - * - *

A strict function will not be invoked if any of its arguments are an error or unknown value. - * The runtime automatically propagates the error or unknown instead. - * - *

A non-strict function will be invoked even if its arguments contain errors or unknowns. The - * function's implementation is then responsible for handling these values. This is primarily used - * for short-circuiting logical operators (e.g., `||`, `&&`) and comprehension's - * internal @not_strictly_false function. - * - *

In a vast majority of cases, a function should be kept strict. - */ - boolean isStrict(); - - /** - * Returns true if the overload's expected argument types match the types of the given arguments. - */ - default boolean canHandle(Object[] arguments) { - List> parameterTypes = getParameterTypes(); - if (parameterTypes.size() != arguments.length) { - return false; - } - for (int i = 0; i < parameterTypes.size(); i++) { - Class paramType = parameterTypes.get(i); - Object arg = arguments[i]; - if (arg == null) { - // null can be assigned to messages, maps, and to objects. - if (paramType != Object.class - && !MessageLite.class.isAssignableFrom(paramType) - && !Map.class.isAssignableFrom(paramType)) { - return false; - } - continue; - } - - if (arg instanceof Exception || arg instanceof CelUnknownSet) { - // Only non-strict functions can accept errors/unknowns as arguments to a function - if (!isStrict()) { - // Skip assignability check below, but continue to validate remaining args - continue; - } - } - - if (!paramType.isAssignableFrom(arg.getClass())) { - return false; - } - } - return true; - } -} diff --git a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel index 092f61f04..1541e9c84 100644 --- a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel @@ -75,7 +75,6 @@ java_library( "//runtime:proto_message_runtime_equality", "//runtime:proto_message_runtime_helpers", "//runtime:resolved_overload", - "//runtime:resolved_overload_internal", "//runtime:runtime_equality", "//runtime:runtime_helpers", "//runtime:standard_functions", diff --git a/runtime/src/test/java/dev/cel/runtime/CelLateFunctionBindingsTest.java b/runtime/src/test/java/dev/cel/runtime/CelLateFunctionBindingsTest.java index 3f9e1e105..395fb0897 100644 --- a/runtime/src/test/java/dev/cel/runtime/CelLateFunctionBindingsTest.java +++ b/runtime/src/test/java/dev/cel/runtime/CelLateFunctionBindingsTest.java @@ -37,7 +37,7 @@ public void findOverload_singleMatchingFunction_isPresent() throws Exception { CelFunctionBinding.from("increment_int", Long.class, (arg) -> arg + 1), CelFunctionBinding.from( "increment_uint", UnsignedLong.class, (arg) -> arg.plus(UnsignedLong.ONE))); - Optional overload = + Optional overload = bindings.findOverloadMatchingArgs( "increment", ImmutableList.of("increment_int", "increment_uint"), new Object[] {1L}); assertThat(overload).isPresent(); @@ -53,7 +53,7 @@ public void findOverload_noMatchingFunctionSameArgCount_isEmpty() throws Excepti CelFunctionBinding.from("increment_int", Long.class, (arg) -> arg + 1), CelFunctionBinding.from( "increment_uint", UnsignedLong.class, (arg) -> arg.plus(UnsignedLong.ONE))); - Optional overload = + Optional overload = bindings.findOverloadMatchingArgs( "increment", ImmutableList.of("increment_int", "increment_uint"), new Object[] {1.0}); assertThat(overload).isEmpty(); @@ -66,7 +66,7 @@ public void findOverload_noMatchingFunctionDifferentArgCount_isEmpty() throws Ex CelFunctionBinding.from("increment_int", Long.class, (arg) -> arg + 1), CelFunctionBinding.from( "increment_uint", UnsignedLong.class, (arg) -> arg.plus(UnsignedLong.ONE))); - Optional overload = + Optional overload = bindings.findOverloadMatchingArgs( "increment", ImmutableList.of("increment_int", "increment_uint"), @@ -88,7 +88,7 @@ public void findOverload_badInput_throwsException() throws Exception { } return arg.plus(UnsignedLong.ONE); })); - Optional overload = + Optional overload = bindings.findOverloadMatchingArgs( "increment", ImmutableList.of("increment_uint"), new Object[] {UnsignedLong.MAX_VALUE}); assertThat(overload).isPresent(); @@ -123,7 +123,7 @@ public void findOverload_nullPrimitiveArg_isEmpty() throws Exception { CelLateFunctionBindings bindings = CelLateFunctionBindings.from( CelFunctionBinding.from("identity_int", Long.class, (arg) -> arg)); - Optional overload = + Optional overload = bindings.findOverloadMatchingArgs( "identity", ImmutableList.of("identity_int"), new Object[] {null}); assertThat(overload).isEmpty(); @@ -134,7 +134,7 @@ public void findOverload_nullMessageArg_returnsOverload() throws Exception { CelLateFunctionBindings bindings = CelLateFunctionBindings.from( CelFunctionBinding.from("identity_msg", TestAllTypes.class, (arg) -> arg)); - Optional overload = + Optional overload = bindings.findOverloadMatchingArgs( "identity", ImmutableList.of("identity_msg"), new Object[] {null}); assertThat(overload).isPresent();