Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ java_library(
":cel_test_suite_yaml_parser",
":cel_user_test_template",
":junit_xml_reporter",
"//testing/testrunner:class_loader_utils",
"@maven//:com_google_guava_guava",
"@maven//:io_github_classgraph_classgraph",
"@maven//:junit_junit",
Expand Down Expand Up @@ -61,20 +62,20 @@ java_library(
"//:auto_value",
"//bundle:cel",
"//bundle:environment",
"//bundle:environment_exception",
"//bundle:environment_yaml_parser",
"//common:cel_ast",
"//common:cel_descriptors",
"//common:compiler_common",
"//common:options",
"//common:proto_ast",
"//common/internal:default_instance_message_factory",
"//policy",
"//policy:compiler_factory",
"//policy:parser",
"//policy:parser_factory",
"//policy:validation_exception",
"//runtime",
"//testing:expr_value_utils",
"//testing/testrunner:proto_descriptor_utils",
"@cel_spec//proto/cel/expr:expr_java_proto",
"@maven//:com_google_guava_guava",
"@maven//:com_google_protobuf_protobuf_java",
Expand Down Expand Up @@ -145,10 +146,8 @@ java_library(
tags = [
],
deps = [
"//common:cel_descriptors",
"//common/internal:cel_descriptor_pools",
"//common/internal:default_message_factory",
"@maven//:com_google_guava_guava",
"//common/internal:default_instance_message_factory",
"//testing/testrunner:proto_descriptor_utils",
"@maven//:com_google_protobuf_protobuf_java",
],
)
Expand Down Expand Up @@ -177,11 +176,9 @@ java_library(
"//:java_truth",
"//bundle:cel",
"//common:cel_ast",
"//common:cel_descriptors",
"//runtime",
"//testing:expr_value_utils",
"@cel_spec//proto/cel/expr:expr_java_proto",
"@maven//:com_google_guava_guava",
"@maven//:com_google_protobuf_protobuf_java",
"@maven//:com_google_truth_extensions_truth_proto_extension",
],
Expand All @@ -196,7 +193,6 @@ java_library(
":cel_test_suite",
":cel_test_suite_exception",
":registry_utils",
"//common:cel_descriptors",
"@cel_spec//proto/cel/expr:expr_java_proto",
"@cel_spec//proto/cel/expr/conformance/test:suite_java_proto",
"@maven//:com_google_guava_guava",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.TextFormat;
import com.google.protobuf.TextFormat.ParseException;
import com.google.protobuf.TypeRegistry;
import dev.cel.common.CelDescriptorUtil;
import dev.cel.expr.conformance.test.InputValue;
import dev.cel.expr.conformance.test.TestCase;
import dev.cel.expr.conformance.test.TestSection;
Expand Down Expand Up @@ -55,11 +53,8 @@ private TestSuite parseTestSuite(String textProto) throws IOException {
TypeRegistry typeRegistry = TypeRegistry.getEmptyTypeRegistry();
ExtensionRegistry extensionRegistry = ExtensionRegistry.getEmptyRegistry();
if (fileDescriptorSetPath != null) {
ImmutableSet<FileDescriptor> fileDescriptors =
CelDescriptorUtil.getFileDescriptorsFromFileDescriptorSet(
RegistryUtils.getFileDescriptorSet(fileDescriptorSetPath));
extensionRegistry = RegistryUtils.getExtensionRegistry(fileDescriptors);
typeRegistry = RegistryUtils.getTypeRegistry(fileDescriptors);
extensionRegistry = RegistryUtils.getExtensionRegistry();
typeRegistry = RegistryUtils.getTypeRegistry();
}
TextFormat.Parser parser = TextFormat.Parser.newBuilder().setTypeRegistry(typeRegistry).build();
TestSuite.Builder builder = TestSuite.newBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@

import dev.cel.expr.ExprValue;
import dev.cel.expr.MapValue;
import com.google.common.collect.ImmutableSet;
import com.google.protobuf.Descriptors.FileDescriptor;
import dev.cel.bundle.Cel;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelDescriptorUtil;
import dev.cel.runtime.CelEvaluationException;
import dev.cel.runtime.CelRuntime.Program;
import dev.cel.testing.testrunner.CelTestSuite.CelTestSection.CelTestCase.Output;
Expand Down Expand Up @@ -88,15 +85,10 @@ private static void assertExprValue(ExprValue exprValue, ExprValue expectedExprV
throws IOException {
String fileDescriptorSetPath = System.getProperty("file_descriptor_set_path");
if (fileDescriptorSetPath != null) {
ImmutableSet<FileDescriptor> fileDescriptors =
CelDescriptorUtil.getFileDescriptorsFromFileDescriptorSet(
RegistryUtils.getFileDescriptorSet(fileDescriptorSetPath));
assertThat(exprValue)
.ignoringRepeatedFieldOrderOfFieldDescriptors(
MapValue.getDescriptor().findFieldByName("entries"))
.unpackingAnyUsing(
RegistryUtils.getTypeRegistry(fileDescriptors),
RegistryUtils.getExtensionRegistry(fileDescriptors))
.unpackingAnyUsing(RegistryUtils.getTypeRegistry(), RegistryUtils.getExtensionRegistry())
.isEqualTo(expectedExprValue);
} else {
assertThat(exprValue)
Expand Down
71 changes: 15 additions & 56 deletions testing/src/main/java/dev/cel/testing/testrunner/RegistryUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,91 +13,50 @@
// limitations under the License.
package dev.cel.testing.testrunner;

import com.google.common.io.Files;
import com.google.protobuf.DescriptorProtos.FileDescriptorSet;
import static dev.cel.testing.utils.ProtoDescriptorUtils.getAllDescriptorsFromJvm;

import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.Message;
import com.google.protobuf.TypeRegistry;
import dev.cel.common.CelDescriptorUtil;
import dev.cel.common.CelDescriptors;
import dev.cel.common.internal.CelDescriptorPool;
import dev.cel.common.internal.DefaultDescriptorPool;
import dev.cel.common.internal.DefaultMessageFactory;
import java.io.File;
import dev.cel.common.internal.DefaultInstanceMessageFactory;
import java.io.IOException;
import java.util.NoSuchElementException;
import java.util.Set;

/** Utility class for creating registries from a file descriptor set. */
public final class RegistryUtils {

private RegistryUtils() {}

/** Returns the {@link FileDescriptorSet} for the given file descriptor set path. */
public static FileDescriptorSet getFileDescriptorSet(String fileDescriptorSetPath)
throws IOException {
// We can pass an empty extension registry here because extensions are recovered later when
// creating the extension registry in {@link #createExtensionRegistry}.
return FileDescriptorSet.parseFrom(
Files.toByteArray(new File(fileDescriptorSetPath)), ExtensionRegistry.newInstance());
}

/** Returns the {@link TypeRegistry} for the given file descriptor set. */
public static TypeRegistry getTypeRegistry(Set<FileDescriptor> fileDescriptors)
throws IOException {
return createTypeRegistry(fileDescriptors);
public static TypeRegistry getTypeRegistry() throws IOException {
return TypeRegistry.newBuilder()
.add(getAllDescriptorsFromJvm().messageTypeDescriptors())
.build();
}

/** Returns the {@link ExtensionRegistry} for the given file descriptor set. */
public static ExtensionRegistry getExtensionRegistry(Set<FileDescriptor> fileDescriptors)
throws IOException {
return createExtensionRegistry(fileDescriptors);
}

private static TypeRegistry createTypeRegistry(Set<FileDescriptor> fileDescriptors) {
CelDescriptors allDescriptors =
CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(fileDescriptors);
return TypeRegistry.newBuilder().add(allDescriptors.messageTypeDescriptors()).build();
}

private static ExtensionRegistry createExtensionRegistry(Set<FileDescriptor> fileDescriptors) {
public static ExtensionRegistry getExtensionRegistry() throws IOException {
ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance();

CelDescriptors allDescriptors =
CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(fileDescriptors);

CelDescriptorPool pool = DefaultDescriptorPool.create(allDescriptors);

// We need to create a default message factory because there would always be difference in
// reference between the default instance's descriptor and the descriptor in the pool if the
// file descriptor set is created at runtime, therefore
// we need to create a default message factory to get the default instance for each descriptor
// because it falls back to the DynamicMessages.
//
// For more details, see: b/292174333
DefaultMessageFactory defaultMessageFactory = DefaultMessageFactory.create(pool);

allDescriptors
getAllDescriptorsFromJvm()
.extensionDescriptors()
.forEach(
(descriptorName, descriptor) -> {
if (descriptor.getType().equals(FieldDescriptor.Type.MESSAGE)) {
Message.Builder defaultInstance =
defaultMessageFactory
.newBuilder(descriptor.getMessageType().getFullName())
Message output =
DefaultInstanceMessageFactory.getInstance()
.getPrototype(descriptor.getMessageType())
.orElseThrow(
() ->
new NoSuchElementException(
"Could not find a default message for: "
+ descriptor.getFullName()));
extensionRegistry.add(descriptor, defaultInstance.build());
extensionRegistry.add(descriptor, output);
} else {
extensionRegistry.add(descriptor);
}
});

return extensionRegistry;
}

private RegistryUtils() {}
}
20 changes: 8 additions & 12 deletions testing/src/main/java/dev/cel/testing/testrunner/TestExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package dev.cel.testing.testrunner;

import static com.google.common.collect.MoreCollectors.onlyElement;
import static dev.cel.testing.utils.ClassLoaderUtils.loadClassesWithMethodAnnotation;
import static dev.cel.testing.utils.ClassLoaderUtils.loadSubclasses;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.time.ZoneId.systemDefault;

Expand All @@ -23,9 +25,7 @@
import dev.cel.testing.testrunner.Annotations.TestSuiteSupplier;
import dev.cel.testing.testrunner.CelTestSuite.CelTestSection;
import dev.cel.testing.testrunner.CelTestSuite.CelTestSection.CelTestCase;
import io.github.classgraph.ClassGraph;
import io.github.classgraph.ClassInfoList;
import io.github.classgraph.ScanResult;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Method;
Expand Down Expand Up @@ -166,12 +166,8 @@ private void setStartMillis(long startMillis) {
public static void runTests() throws Exception {
String testSuitePath = System.getProperty("test_suite_path");

// Using `enableAllInfo()` to scan all class files upfront. This avoids repeated parsing
// of class files by individual methods, improving efficiency.
ScanResult scanResult = new ClassGraph().enableAllInfo().scan();

CelTestSuite testSuite;
testSuite = readCustomTestSuite(scanResult);
testSuite = readCustomTestSuite();

if (testSuitePath != null) {
if (testSuite != null) {
Expand All @@ -184,7 +180,7 @@ public static void runTests() throws Exception {
throw new IllegalArgumentException("Neither test_suite_path nor TestSuiteSupplier is set.");
}

Class<?> testClass = getUserTestClass(scanResult);
Class<?> testClass = getUserTestClass();
String envXmlFile = System.getenv("XML_OUTPUT_FILE");
JUnitXmlReporter testReporter = new JUnitXmlReporter(envXmlFile);
TestContext testContext = new TestContext();
Expand Down Expand Up @@ -245,9 +241,9 @@ public String describe() {
}
}

private static CelTestSuite readCustomTestSuite(ScanResult scanResult) throws Exception {
private static CelTestSuite readCustomTestSuite() throws Exception {
ClassInfoList classInfoList =
scanResult.getClassesWithMethodAnnotation(CEL_TESTSUITE_ANNOTATION_CLASS.getName());
loadClassesWithMethodAnnotation(CEL_TESTSUITE_ANNOTATION_CLASS.getName());
if (classInfoList.isEmpty()) {
return null;
}
Expand Down Expand Up @@ -278,8 +274,8 @@ private static Method getMethodWithAnnotation(Class<?> clazz) {
return testSuiteSupplierMethod;
}

private static Class<?> getUserTestClass(ScanResult scanResult) {
ClassInfoList subClassInfoList = scanResult.getSubclasses(CelUserTestTemplate.class);
private static Class<?> getUserTestClass() {
ClassInfoList subClassInfoList = loadSubclasses(CelUserTestTemplate.class);
if (subClassInfoList.size() != 1) {
throw new IllegalArgumentException(
"Expected 1 subclass for CelUserTestTemplate, but got "
Expand Down
Loading
Loading