diff --git a/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java b/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java index 3eed49257..c1a096930 100644 --- a/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java +++ b/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java @@ -139,7 +139,7 @@ public Object adaptProtoToValue(MessageOrBuilder proto) { // If the proto is not a well-known type, then the input Message is what's expected as the // output return value. WellKnownProto wellKnownProto = - WellKnownProto.getByTypeName(typeName(proto.getDescriptorForType())); + WellKnownProto.getByTypeName(typeName(proto.getDescriptorForType())).orElse(null); if (wellKnownProto == null) { return proto; } @@ -280,7 +280,7 @@ private BidiConverter fieldToValueConverter(FieldDescriptor fieldDescriptor) { * considered, such as a packing an {@code google.protobuf.StringValue} into a {@code Any} value. */ public Message adaptValueToProto(Object value, String protoTypeName) { - WellKnownProto wellKnownProto = WellKnownProto.getByTypeName(protoTypeName); + WellKnownProto wellKnownProto = WellKnownProto.getByTypeName(protoTypeName).orElse(null); if (wellKnownProto == null) { if (value instanceof Message) { return (Message) value; @@ -326,8 +326,7 @@ private static boolean isWrapperType(FieldDescriptor fieldDescriptor) { return false; } String fieldTypeName = fieldDescriptor.getMessageType().getFullName(); - WellKnownProto wellKnownProto = WellKnownProto.getByTypeName(fieldTypeName); - return wellKnownProto != null && wellKnownProto.isWrapperType(); + return WellKnownProto.isWrapperType(fieldTypeName); } private static int intCheckedCast(long value) { diff --git a/common/src/main/java/dev/cel/common/internal/WellKnownProto.java b/common/src/main/java/dev/cel/common/internal/WellKnownProto.java index 476891181..78041e3be 100644 --- a/common/src/main/java/dev/cel/common/internal/WellKnownProto.java +++ b/common/src/main/java/dev/cel/common/internal/WellKnownProto.java @@ -36,8 +36,8 @@ import com.google.protobuf.UInt64Value; import com.google.protobuf.Value; import dev.cel.common.annotations.Internal; +import java.util.Optional; import java.util.function.Function; -import org.jspecify.annotations.Nullable; /** * WellKnownProto types used throughout CEL. These types are specially handled to ensure that @@ -46,77 +46,79 @@ */ @Internal public enum WellKnownProto { - ANY_VALUE("google.protobuf.Any", Any.class.getName()), - DURATION("google.protobuf.Duration", Duration.class.getName()), - JSON_LIST_VALUE("google.protobuf.ListValue", ListValue.class.getName()), - JSON_STRUCT_VALUE("google.protobuf.Struct", Struct.class.getName()), - JSON_VALUE("google.protobuf.Value", Value.class.getName()), - TIMESTAMP("google.protobuf.Timestamp", Timestamp.class.getName()), + ANY_VALUE("google.protobuf.Any", Any.class), + DURATION("google.protobuf.Duration", Duration.class), + JSON_LIST_VALUE("google.protobuf.ListValue", ListValue.class), + JSON_STRUCT_VALUE("google.protobuf.Struct", Struct.class), + JSON_VALUE("google.protobuf.Value", Value.class), + TIMESTAMP("google.protobuf.Timestamp", Timestamp.class), // Wrapper types - FLOAT_VALUE("google.protobuf.FloatValue", FloatValue.class.getName(), /* isWrapperType= */ true), - INT32_VALUE("google.protobuf.Int32Value", Int32Value.class.getName(), /* isWrapperType= */ true), - INT64_VALUE("google.protobuf.Int64Value", Int64Value.class.getName(), /* isWrapperType= */ true), - STRING_VALUE( - "google.protobuf.StringValue", StringValue.class.getName(), /* isWrapperType= */ true), - BOOL_VALUE("google.protobuf.BoolValue", BoolValue.class.getName(), /* isWrapperType= */ true), - BYTES_VALUE("google.protobuf.BytesValue", BytesValue.class.getName(), /* isWrapperType= */ true), - DOUBLE_VALUE( - "google.protobuf.DoubleValue", DoubleValue.class.getName(), /* isWrapperType= */ true), - UINT32_VALUE( - "google.protobuf.UInt32Value", UInt32Value.class.getName(), /* isWrapperType= */ true), - UINT64_VALUE( - "google.protobuf.UInt64Value", UInt64Value.class.getName(), /* isWrapperType= */ true), + FLOAT_VALUE("google.protobuf.FloatValue", FloatValue.class, /* isWrapperType= */ true), + INT32_VALUE("google.protobuf.Int32Value", Int32Value.class, /* isWrapperType= */ true), + INT64_VALUE("google.protobuf.Int64Value", Int64Value.class, /* isWrapperType= */ true), + STRING_VALUE("google.protobuf.StringValue", StringValue.class, /* isWrapperType= */ true), + BOOL_VALUE("google.protobuf.BoolValue", BoolValue.class, /* isWrapperType= */ true), + BYTES_VALUE("google.protobuf.BytesValue", BytesValue.class, /* isWrapperType= */ true), + DOUBLE_VALUE("google.protobuf.DoubleValue", DoubleValue.class, /* isWrapperType= */ true), + UINT32_VALUE("google.protobuf.UInt32Value", UInt32Value.class, /* isWrapperType= */ true), + UINT64_VALUE("google.protobuf.UInt64Value", UInt64Value.class, /* isWrapperType= */ true), // These aren't explicitly called out as wrapper types in the spec, but behave like one, because // they are still converted into an equivalent primitive type. - EMPTY("google.protobuf.Empty", Empty.class.getName(), /* isWrapperType= */ true), - FIELD_MASK("google.protobuf.FieldMask", FieldMask.class.getName(), /* isWrapperType= */ true), + EMPTY("google.protobuf.Empty", Empty.class, /* isWrapperType= */ true), + FIELD_MASK("google.protobuf.FieldMask", FieldMask.class, /* isWrapperType= */ true), ; - private static final ImmutableMap WELL_KNOWN_PROTO_MAP; + private static final ImmutableMap TYPE_NAME_TO_WELL_KNOWN_PROTO_MAP = + stream(WellKnownProto.values()) + .collect(toImmutableMap(WellKnownProto::typeName, Function.identity())); - static { - WELL_KNOWN_PROTO_MAP = - stream(WellKnownProto.values()) - .collect(toImmutableMap(WellKnownProto::typeName, Function.identity())); - } + private static final ImmutableMap, WellKnownProto> + CLASS_TO_NAME_TO_WELL_KNOWN_PROTO_MAP = + stream(WellKnownProto.values()) + .collect(toImmutableMap(WellKnownProto::messageClass, Function.identity())); - private final String wellKnownProtoFullName; - private final String javaClassName; + private final String wellKnownProtoTypeName; + private final Class clazz; private final boolean isWrapperType; + /** Gets the fully qualified prototype name (ex: google.protobuf.FloatValue) */ public String typeName() { - return wellKnownProtoFullName; + return wellKnownProtoTypeName; } - public String javaClassName() { - return this.javaClassName; + /** Gets the underlying java class for this WellKnownProto. */ + public Class messageClass() { + return clazz; } - public static @Nullable WellKnownProto getByTypeName(String typeName) { - return WELL_KNOWN_PROTO_MAP.get(typeName); + public static Optional getByTypeName(String typeName) { + return Optional.ofNullable(TYPE_NAME_TO_WELL_KNOWN_PROTO_MAP.get(typeName)); } - public static boolean isWrapperType(String typeName) { - WellKnownProto wellKnownProto = getByTypeName(typeName); - if (wellKnownProto == null) { - return false; - } + public static Optional getByClass(Class clazz) { + return Optional.ofNullable(CLASS_TO_NAME_TO_WELL_KNOWN_PROTO_MAP.get(clazz)); + } - return wellKnownProto.isWrapperType(); + /** + * Returns true if the provided {@code typeName} is a well known type, and it's a wrapper. False + * otherwise. + */ + public static boolean isWrapperType(String typeName) { + return getByTypeName(typeName).map(WellKnownProto::isWrapperType).orElse(false); } public boolean isWrapperType() { return isWrapperType; } - WellKnownProto(String wellKnownProtoFullName, String javaClassName) { - this(wellKnownProtoFullName, javaClassName, /* isWrapperType= */ false); + WellKnownProto(String wellKnownProtoTypeName, Class clazz) { + this(wellKnownProtoTypeName, clazz, /* isWrapperType= */ false); } - WellKnownProto(String wellKnownProtoFullName, String javaClassName, boolean isWrapperType) { - this.wellKnownProtoFullName = wellKnownProtoFullName; - this.javaClassName = javaClassName; + WellKnownProto(String wellKnownProtoFullName, Class clazz, boolean isWrapperType) { + this.wellKnownProtoTypeName = wellKnownProtoFullName; + this.clazz = clazz; this.isWrapperType = isWrapperType; } } diff --git a/common/src/main/java/dev/cel/common/values/ProtoCelValueConverter.java b/common/src/main/java/dev/cel/common/values/ProtoCelValueConverter.java index 16d1a8956..16f5cc215 100644 --- a/common/src/main/java/dev/cel/common/values/ProtoCelValueConverter.java +++ b/common/src/main/java/dev/cel/common/values/ProtoCelValueConverter.java @@ -107,7 +107,7 @@ public CelValue fromProtoMessageToCelValue(MessageOrBuilder message) { } WellKnownProto wellKnownProto = - WellKnownProto.getByTypeName(message.getDescriptorForType().getFullName()); + WellKnownProto.getByTypeName(message.getDescriptorForType().getFullName()).orElse(null); if (wellKnownProto == null) { return ProtoMessageValue.create((Message) message, celDescriptorPool, this); } diff --git a/common/src/test/java/dev/cel/common/internal/WellKnownProtoTest.java b/common/src/test/java/dev/cel/common/internal/WellKnownProtoTest.java index 7fc6eabfd..bb75a341a 100644 --- a/common/src/test/java/dev/cel/common/internal/WellKnownProtoTest.java +++ b/common/src/test/java/dev/cel/common/internal/WellKnownProtoTest.java @@ -16,8 +16,10 @@ import static com.google.common.truth.Truth.assertThat; +import com.google.protobuf.FloatValue; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; +import java.util.List; import org.junit.Test; import org.junit.runner.RunWith; @@ -53,7 +55,12 @@ public void isWrapperType_withTypeName_false(String typeName) { } @Test - public void getJavaClassName() { - assertThat(WellKnownProto.ANY_VALUE.javaClassName()).isEqualTo("com.google.protobuf.Any"); + public void getByClass_success() { + assertThat(WellKnownProto.getByClass(FloatValue.class)).hasValue(WellKnownProto.FLOAT_VALUE); + } + + @Test + public void getByClass_unknownClass_returnsEmpty() { + assertThat(WellKnownProto.getByClass(List.class)).isEmpty(); } }