diff --git a/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel b/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel index d1f49a5bc..25d86908b 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel +++ b/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel @@ -21,6 +21,7 @@ java_library( ":cel_test_suite_yaml_parser", ":cel_user_test_template", ":junit_xml_reporter", + "//testing/testrunner:class_loader_utils", "@maven//:com_google_guava_guava", "@maven//:io_github_classgraph_classgraph", "@maven//:junit_junit", @@ -61,13 +62,12 @@ java_library( "//:auto_value", "//bundle:cel", "//bundle:environment", - "//bundle:environment_exception", "//bundle:environment_yaml_parser", "//common:cel_ast", - "//common:cel_descriptors", "//common:compiler_common", "//common:options", "//common:proto_ast", + "//common/internal:default_instance_message_factory", "//policy", "//policy:compiler_factory", "//policy:parser", @@ -75,6 +75,7 @@ java_library( "//policy:validation_exception", "//runtime", "//testing:expr_value_utils", + "//testing/testrunner:proto_descriptor_utils", "@cel_spec//proto/cel/expr:expr_java_proto", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", @@ -145,10 +146,8 @@ java_library( tags = [ ], deps = [ - "//common:cel_descriptors", - "//common/internal:cel_descriptor_pools", - "//common/internal:default_message_factory", - "@maven//:com_google_guava_guava", + "//common/internal:default_instance_message_factory", + "//testing/testrunner:proto_descriptor_utils", "@maven//:com_google_protobuf_protobuf_java", ], ) @@ -177,11 +176,9 @@ java_library( "//:java_truth", "//bundle:cel", "//common:cel_ast", - "//common:cel_descriptors", "//runtime", "//testing:expr_value_utils", "@cel_spec//proto/cel/expr:expr_java_proto", - "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", "@maven//:com_google_truth_extensions_truth_proto_extension", ], @@ -196,7 +193,6 @@ java_library( ":cel_test_suite", ":cel_test_suite_exception", ":registry_utils", - "//common:cel_descriptors", "@cel_spec//proto/cel/expr:expr_java_proto", "@cel_spec//proto/cel/expr/conformance/test:suite_java_proto", "@maven//:com_google_guava_guava", diff --git a/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteTextProtoParser.java b/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteTextProtoParser.java index 2bb87bf76..43b8de293 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteTextProtoParser.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteTextProtoParser.java @@ -18,12 +18,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import com.google.protobuf.Descriptors.FileDescriptor; import com.google.protobuf.ExtensionRegistry; import com.google.protobuf.TextFormat; import com.google.protobuf.TextFormat.ParseException; import com.google.protobuf.TypeRegistry; -import dev.cel.common.CelDescriptorUtil; import dev.cel.expr.conformance.test.InputValue; import dev.cel.expr.conformance.test.TestCase; import dev.cel.expr.conformance.test.TestSection; @@ -55,11 +53,8 @@ private TestSuite parseTestSuite(String textProto) throws IOException { TypeRegistry typeRegistry = TypeRegistry.getEmptyTypeRegistry(); ExtensionRegistry extensionRegistry = ExtensionRegistry.getEmptyRegistry(); if (fileDescriptorSetPath != null) { - ImmutableSet fileDescriptors = - CelDescriptorUtil.getFileDescriptorsFromFileDescriptorSet( - RegistryUtils.getFileDescriptorSet(fileDescriptorSetPath)); - extensionRegistry = RegistryUtils.getExtensionRegistry(fileDescriptors); - typeRegistry = RegistryUtils.getTypeRegistry(fileDescriptors); + extensionRegistry = RegistryUtils.getExtensionRegistry(); + typeRegistry = RegistryUtils.getTypeRegistry(); } TextFormat.Parser parser = TextFormat.Parser.newBuilder().setTypeRegistry(typeRegistry).build(); TestSuite.Builder builder = TestSuite.newBuilder(); diff --git a/testing/src/main/java/dev/cel/testing/testrunner/DefaultResultMatcher.java b/testing/src/main/java/dev/cel/testing/testrunner/DefaultResultMatcher.java index 8e3424cab..ede7b8c30 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/DefaultResultMatcher.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/DefaultResultMatcher.java @@ -20,11 +20,8 @@ import dev.cel.expr.ExprValue; import dev.cel.expr.MapValue; -import com.google.common.collect.ImmutableSet; -import com.google.protobuf.Descriptors.FileDescriptor; import dev.cel.bundle.Cel; import dev.cel.common.CelAbstractSyntaxTree; -import dev.cel.common.CelDescriptorUtil; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelRuntime.Program; import dev.cel.testing.testrunner.CelTestSuite.CelTestSection.CelTestCase.Output; @@ -88,15 +85,10 @@ private static void assertExprValue(ExprValue exprValue, ExprValue expectedExprV throws IOException { String fileDescriptorSetPath = System.getProperty("file_descriptor_set_path"); if (fileDescriptorSetPath != null) { - ImmutableSet fileDescriptors = - CelDescriptorUtil.getFileDescriptorsFromFileDescriptorSet( - RegistryUtils.getFileDescriptorSet(fileDescriptorSetPath)); assertThat(exprValue) .ignoringRepeatedFieldOrderOfFieldDescriptors( MapValue.getDescriptor().findFieldByName("entries")) - .unpackingAnyUsing( - RegistryUtils.getTypeRegistry(fileDescriptors), - RegistryUtils.getExtensionRegistry(fileDescriptors)) + .unpackingAnyUsing(RegistryUtils.getTypeRegistry(), RegistryUtils.getExtensionRegistry()) .isEqualTo(expectedExprValue); } else { assertThat(exprValue) diff --git a/testing/src/main/java/dev/cel/testing/testrunner/RegistryUtils.java b/testing/src/main/java/dev/cel/testing/testrunner/RegistryUtils.java index 0cc5bd850..515f9ba5f 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/RegistryUtils.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/RegistryUtils.java @@ -13,91 +13,50 @@ // limitations under the License. package dev.cel.testing.testrunner; -import com.google.common.io.Files; -import com.google.protobuf.DescriptorProtos.FileDescriptorSet; +import static dev.cel.testing.utils.ProtoDescriptorUtils.getAllDescriptorsFromJvm; + import com.google.protobuf.Descriptors.FieldDescriptor; -import com.google.protobuf.Descriptors.FileDescriptor; import com.google.protobuf.ExtensionRegistry; import com.google.protobuf.Message; import com.google.protobuf.TypeRegistry; -import dev.cel.common.CelDescriptorUtil; -import dev.cel.common.CelDescriptors; -import dev.cel.common.internal.CelDescriptorPool; -import dev.cel.common.internal.DefaultDescriptorPool; -import dev.cel.common.internal.DefaultMessageFactory; -import java.io.File; +import dev.cel.common.internal.DefaultInstanceMessageFactory; import java.io.IOException; import java.util.NoSuchElementException; -import java.util.Set; /** Utility class for creating registries from a file descriptor set. */ public final class RegistryUtils { - private RegistryUtils() {} - - /** Returns the {@link FileDescriptorSet} for the given file descriptor set path. */ - public static FileDescriptorSet getFileDescriptorSet(String fileDescriptorSetPath) - throws IOException { - // We can pass an empty extension registry here because extensions are recovered later when - // creating the extension registry in {@link #createExtensionRegistry}. - return FileDescriptorSet.parseFrom( - Files.toByteArray(new File(fileDescriptorSetPath)), ExtensionRegistry.newInstance()); - } - /** Returns the {@link TypeRegistry} for the given file descriptor set. */ - public static TypeRegistry getTypeRegistry(Set fileDescriptors) - throws IOException { - return createTypeRegistry(fileDescriptors); + public static TypeRegistry getTypeRegistry() throws IOException { + return TypeRegistry.newBuilder() + .add(getAllDescriptorsFromJvm().messageTypeDescriptors()) + .build(); } /** Returns the {@link ExtensionRegistry} for the given file descriptor set. */ - public static ExtensionRegistry getExtensionRegistry(Set fileDescriptors) - throws IOException { - return createExtensionRegistry(fileDescriptors); - } - - private static TypeRegistry createTypeRegistry(Set fileDescriptors) { - CelDescriptors allDescriptors = - CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(fileDescriptors); - return TypeRegistry.newBuilder().add(allDescriptors.messageTypeDescriptors()).build(); - } - - private static ExtensionRegistry createExtensionRegistry(Set fileDescriptors) { + public static ExtensionRegistry getExtensionRegistry() throws IOException { ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance(); - CelDescriptors allDescriptors = - CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(fileDescriptors); - - CelDescriptorPool pool = DefaultDescriptorPool.create(allDescriptors); - - // We need to create a default message factory because there would always be difference in - // reference between the default instance's descriptor and the descriptor in the pool if the - // file descriptor set is created at runtime, therefore - // we need to create a default message factory to get the default instance for each descriptor - // because it falls back to the DynamicMessages. - // - // For more details, see: b/292174333 - DefaultMessageFactory defaultMessageFactory = DefaultMessageFactory.create(pool); - - allDescriptors + getAllDescriptorsFromJvm() .extensionDescriptors() .forEach( (descriptorName, descriptor) -> { if (descriptor.getType().equals(FieldDescriptor.Type.MESSAGE)) { - Message.Builder defaultInstance = - defaultMessageFactory - .newBuilder(descriptor.getMessageType().getFullName()) + Message output = + DefaultInstanceMessageFactory.getInstance() + .getPrototype(descriptor.getMessageType()) .orElseThrow( () -> new NoSuchElementException( "Could not find a default message for: " + descriptor.getFullName())); - extensionRegistry.add(descriptor, defaultInstance.build()); + extensionRegistry.add(descriptor, output); } else { extensionRegistry.add(descriptor); } }); - return extensionRegistry; } + + private RegistryUtils() {} } diff --git a/testing/src/main/java/dev/cel/testing/testrunner/TestExecutor.java b/testing/src/main/java/dev/cel/testing/testrunner/TestExecutor.java index c52dbef74..f853b117b 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/TestExecutor.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/TestExecutor.java @@ -15,6 +15,8 @@ package dev.cel.testing.testrunner; import static com.google.common.collect.MoreCollectors.onlyElement; +import static dev.cel.testing.utils.ClassLoaderUtils.loadClassesWithMethodAnnotation; +import static dev.cel.testing.utils.ClassLoaderUtils.loadSubclasses; import static java.nio.charset.StandardCharsets.UTF_8; import static java.time.ZoneId.systemDefault; @@ -23,9 +25,7 @@ import dev.cel.testing.testrunner.Annotations.TestSuiteSupplier; import dev.cel.testing.testrunner.CelTestSuite.CelTestSection; import dev.cel.testing.testrunner.CelTestSuite.CelTestSection.CelTestCase; -import io.github.classgraph.ClassGraph; import io.github.classgraph.ClassInfoList; -import io.github.classgraph.ScanResult; import java.io.File; import java.io.IOException; import java.lang.reflect.Method; @@ -166,12 +166,8 @@ private void setStartMillis(long startMillis) { public static void runTests() throws Exception { String testSuitePath = System.getProperty("test_suite_path"); - // Using `enableAllInfo()` to scan all class files upfront. This avoids repeated parsing - // of class files by individual methods, improving efficiency. - ScanResult scanResult = new ClassGraph().enableAllInfo().scan(); - CelTestSuite testSuite; - testSuite = readCustomTestSuite(scanResult); + testSuite = readCustomTestSuite(); if (testSuitePath != null) { if (testSuite != null) { @@ -184,7 +180,7 @@ public static void runTests() throws Exception { throw new IllegalArgumentException("Neither test_suite_path nor TestSuiteSupplier is set."); } - Class testClass = getUserTestClass(scanResult); + Class testClass = getUserTestClass(); String envXmlFile = System.getenv("XML_OUTPUT_FILE"); JUnitXmlReporter testReporter = new JUnitXmlReporter(envXmlFile); TestContext testContext = new TestContext(); @@ -245,9 +241,9 @@ public String describe() { } } - private static CelTestSuite readCustomTestSuite(ScanResult scanResult) throws Exception { + private static CelTestSuite readCustomTestSuite() throws Exception { ClassInfoList classInfoList = - scanResult.getClassesWithMethodAnnotation(CEL_TESTSUITE_ANNOTATION_CLASS.getName()); + loadClassesWithMethodAnnotation(CEL_TESTSUITE_ANNOTATION_CLASS.getName()); if (classInfoList.isEmpty()) { return null; } @@ -278,8 +274,8 @@ private static Method getMethodWithAnnotation(Class clazz) { return testSuiteSupplierMethod; } - private static Class getUserTestClass(ScanResult scanResult) { - ClassInfoList subClassInfoList = scanResult.getSubclasses(CelUserTestTemplate.class); + private static Class getUserTestClass() { + ClassInfoList subClassInfoList = loadSubclasses(CelUserTestTemplate.class); if (subClassInfoList.size() != 1) { throw new IllegalArgumentException( "Expected 1 subclass for CelUserTestTemplate, but got " diff --git a/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java b/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java index a261405bc..267368531 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java @@ -25,27 +25,22 @@ import dev.cel.expr.Value; import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; import com.google.protobuf.Any; import com.google.protobuf.Descriptors.Descriptor; -import com.google.protobuf.Descriptors.FileDescriptor; -import com.google.protobuf.DynamicMessage; import com.google.protobuf.ExtensionRegistry; import com.google.protobuf.Message; import com.google.protobuf.TextFormat; import dev.cel.bundle.Cel; import dev.cel.bundle.CelEnvironment; import dev.cel.bundle.CelEnvironment.ExtensionConfig; -import dev.cel.bundle.CelEnvironmentException; import dev.cel.bundle.CelEnvironmentYamlParser; import dev.cel.common.CelAbstractSyntaxTree; -import dev.cel.common.CelDescriptorUtil; import dev.cel.common.CelOptions; import dev.cel.common.CelProtoAbstractSyntaxTree; import dev.cel.common.CelValidationException; +import dev.cel.common.internal.DefaultInstanceMessageFactory; import dev.cel.policy.CelPolicy; import dev.cel.policy.CelPolicyCompilerFactory; import dev.cel.policy.CelPolicyParser; @@ -56,10 +51,12 @@ import dev.cel.testing.testrunner.CelTestSuite.CelTestSection.CelTestCase; import dev.cel.testing.testrunner.CelTestSuite.CelTestSection.CelTestCase.Input.Binding; import dev.cel.testing.testrunner.ResultMatcher.ResultMatcherParams; +import dev.cel.testing.utils.ProtoDescriptorUtils; import java.io.File; import java.io.IOException; import java.nio.file.Paths; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Optional; import java.util.logging.Logger; @@ -135,8 +132,7 @@ private static CelAbstractSyntaxTree readAstFromCheckedExpression( } private static CelTestContext extendCelTestContext( - CelTestContext celTestContext, CelExprFileSource celExprFileSource) - throws CelEnvironmentException, IOException { + CelTestContext celTestContext, CelExprFileSource celExprFileSource) throws Exception { CelOptions celOptions = celTestContext.celOptions(); CelTestContext.Builder celTestContextBuilder = celTestContext.toBuilder().setCel(extendCel(celTestContext.cel(), celOptions)); @@ -150,8 +146,7 @@ private static CelTestContext extendCelTestContext( return celTestContextBuilder.build(); } - private static Cel extendCel(Cel cel, CelOptions celOptions) - throws IOException, CelEnvironmentException { + private static Cel extendCel(Cel cel, CelOptions celOptions) throws Exception { Cel extendedCel = cel; // Add the file descriptor set to the cel object if provided. @@ -162,7 +157,9 @@ private static Cel extendCel(Cel cel, CelOptions celOptions) if (fileDescriptorSetPath != null) { extendedCel = cel.toCelBuilder() - .addFileTypes(RegistryUtils.getFileDescriptorSet(fileDescriptorSetPath)) + .addMessageTypes( + ProtoDescriptorUtils.getAllDescriptorsFromJvm().messageTypeDescriptors()) + .setExtensionRegistry(RegistryUtils.getExtensionRegistry()) .build(); } @@ -256,7 +253,7 @@ private static void evaluate( String.format( "Evaluation failed for test case: %s. Error: %s", testCase.name(), e.getMessage()); error = new CelEvaluationException(errorMessage, e); - logger.severe(errorMessage); + logger.severe(e.toString()); } // Perform the assertion on the result of the evaluation. @@ -288,44 +285,48 @@ private static void evaluate( } private static Object getEvaluationResult( - CelTestCase testCase, CelTestContext celTestContext, Program program) throws Exception { + CelTestCase testCase, CelTestContext celTestContext, Program program) + throws CelEvaluationException, IOException, CelValidationException { if (celTestContext.celLateFunctionBindings().isPresent()) { return program.eval( getBindings(testCase, celTestContext), celTestContext.celLateFunctionBindings().get()); } switch (testCase.input().kind()) { case CONTEXT_MESSAGE: - return program.eval( - unpackAny( - testCase.input().contextMessage(), System.getProperty("file_descriptor_set_path"))); + return program.eval(unpackAny(testCase.input().contextMessage())); case CONTEXT_EXPR: return program.eval(getEvaluatedContextExpr(testCase, celTestContext)); case BINDINGS: return program.eval(getBindings(testCase, celTestContext)); case NO_INPUT: - return program.eval(celTestContext.variableBindings()); + ImmutableMap.Builder newBindings = ImmutableMap.builder(); + for (Map.Entry entry : celTestContext.variableBindings().entrySet()) { + if (entry.getValue() instanceof Any) { + newBindings.put(entry.getKey(), unpackAny((Any) entry.getValue())); + } else { + newBindings.put(entry); + } + } + return program.eval(newBindings.buildOrThrow()); } throw new IllegalArgumentException("Unexpected input type: " + testCase.input().kind()); } - // TODO: Remove DynamicMessage parsing once default instance generation is fixed. - // - // Dynamic Message parsing is added here to make the default instance generation code OSS - // compatible. Otherwise, we'd have to depend on AnyUtil which is not available in OSS. - // However, the functionality fails in OSS as the generated DynamicMessage descriptor is different - // from the message descriptor. - private static Message unpackAny(Any any, String fileDescriptorSetPath) throws IOException { - Preconditions.checkNotNull( - fileDescriptorSetPath, - "File descriptor set is required to unpack Any of type: %s.", - any.getTypeUrl()); - ImmutableSet fileDescriptors = - CelDescriptorUtil.getFileDescriptorsFromFileDescriptorSet( - RegistryUtils.getFileDescriptorSet(fileDescriptorSetPath)); + private static Message unpackAny(Any any) throws IOException { Descriptor descriptor = - RegistryUtils.getTypeRegistry(fileDescriptors).getDescriptorForTypeUrl(any.getTypeUrl()); - return DynamicMessage.parseFrom( - descriptor, any.getValue(), RegistryUtils.getExtensionRegistry(fileDescriptors)); + RegistryUtils.getTypeRegistry().getDescriptorForTypeUrl(any.getTypeUrl()); + return getDefaultInstance(descriptor) + .getParserForType() + .parseFrom(any.getValue(), RegistryUtils.getExtensionRegistry()); + } + + private static Message getDefaultInstance(Descriptor descriptor) throws IOException { + return DefaultInstanceMessageFactory.getInstance() + .getPrototype(descriptor) + .orElseThrow( + () -> + new NoSuchElementException( + "Could not find a default message for: " + descriptor.getFullName())); } private static Message getEvaluatedContextExpr( @@ -340,7 +341,7 @@ private static Message getEvaluatedContextExpr( private static ImmutableMap getBindings( CelTestCase testCase, CelTestContext celTestContext) - throws CelEvaluationException, CelValidationException, IOException { + throws IOException, CelEvaluationException, CelValidationException { Cel cel = celTestContext.cel(); ImmutableMap.Builder inputBuilder = ImmutableMap.builder(); for (Map.Entry entry : testCase.input().bindings().entrySet()) { diff --git a/testing/src/main/java/dev/cel/testing/utils/BUILD.bazel b/testing/src/main/java/dev/cel/testing/utils/BUILD.bazel index 987336268..d6646fd32 100644 --- a/testing/src/main/java/dev/cel/testing/utils/BUILD.bazel +++ b/testing/src/main/java/dev/cel/testing/utils/BUILD.bazel @@ -5,6 +5,7 @@ package( default_testonly = True, default_visibility = [ "//testing:__pkg__", + "//testing/testrunner:__pkg__", ], ) @@ -15,9 +16,7 @@ java_library( ], deps = [ "//common:cel_descriptors", - "//common/internal:cel_descriptor_pools", "//common/internal:default_instance_message_factory", - "//common/internal:default_message_factory", "//common/types", "//common/types:type_providers", "//runtime:unknown_attributes", @@ -30,3 +29,28 @@ java_library( "@maven_android//:com_google_protobuf_protobuf_javalite", ], ) + +java_library( + name = "class_loader_utils", + srcs = ["ClassLoaderUtils.java"], + tags = [ + ], + deps = [ + "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", + "@maven//:io_github_classgraph_classgraph", + ], +) + +java_library( + name = "proto_descriptor_utils", + srcs = ["ProtoDescriptorUtils.java"], + tags = [ + ], + deps = [ + "//common:cel_descriptors", + "//testing/testrunner:class_loader_utils", + "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", + ], +) diff --git a/testing/src/main/java/dev/cel/testing/utils/ClassLoaderUtils.java b/testing/src/main/java/dev/cel/testing/utils/ClassLoaderUtils.java new file mode 100644 index 000000000..652ec85c6 --- /dev/null +++ b/testing/src/main/java/dev/cel/testing/utils/ClassLoaderUtils.java @@ -0,0 +1,84 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package dev.cel.testing.utils; + +import com.google.common.base.Supplier; +import com.google.common.base.Suppliers; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.Descriptors.Descriptor; +import io.github.classgraph.ClassGraph; +import io.github.classgraph.ClassInfo; +import io.github.classgraph.ClassInfoList; +import io.github.classgraph.ScanResult; +import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.util.logging.Logger; + +/** Utility class for loading classes using {@link ClassGraph}. */ +public final class ClassLoaderUtils { + + private static final Logger logger = Logger.getLogger(ClassLoaderUtils.class.getName()); + + // Using `enableAllInfo()` to scan all class files upfront. This avoids repeated parsing + // of class files by individual methods, improving efficiency. + private static final Supplier CLASS_SCAN_RESULT = + Suppliers.memoize(() -> new ClassGraph().enableAllInfo().scan()); + + /** + * Loads all descriptor type classes from the JVM. + * + * @return A list of {@link Descriptor} objects representing the descriptors loaded from the JVM. + * @throws IOException If there is an error during the loading process. + */ + public static ImmutableList loadDescriptors() throws IOException { + ClassInfoList classInfoList = CLASS_SCAN_RESULT.get().getAllStandardClasses(); + ImmutableList.Builder compileTimeLoadedDescriptors = ImmutableList.builder(); + + for (ClassInfo classInfo : classInfoList) { + try { + Class classInfoClass = classInfo.loadClass(); + Descriptor descriptor = (Descriptor) classInfoClass.getMethod("getDescriptor").invoke(null); + compileTimeLoadedDescriptors.add(descriptor); + } catch (InvocationTargetException e) { + logger.severe( + "Failed to load descriptor: " + classInfo.getName() + " with error: " + e); + } catch (Exception e) { + // Ignore classes that do not have a getDescriptor method. + } + } + return compileTimeLoadedDescriptors.build(); + } + + /** + * Loads all subclasses of the given class from the JVM. + * + * @param clazz The class to load subclasses for. + * @return A list of {@link ClassInfo} objects representing the subclasses. + */ + public static ClassInfoList loadSubclasses(Class clazz) { + return CLASS_SCAN_RESULT.get().getSubclasses(clazz.getName()); + } + + /** + * Loads all classes with the given method annotation from the JVM. + * + * @param annotationName The name of the annotation to load classes with. + * @return A list of {@link ClassInfo} objects representing the classes with the annotation. + */ + public static ClassInfoList loadClassesWithMethodAnnotation(String annotationName) { + return CLASS_SCAN_RESULT.get().getClassesWithMethodAnnotation(annotationName); + } + + private ClassLoaderUtils() {} +} diff --git a/testing/src/main/java/dev/cel/testing/utils/ExprValueUtils.java b/testing/src/main/java/dev/cel/testing/utils/ExprValueUtils.java index 4a0571e9b..1589f60ae 100644 --- a/testing/src/main/java/dev/cel/testing/utils/ExprValueUtils.java +++ b/testing/src/main/java/dev/cel/testing/utils/ExprValueUtils.java @@ -20,22 +20,17 @@ import dev.cel.expr.Value; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; import com.google.common.primitives.UnsignedLong; import com.google.protobuf.Any; import com.google.protobuf.ByteString; import com.google.protobuf.Descriptors.Descriptor; -import com.google.protobuf.Descriptors.FileDescriptor; import com.google.protobuf.ExtensionRegistry; import com.google.protobuf.Message; import com.google.protobuf.NullValue; import com.google.protobuf.TypeRegistry; import dev.cel.common.CelDescriptorUtil; import dev.cel.common.CelDescriptors; -import dev.cel.common.internal.CelDescriptorPool; -import dev.cel.common.internal.DefaultDescriptorPool; import dev.cel.common.internal.DefaultInstanceMessageFactory; -import dev.cel.common.internal.DefaultMessageFactory; import dev.cel.common.types.CelType; import dev.cel.common.types.ListType; import dev.cel.common.types.MapType; @@ -98,17 +93,11 @@ public static Object fromValue(Value value) throws IOException { // fixed. String fileDescriptorSetPath = System.getProperty("file_descriptor_set_path"); if (fileDescriptorSetPath != null) { - return parseAny(object, fileDescriptorSetPath); + return parseAny(object); } Descriptor descriptor = DEFAULT_TYPE_REGISTRY.getDescriptorForTypeUrl(object.getTypeUrl()); - Message prototype = - DefaultInstanceMessageFactory.getInstance() - .getPrototype(descriptor) - .orElseThrow( - () -> - new NoSuchElementException( - "Could not find a default message for: " + descriptor.getFullName())); + Message prototype = getDefaultInstance(descriptor); return prototype .getParserForType() .parseFrom(object.getValue(), DEFAULT_EXTENSION_REGISTRY); @@ -273,45 +262,28 @@ public static Value toValue(Object object, CelType type) throws Exception { String.format("Unexpected result type: %s", object.getClass())); } - private static Message parseAny(Any value, String fileDescriptorSetPath) throws IOException { - ImmutableSet fileDescriptors = - CelDescriptorUtil.getFileDescriptorsFromFileDescriptorSet( - RegistryUtils.getFileDescriptorSet(fileDescriptorSetPath)); - - TypeRegistry typeRegistry = RegistryUtils.getTypeRegistry(fileDescriptors); - ExtensionRegistry extensionRegistry = RegistryUtils.getExtensionRegistry(fileDescriptors); - - CelDescriptors allDescriptors = - CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(fileDescriptors); - - CelDescriptorPool pool = DefaultDescriptorPool.create(allDescriptors); - - DefaultMessageFactory defaultMessageFactory = DefaultMessageFactory.create(pool); + private static Message parseAny(Any value) throws IOException { + TypeRegistry typeRegistry = RegistryUtils.getTypeRegistry(); + ExtensionRegistry extensionRegistry = RegistryUtils.getExtensionRegistry(); Descriptor descriptor = typeRegistry.getDescriptorForTypeUrl(value.getTypeUrl()); - - return unpackAny(value, defaultMessageFactory, descriptor, extensionRegistry); + return unpackAny(value, descriptor, extensionRegistry); } private static Message unpackAny( - Any value, - DefaultMessageFactory defaultMessageFactory, - Descriptor descriptor, - ExtensionRegistry extensionRegistry) - throws IOException { - // Generate a default message for the given descriptor. - Message defaultInstance = - defaultMessageFactory - .newBuilder(descriptor.getFullName()) - .orElseThrow( - () -> - new NoSuchElementException( - "Could not find a default message for: " + value.getTypeUrl())) - .build(); - - // Parse the default message using the value from the Any object. + Any value, Descriptor descriptor, ExtensionRegistry extensionRegistry) throws IOException { + Message defaultInstance = getDefaultInstance(descriptor); return defaultInstance.getParserForType().parseFrom(value.getValue(), extensionRegistry); } + private static Message getDefaultInstance(Descriptor descriptor) throws IOException { + return DefaultInstanceMessageFactory.getInstance() + .getPrototype(descriptor) + .orElseThrow( + () -> + new NoSuchElementException( + "Could not find a default message for: " + descriptor.getFullName())); + } + private static ExtensionRegistry newDefaultExtensionRegistry() { ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance(); dev.cel.expr.conformance.proto2.TestAllTypesExtensions.registerAllExtensions(extensionRegistry); diff --git a/testing/src/main/java/dev/cel/testing/utils/ProtoDescriptorUtils.java b/testing/src/main/java/dev/cel/testing/utils/ProtoDescriptorUtils.java new file mode 100644 index 000000000..455b00693 --- /dev/null +++ b/testing/src/main/java/dev/cel/testing/utils/ProtoDescriptorUtils.java @@ -0,0 +1,71 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package dev.cel.testing.utils; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static dev.cel.testing.utils.ClassLoaderUtils.loadDescriptors; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.io.Files; +import com.google.protobuf.DescriptorProtos.FileDescriptorSet; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FileDescriptor; +import com.google.protobuf.ExtensionRegistry; +import dev.cel.common.CelDescriptorUtil; +import dev.cel.common.CelDescriptors; +import java.io.File; +import java.io.IOException; + +/** Utility class for working with proto descriptors. */ +public final class ProtoDescriptorUtils { + + /** + * Returns all the descriptors from the JVM. + * + * @return The {@link CelDescriptors} object containing all the descriptors. + */ + public static CelDescriptors getAllDescriptorsFromJvm() throws IOException { + ImmutableList compileTimeLoadedDescriptors = loadDescriptors(); + FileDescriptorSet fileDescriptorSet = + getFileDescriptorSet(System.getProperty("file_descriptor_set_path")); + ImmutableSet runtimeFileDescriptorNames = + CelDescriptorUtil.getFileDescriptorsFromFileDescriptorSet(fileDescriptorSet).stream() + .map(FileDescriptor::getFullName) + .collect(toImmutableSet()); + + // Get all the file descriptors from the descriptors which are loaded from the JVM and use the + // ones which match the ones provided by the user in the file descriptor set. + ImmutableList userProvidedFileDescriptors = + CelDescriptorUtil.getFileDescriptorsForDescriptors(compileTimeLoadedDescriptors).stream() + .filter( + fileDescriptor -> runtimeFileDescriptorNames.contains(fileDescriptor.getFullName())) + .collect(toImmutableList()); + + // Get all the descriptors from the file descriptors above which include nested, extension and + // message type descriptors as well. + return CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(userProvidedFileDescriptors); + } + + private static FileDescriptorSet getFileDescriptorSet(String fileDescriptorSetPath) + throws IOException { + // We can pass an empty extension registry here because extensions are recovered later when + // creating the extension registry in {@link #createExtensionRegistry}. + return FileDescriptorSet.parseFrom( + Files.toByteArray(new File(fileDescriptorSetPath)), ExtensionRegistry.newInstance()); + } + + private ProtoDescriptorUtils() {} +} diff --git a/testing/src/test/java/dev/cel/testing/testrunner/BUILD.bazel b/testing/src/test/java/dev/cel/testing/testrunner/BUILD.bazel index cc4bf377c..bee44b4dd 100644 --- a/testing/src/test/java/dev/cel/testing/testrunner/BUILD.bazel +++ b/testing/src/test/java/dev/cel/testing/testrunner/BUILD.bazel @@ -11,14 +11,6 @@ package( ], ) -proto_descriptor_set( - name = "test_all_types_fds", - deps = [ - "@cel_spec//proto/cel/expr/conformance/proto2:test_all_types_proto", - "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_proto", - ], -) - # Since the user test class is triggered by the cel_test_runner rule, we should not add it to the # junit4_test_suite. # This is just a sample test class for the cel_test_runner rule. @@ -153,6 +145,10 @@ cel_java_test( test_data_path = "//testing/src/test/resources/policy", test_src = ":user_test", test_suite = "nested_rule/testrunner_tests.yaml", + deps = [ + "@cel_spec//proto/cel/expr/conformance/proto2:test_all_types_java_proto", + "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", + ], ) cel_java_test( @@ -167,6 +163,23 @@ cel_java_test( test_suite = "nested_rule/testrunner_unknown_output_tests.yaml", ) +cel_java_test( + name = "custom_variable_binding_test_runner_sample", + cel_expr = "custom_variable_bindings/policy.yaml", + config = "custom_variable_bindings/config.yaml", + proto_deps = [ + "@cel_spec//proto/cel/expr/conformance/proto2:test_all_types_proto", + "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_proto", + ], + test_data_path = "//testing/src/test/resources/policy", + test_src = ":custom_variable_binding_user_test", + test_suite = "custom_variable_bindings/tests.yaml", + deps = [ + "@cel_spec//proto/cel/expr/conformance/proto2:test_all_types_java_proto", + "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", + ], +) + cel_java_test( name = "test_runner_yaml_sample_with_eval_error", cel_expr = "nested_rule/eval_error_policy.yaml", @@ -191,6 +204,10 @@ cel_java_test( test_data_path = "//testing/src/test/resources/policy", test_src = ":context_pb_user_test", test_suite = "context_pb/tests.yaml", + deps = [ + "@cel_spec//proto/cel/expr/conformance/proto2:test_all_types_java_proto", + "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", + ], ) cel_java_test( @@ -204,6 +221,10 @@ cel_java_test( test_data_path = "//testing/src/test/resources/policy", test_src = ":env_config_user_test", test_suite = "nested_rule/testrunner_tests.textproto", + deps = [ + "@cel_spec//proto/cel/expr/conformance/proto2:test_all_types_java_proto", + "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", + ], ) cel_java_test( @@ -216,6 +237,10 @@ cel_java_test( test_data_path = "//testing/src/test/resources/policy", test_src = ":user_test", test_suite = "nested_rule/testrunner_tests.textproto", + deps = [ + "@cel_spec//proto/cel/expr/conformance/proto2:test_all_types_java_proto", + "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", + ], ) cel_java_test( @@ -226,16 +251,6 @@ cel_java_test( test_suite = "expr_value_output/tests.textproto", ) -java_library( - name = "custom_test_suite", - srcs = ["CustomTestSuite.java"], - deps = [ - "//testing/testrunner:annotations", - "//testing/testrunner:cel_test_suite", - "@maven//:com_google_guava_guava", - ], -) - cel_java_test( name = "test_runner_sample_with_eval_error", cel_expr = "nested_rule/eval_error_policy.yaml", @@ -247,6 +262,10 @@ cel_java_test( test_data_path = "//testing/src/test/resources/policy", test_src = ":env_config_user_test", test_suite = "nested_rule/eval_error_tests.textproto", + deps = [ + "@cel_spec//proto/cel/expr/conformance/proto2:test_all_types_java_proto", + "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", + ], ) cel_java_test( @@ -260,6 +279,10 @@ cel_java_test( test_data_path = "//testing/src/test/resources/policy", test_src = ":context_pb_user_test", test_suite = "context_pb/context_msg_tests.textproto", + deps = [ + "@cel_spec//proto/cel/expr/conformance/proto2:test_all_types_java_proto", + "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", + ], ) cel_java_test( @@ -273,6 +296,10 @@ cel_java_test( test_data_path = "//testing/src/test/resources/policy", test_src = ":context_pb_user_test", test_suite = "context_pb/tests.textproto", + deps = [ + "@cel_spec//proto/cel/expr/conformance/proto2:test_all_types_java_proto", + "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", + ], ) cel_java_test( @@ -295,6 +322,20 @@ cel_java_test( test_data_path = "//testing/src/test/resources/policy", test_src = ":user_test", test_suite = "protoextension_value_as_input/tests.textproto", + deps = [ + "@cel_spec//proto/cel/expr/conformance/proto2:test_all_types_java_proto", + "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", + ], +) + +java_library( + name = "custom_test_suite", + srcs = ["CustomTestSuite.java"], + deps = [ + "//testing/testrunner:annotations", + "//testing/testrunner:cel_test_suite", + "@maven//:com_google_guava_guava", + ], ) cel_java_test( diff --git a/testing/src/test/java/dev/cel/testing/testrunner/TestRunnerLibraryTest.java b/testing/src/test/java/dev/cel/testing/testrunner/TestRunnerLibraryTest.java index 4d8aa6b09..c85a8065f 100644 --- a/testing/src/test/java/dev/cel/testing/testrunner/TestRunnerLibraryTest.java +++ b/testing/src/test/java/dev/cel/testing/testrunner/TestRunnerLibraryTest.java @@ -18,7 +18,6 @@ import static org.junit.Assert.assertThrows; import com.google.common.collect.ImmutableMap; -import com.google.protobuf.Any; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.util.TestUtil; import dev.cel.bundle.CelFactory; @@ -80,36 +79,6 @@ public void runPolicyTest_outputMismatch_failureAssertion() throws Exception { assertThat(thrown).hasMessageThat().contains("modified: value.bool_value: true -> false"); } - @Test - public void runPolicyTest_fileDescriptorSetPathNotSet_failureInUnpackAny() throws Exception { - CelTestCase simpleOutputTestCase = - CelTestCase.newBuilder() - .setName("fileDescriptorSetPathNotSet_test") - .setDescription("fileDescriptorSetPathNotSet_test_description") - .setInput( - CelTestSuite.CelTestSection.CelTestCase.Input.ofContextMessage( - Any.pack(TestAllTypes.getDefaultInstance()))) - .setOutput(CelTestSuite.CelTestSection.CelTestCase.Output.ofResultValue(true)) - .build(); - CelExprFileSource celExprFileSource = - CelExprFileSource.fromFile( - TestUtil.getSrcDir() - + "/google3/third_party/java/cel/testing/src/test/java/dev/cel/testing/testrunner/resources/empty_policy.yaml"); - - NullPointerException thrown = - assertThrows( - NullPointerException.class, - () -> - TestRunnerLibrary.evaluateTestCase( - simpleOutputTestCase, CelTestContext.newBuilder().build(), celExprFileSource)); - - assertThat(thrown) - .hasMessageThat() - .contains( - "File descriptor set is required to unpack Any of type:" - + " type.googleapis.com/cel.expr.conformance.proto3.TestAllTypes."); - } - @Test public void runPolicyTest_evaluatedContextExprNotProtoMessage_failure() throws Exception { CelTestCase simpleOutputTestCase = diff --git a/testing/src/test/resources/policy/BUILD.bazel b/testing/src/test/resources/policy/BUILD.bazel index 4dff37014..aadae6ba6 100644 --- a/testing/src/test/resources/policy/BUILD.bazel +++ b/testing/src/test/resources/policy/BUILD.bazel @@ -4,7 +4,7 @@ package( ], default_testonly = True, default_visibility = [ - "//testing:__pkg__", + "//testing:__subpackages__", ], ) diff --git a/testing/testrunner/BUILD.bazel b/testing/testrunner/BUILD.bazel index 0382d4c27..d5ed28f10 100644 --- a/testing/testrunner/BUILD.bazel +++ b/testing/testrunner/BUILD.bazel @@ -85,6 +85,18 @@ java_library( exports = ["//testing/src/main/java/dev/cel/testing/testrunner:registry_utils"], ) +java_library( + name = "class_loader_utils", + visibility = ["//:internal"], + exports = ["//testing/src/main/java/dev/cel/testing/utils:class_loader_utils"], +) + +java_library( + name = "proto_descriptor_utils", + visibility = ["//:internal"], + exports = ["//testing/src/main/java/dev/cel/testing/utils:proto_descriptor_utils"], +) + bzl_library( name = "cel_java_test", srcs = ["cel_java_test.bzl"], diff --git a/testing/testrunner/cel_java_test.bzl b/testing/testrunner/cel_java_test.bzl index 0480c664f..b2a19f996 100644 --- a/testing/testrunner/cel_java_test.bzl +++ b/testing/testrunner/cel_java_test.bzl @@ -18,6 +18,7 @@ load("@rules_java//java:java_binary.bzl", "java_binary") load("@rules_proto//proto:defs.bzl", "proto_descriptor_set") load("@rules_shell//shell:sh_test.bzl", "sh_test") load("@bazel_skylib//lib:paths.bzl", "paths") +load("@com_google_protobuf//bazel:java_proto_library.bzl", "java_proto_library") def cel_java_test( name, @@ -93,6 +94,12 @@ def cel_java_test( data.append(descriptor_set_path) jvm_flags.append("-Dfile_descriptor_set_path=$(location {})".format(descriptor_set_path)) + java_proto_library( + name = name + "_proto_descriptor_set_java_proto", + deps = proto_deps, + ) + deps = deps + [":" + name + "_proto_descriptor_set_java_proto"] + jvm_flags.append("-Dis_raw_expr=%s" % is_raw_expr) java_binary(