diff --git a/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel b/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel index 170912d95..e14cfc26c 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel +++ b/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel @@ -142,6 +142,8 @@ java_library( java_library( name = "registry_utils", srcs = ["RegistryUtils.java"], + tags = [ + ], deps = [ "//common:cel_descriptors", "//common/internal:cel_descriptor_pools", diff --git a/testing/src/main/java/dev/cel/testing/testrunner/RegistryUtils.java b/testing/src/main/java/dev/cel/testing/testrunner/RegistryUtils.java index 34bb01e52..0cc5bd850 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/RegistryUtils.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/RegistryUtils.java @@ -31,23 +31,27 @@ import java.util.Set; /** Utility class for creating registries from a file descriptor set. */ -final class RegistryUtils { +public final class RegistryUtils { private RegistryUtils() {} /** Returns the {@link FileDescriptorSet} for the given file descriptor set path. */ - static FileDescriptorSet getFileDescriptorSet(String fileDescriptorSetPath) throws IOException { + public static FileDescriptorSet getFileDescriptorSet(String fileDescriptorSetPath) + throws IOException { + // We can pass an empty extension registry here because extensions are recovered later when + // creating the extension registry in {@link #createExtensionRegistry}. return FileDescriptorSet.parseFrom( Files.toByteArray(new File(fileDescriptorSetPath)), ExtensionRegistry.newInstance()); } /** Returns the {@link TypeRegistry} for the given file descriptor set. */ - static TypeRegistry getTypeRegistry(Set fileDescriptors) throws IOException { + public static TypeRegistry getTypeRegistry(Set fileDescriptors) + throws IOException { return createTypeRegistry(fileDescriptors); } /** Returns the {@link ExtensionRegistry} for the given file descriptor set. */ - static ExtensionRegistry getExtensionRegistry(Set fileDescriptors) + public static ExtensionRegistry getExtensionRegistry(Set fileDescriptors) throws IOException { return createExtensionRegistry(fileDescriptors); } diff --git a/testing/src/main/java/dev/cel/testing/utils/BUILD.bazel b/testing/src/main/java/dev/cel/testing/utils/BUILD.bazel index 6d4e41fb0..603f466aa 100644 --- a/testing/src/main/java/dev/cel/testing/utils/BUILD.bazel +++ b/testing/src/main/java/dev/cel/testing/utils/BUILD.bazel @@ -15,9 +15,12 @@ java_library( ], deps = [ "//common:cel_descriptors", + "//common/internal:cel_descriptor_pools", "//common/internal:default_instance_message_factory", + "//common/internal:default_message_factory", "//common/types", "//common/types:type_providers", + "//testing/testrunner:registry_utils", "@cel_spec//proto/cel/expr:expr_java_proto", "@cel_spec//proto/cel/expr/conformance/proto2:test_all_types_java_proto", "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", diff --git a/testing/src/main/java/dev/cel/testing/utils/ExprValueUtils.java b/testing/src/main/java/dev/cel/testing/utils/ExprValueUtils.java index 8db0d68d1..9a976de75 100644 --- a/testing/src/main/java/dev/cel/testing/utils/ExprValueUtils.java +++ b/testing/src/main/java/dev/cel/testing/utils/ExprValueUtils.java @@ -19,23 +19,29 @@ import dev.cel.expr.Value; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.primitives.UnsignedLong; import com.google.protobuf.Any; import com.google.protobuf.ByteString; import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FileDescriptor; import com.google.protobuf.ExtensionRegistry; import com.google.protobuf.Message; import com.google.protobuf.NullValue; import com.google.protobuf.TypeRegistry; import dev.cel.common.CelDescriptorUtil; import dev.cel.common.CelDescriptors; +import dev.cel.common.internal.CelDescriptorPool; +import dev.cel.common.internal.DefaultDescriptorPool; import dev.cel.common.internal.DefaultInstanceMessageFactory; +import dev.cel.common.internal.DefaultMessageFactory; import dev.cel.common.types.CelType; import dev.cel.common.types.ListType; import dev.cel.common.types.MapType; import dev.cel.common.types.OptionalType; import dev.cel.common.types.SimpleType; import dev.cel.common.types.TypeType; +import dev.cel.testing.testrunner.RegistryUtils; import java.io.IOException; import java.util.List; import java.util.Map; @@ -79,6 +85,19 @@ public static Object fromValue(Value value) throws IOException { case OBJECT_VALUE: { Any object = value.getObjectValue(); + + // If the file_descriptor_set_path is set, use the provided file descriptor set created at + // runtime after deserializing the file_descriptor_set file. + // Because of the above reason, DefaultInstanceMessageFactory cannot be used since it + // would always result in a descriptor reference mismatch. Instead, we use + // DefaultMessageFactory to create a DynamicMessage and parse it with `.getValue()`. + // + // TODO: Remove DynamicMessage parsing once default instance generation is + // fixed. + String fileDescriptorSetPath = System.getProperty("file_descriptor_set_path"); + if (fileDescriptorSetPath != null) { + return parseAny(object, fileDescriptorSetPath); + } Descriptor descriptor = DEFAULT_TYPE_REGISTRY.getDescriptorForTypeUrl(object.getTypeUrl()); Message prototype = @@ -245,6 +264,45 @@ public static Value toValue(Object object, CelType type) throws Exception { String.format("Unexpected result type: %s", object.getClass())); } + private static Message parseAny(Any value, String fileDescriptorSetPath) throws IOException { + ImmutableSet fileDescriptors = + CelDescriptorUtil.getFileDescriptorsFromFileDescriptorSet( + RegistryUtils.getFileDescriptorSet(fileDescriptorSetPath)); + + TypeRegistry typeRegistry = RegistryUtils.getTypeRegistry(fileDescriptors); + ExtensionRegistry extensionRegistry = RegistryUtils.getExtensionRegistry(fileDescriptors); + + CelDescriptors allDescriptors = + CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(fileDescriptors); + + CelDescriptorPool pool = DefaultDescriptorPool.create(allDescriptors); + + DefaultMessageFactory defaultMessageFactory = DefaultMessageFactory.create(pool); + Descriptor descriptor = typeRegistry.getDescriptorForTypeUrl(value.getTypeUrl()); + + return unpackAny(value, defaultMessageFactory, descriptor, extensionRegistry); + } + + private static Message unpackAny( + Any value, + DefaultMessageFactory defaultMessageFactory, + Descriptor descriptor, + ExtensionRegistry extensionRegistry) + throws IOException { + // Generate a default message for the given descriptor. + Message defaultInstance = + defaultMessageFactory + .newBuilder(descriptor.getFullName()) + .orElseThrow( + () -> + new NoSuchElementException( + "Could not find a default message for: " + value.getTypeUrl())) + .build(); + + // Parse the default message using the value from the Any object. + return defaultInstance.getParserForType().parseFrom(value.getValue(), extensionRegistry); + } + private static ExtensionRegistry newDefaultExtensionRegistry() { ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance(); dev.cel.expr.conformance.proto2.TestAllTypesExtensions.registerAllExtensions(extensionRegistry); diff --git a/testing/testrunner/BUILD.bazel b/testing/testrunner/BUILD.bazel index a91da900e..3d26c10a2 100644 --- a/testing/testrunner/BUILD.bazel +++ b/testing/testrunner/BUILD.bazel @@ -79,6 +79,12 @@ exports_files( srcs = ["run_testrunner_binary.sh"], ) +java_library( + name = "registry_utils", + visibility = ["//:internal"], + exports = ["//testing/src/main/java/dev/cel/testing/testrunner:registry_utils"], +) + bzl_library( name = "cel_java_test", srcs = ["cel_java_test.bzl"],