Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions common/internal/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ java_library(
exports = ["//common/src/main/java/dev/cel/common/internal:default_instance_message_factory"],
)

java_library(
name = "default_instance_message_lite_factory",
exports = ["//common/src/main/java/dev/cel/common/internal:default_instance_message_lite_factory"],
)

java_library(
name = "well_known_proto",
exports = ["//common/src/main/java/dev/cel/common/internal:well_known_proto"],
Expand Down
35 changes: 33 additions & 2 deletions common/src/main/java/dev/cel/common/internal/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,25 @@ java_library(
tags = [
],
deps = [
":default_instance_message_lite_factory",
":proto_java_qualified_names",
"//common/annotations",
"@maven//:com_google_guava_guava",
"@maven//:com_google_protobuf_protobuf_java",
],
)

# keep sorted
java_library(
name = "default_instance_message_lite_factory",
srcs = ["DefaultInstanceMessageLiteFactory.java"],
tags = [
],
deps = [
":reflection_util",
"//common/annotations",
"@maven//:com_google_guava_guava",
"@maven//:com_google_protobuf_protobuf_java",
],
)

java_library(
name = "dynamic_proto",
Expand Down Expand Up @@ -274,3 +286,22 @@ java_library(
"@maven//:com_google_re2j_re2j",
],
)

java_library(
name = "proto_java_qualified_names",
srcs = ["ProtoJavaQualifiedNames.java"],
tags = [
],
deps = [
"@maven//:com_google_guava_guava",
"@maven//:com_google_protobuf_protobuf_java",
],
)

java_library(
name = "reflection_util",
srcs = ["ReflectionUtil.java"],
deps = [
"//common/annotations",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,11 @@

package dev.cel.common.internal;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.CaseFormat;
import com.google.common.base.Joiner;
import com.google.common.base.Strings;
import com.google.common.io.Files;
import com.google.protobuf.DescriptorProtos.FileOptions;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.EnumDescriptor;
import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.Descriptors.ServiceDescriptor;
import com.google.protobuf.Message;
import com.google.protobuf.MessageLite;
import dev.cel.common.annotations.Internal;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayDeque;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;

/**
* Singleton factory for creating default messages from a protobuf descriptor.
Expand All @@ -39,19 +27,11 @@
*/
@Internal
public final class DefaultInstanceMessageFactory {

// Controls how many times we should recursively inspect a nested message for building fully
// qualified java class name before aborting.
public static final int SAFE_RECURSE_LIMIT = 50;

private static final DefaultInstanceMessageFactory instance = new DefaultInstanceMessageFactory();

private final Map<String, LazyGeneratedMessageDefaultInstance> messageByDescriptorName =
new ConcurrentHashMap<>();
private static final DefaultInstanceMessageFactory INSTANCE = new DefaultInstanceMessageFactory();

/** Gets a single instance of this MessageFactory */
public static DefaultInstanceMessageFactory getInstance() {
return instance;
return INSTANCE;
}

/**
Expand All @@ -63,182 +43,29 @@ public static DefaultInstanceMessageFactory getInstance() {
* descriptor class isn't loaded in the binary.
*/
public Optional<Message> getPrototype(Descriptor descriptor) {
String descriptorName = descriptor.getFullName();
LazyGeneratedMessageDefaultInstance lazyDefaultInstance =
messageByDescriptorName.computeIfAbsent(
descriptorName,
(unused) ->
new LazyGeneratedMessageDefaultInstance(
getFullyQualifiedJavaClassName(descriptor)));

Message defaultInstance = lazyDefaultInstance.getDefaultInstance();
MessageLite defaultInstance =
DefaultInstanceMessageLiteFactory.getInstance()
.getPrototype(
descriptor.getFullName(),
ProtoJavaQualifiedNames.getFullyQualifiedJavaClassName(descriptor))
.orElse(null);
if (defaultInstance == null) {
return Optional.empty();
}
// Reference equality is intended. We want to make sure the descriptors are equal
// to guarantee types to be hermetic if linked types is disabled.
if (defaultInstance.getDescriptorForType() != descriptor) {
return Optional.empty();
}
return Optional.of(defaultInstance);
}

/**
* Retrieves the full Java class name from the given descriptor
*
* @return fully qualified class name.
* <p>Example 1: dev.cel.expr.Value
* <p>Example 2: com.google.rpc.context.AttributeContext$Resource (Nested classes)
* <p>Example 3: com.google.api.expr.cel.internal.testdata$SingleFileProto$SingleFile$Path
* (Nested class with java multiple files disabled)
*/
private String getFullyQualifiedJavaClassName(Descriptor descriptor) {
StringBuilder fullClassName = new StringBuilder();

fullClassName.append(getJavaPackageName(descriptor));

String javaOuterClass = getJavaOuterClassName(descriptor);
if (!Strings.isNullOrEmpty(javaOuterClass)) {
fullClassName.append(javaOuterClass).append("$");
}

// Recursively build the target class name in case if the message is nested.
ArrayDeque<String> classNames = new ArrayDeque<>();
Descriptor d = descriptor;

int recurseCount = 0;
while (d != null) {
classNames.push(d.getName());
d = d.getContainingType();
recurseCount++;
if (recurseCount >= SAFE_RECURSE_LIMIT) {
throw new IllegalStateException(
String.format(
"Recursion limit of %d hit while inspecting descriptor: %s",
SAFE_RECURSE_LIMIT, descriptor.getFullName()));
}
}

Joiner.on("$").appendTo(fullClassName, classNames);

return fullClassName.toString();
}

/**
* Gets the java package name from the descriptor. See
* https://developers.google.com/protocol-buffers/docs/reference/java-generated#package for rules
* on package name generation
*/
private String getJavaPackageName(Descriptor descriptor) {
FileOptions options = descriptor.getFile().getOptions();
StringBuilder javaPackageName = new StringBuilder();
if (options.hasJavaPackage()) {
javaPackageName.append(descriptor.getFile().getOptions().getJavaPackage()).append(".");
} else {
javaPackageName
// CEL-Internal-1
.append(descriptor.getFile().getPackage())
.append(".");
if (!(defaultInstance instanceof Message)) {
throw new IllegalArgumentException(
"Expected a full protobuf message, but got: " + defaultInstance.getClass());
}

// CEL-Internal-2
Message fullMessage = (Message) defaultInstance;

return javaPackageName.toString();
}

/**
* Gets a wrapping outer class name from the descriptor. The outer class name differs depending on
* the proto options set. See
* https://developers.google.com/protocol-buffers/docs/reference/java-generated#invocation
*/
private String getJavaOuterClassName(Descriptor descriptor) {
FileOptions options = descriptor.getFile().getOptions();

if (options.getJavaMultipleFiles()) {
// If java_multiple_files is enabled, protoc does not generate a wrapper outer class
return "";
}

if (options.hasJavaOuterClassname()) {
return options.getJavaOuterClassname();
} else {
// If an outer class name is not explicitly set, the name is converted into
// Pascal case based on the snake cased file name
// Ex: messages_proto.proto becomes MessagesProto
String protoFileNameWithoutExtension =
Files.getNameWithoutExtension(descriptor.getFile().getFullName());
String outerClassName =
CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.UPPER_CAMEL, protoFileNameWithoutExtension);
if (hasConflictingClassName(descriptor.getFile(), outerClassName)) {
outerClassName += "OuterClass";
}
return outerClassName;
}
}

private boolean hasConflictingClassName(FileDescriptor file, String name) {
for (EnumDescriptor enumDesc : file.getEnumTypes()) {
if (name.equals(enumDesc.getName())) {
return true;
}
}
for (ServiceDescriptor serviceDesc : file.getServices()) {
if (name.equals(serviceDesc.getName())) {
return true;
}
}
for (Descriptor messageDesc : file.getMessageTypes()) {
if (name.equals(messageDesc.getName())) {
return true;
}
}
return false;
}

/** A placeholder to lazily load the generated messages' defaultInstances. */
private static final class LazyGeneratedMessageDefaultInstance {
private final String fullClassName;
private volatile Message defaultInstance = null;
private volatile boolean loaded = false;

public LazyGeneratedMessageDefaultInstance(String fullClassName) {
this.fullClassName = fullClassName;
}

public Message getDefaultInstance() {
if (!loaded) {
synchronized (this) {
if (!loaded) {
loadDefaultInstance();
loaded = true;
}
}
}
return defaultInstance;
}

private void loadDefaultInstance() {
try {
defaultInstance =
(Message) Class.forName(fullClassName).getMethod("getDefaultInstance").invoke(null);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new LinkageError(
String.format("getDefaultInstance for class: %s failed.", fullClassName), e);
} catch (NoSuchMethodException e) {
throw new LinkageError(
String.format("getDefaultInstance method does not exist in class: %s.", fullClassName),
e);
} catch (ClassNotFoundException e) {
// The class may not exist in some instances (Ex: evaluating a checked expression from a
// cached source).
}
// Reference equality is intended. We want to make sure the descriptors are equal
// to guarantee types to be hermetic if linked types is disabled.
if (fullMessage.getDescriptorForType() != descriptor) {
return Optional.empty();
}
}

/** Clears the descriptor map. This should not be used outside testing. */
@VisibleForTesting
void resetDescriptorMapForTesting() {
messageByDescriptorName.clear();
return Optional.of(fullMessage);
}

private DefaultInstanceMessageFactory() {}
Expand Down
Loading