From 3d1f608a05814d8f4491bae13b284fa1eb5622e3 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Fri, 30 Jan 2026 14:47:16 -0800 Subject: [PATCH] Add json_name protobuf field option support PiperOrigin-RevId: 863401628 --- .../src/test/java/dev/cel/bundle/BUILD.bazel | 2 +- .../test/java/dev/cel/bundle/CelImplTest.java | 85 +++++++++- .../src/main/java/dev/cel/checker/BUILD.bazel | 1 + .../dev/cel/checker/CelCheckerLegacyImpl.java | 9 +- .../java/dev/cel/checker/ExprChecker.java | 59 ++++++- .../main/java/dev/cel/common/CelOptions.java | 13 ++ .../cel/common/types/ProtoMessageType.java | 26 ++- .../types/ProtoMessageTypeProvider.java | 151 +++++++++++++++--- .../java/dev/cel/common/types/BUILD.bazel | 1 + .../types/ProtoMessageTypeProviderTest.java | 21 +++ .../common/types/ProtoMessageTypeTest.java | 3 +- .../runtime/DescriptorMessageProvider.java | 9 ++ .../src/test/java/dev/cel/runtime/BUILD.bazel | 17 +- .../cel/runtime/CelLiteInterpreterTest.java | 27 ++-- .../cel/runtime/CelValueInterpreterTest.java | 37 ----- .../src/main/java/dev/cel/testing/BUILD.bazel | 1 + .../dev/cel/testing/BaseInterpreterTest.java | 52 ++++-- .../dev/cel/testing/CelBaselineTestCase.java | 11 +- .../test/resources/protos/single_file.proto | 1 + 19 files changed, 402 insertions(+), 124 deletions(-) delete mode 100644 runtime/src/test/java/dev/cel/runtime/CelValueInterpreterTest.java diff --git a/bundle/src/test/java/dev/cel/bundle/BUILD.bazel b/bundle/src/test/java/dev/cel/bundle/BUILD.bazel index 26cbe392d..cd33dd67d 100644 --- a/bundle/src/test/java/dev/cel/bundle/BUILD.bazel +++ b/bundle/src/test/java/dev/cel/bundle/BUILD.bazel @@ -34,7 +34,6 @@ java_library( "//common:proto_ast", "//common:source_location", "//common/ast", - "//common/internal:proto_time_utils", "//common/resources/testdata/proto3:standalone_global_enum_java_proto", "//common/testing", "//common/types", @@ -55,6 +54,7 @@ java_library( "//runtime:evaluation_listener", "//runtime:function_binding", "//runtime:unknown_attributes", + "//testing/protos:single_file_java_proto", "@cel_spec//proto/cel/expr:checked_java_proto", "@cel_spec//proto/cel/expr:syntax_java_proto", "@cel_spec//proto/cel/expr/conformance/proto2:test_all_types_java_proto", diff --git a/bundle/src/test/java/dev/cel/bundle/CelImplTest.java b/bundle/src/test/java/dev/cel/bundle/CelImplTest.java index 13c3392d3..0708b37c3 100644 --- a/bundle/src/test/java/dev/cel/bundle/CelImplTest.java +++ b/bundle/src/test/java/dev/cel/bundle/CelImplTest.java @@ -71,6 +71,7 @@ import dev.cel.common.CelIssue; import dev.cel.common.CelOptions; import dev.cel.common.CelProtoAbstractSyntaxTree; +import dev.cel.common.CelSource.Extension; import dev.cel.common.CelSourceLocation; import dev.cel.common.CelValidationException; import dev.cel.common.CelValidationResult; @@ -112,6 +113,7 @@ import dev.cel.runtime.CelUnknownSet; import dev.cel.runtime.CelVariableResolver; import dev.cel.runtime.UnknownContext; +import dev.cel.testing.testdata.SingleFileProto.SingleFile; import dev.cel.testing.testdata.proto3.StandaloneGlobalEnum; import java.time.Instant; import java.util.ArrayList; @@ -743,7 +745,7 @@ public void program_withThrowingFunction() throws Exception { CelFunctionBinding.from( "throws", ImmutableList.of(), - (args) -> { + (unused) -> { throw new CelEvaluationException("this method always throws"); })) .setResultType(SimpleType.BOOL) @@ -771,7 +773,7 @@ public void program_withThrowingFunctionShortcircuited() throws Exception { CelFunctionBinding.from( "throws", ImmutableList.of(), - (args) -> { + (unused) -> { throw CelEvaluationExceptionBuilder.newBuilder("this method always throws") .setCause(new RuntimeException("reason")) .build(); @@ -1143,7 +1145,7 @@ public void program_customVarResolver() throws Exception { program.eval( (name) -> name.equals("variable") ? Optional.of("hello") : Optional.empty())) .isEqualTo(true); - assertThat(program.eval((name) -> Optional.of(""))).isEqualTo(false); + assertThat(program.eval((unused) -> Optional.of(""))).isEqualTo(false); } @Test @@ -2193,6 +2195,83 @@ public void toBuilder_isImmutable() { assertThat(newRuntimeBuilder).isNotEqualTo(celImpl.toRuntimeBuilder()); } + @Test + public void eval_withJsonFieldName() throws Exception { + Cel cel = + standardCelBuilderWithMacros() + .addVar("file", StructTypeReference.create(SingleFile.getDescriptor().getFullName())) + .addMessageTypes(SingleFile.getDescriptor()) + .setOptions(CelOptions.current().enableJsonFieldNames(true).build()) + .build(); + CelAbstractSyntaxTree ast = cel.compile("file.camelCased").getAst(); + + Object result = + cel.createProgram(ast) + .eval(ImmutableMap.of("file", SingleFile.newBuilder().setSnakeCased("foo").build())); + + assertThat(result).isEqualTo("foo"); + } + + @Test + public void eval_withJsonFieldName_runtimeOptionDisabled_throws() throws Exception { + CelCompiler celCompiler = + CelCompilerFactory.standardCelCompilerBuilder() + .addVar("file", StructTypeReference.create(SingleFile.getDescriptor().getFullName())) + .addMessageTypes(SingleFile.getDescriptor()) + .setOptions(CelOptions.current().enableJsonFieldNames(true).build()) + .build(); + CelRuntime celRuntime = + CelRuntimeFactory.standardCelRuntimeBuilder() + .addMessageTypes(SingleFile.getDescriptor()) + .setOptions(CelOptions.current().enableJsonFieldNames(false).build()) + .build(); + CelAbstractSyntaxTree ast = celCompiler.compile("file.camelCased").getAst(); + + CelEvaluationException e = + assertThrows( + CelEvaluationException.class, + () -> + celRuntime + .createProgram(ast) + .eval(ImmutableMap.of("file", SingleFile.getDefaultInstance()))); + assertThat(e) + .hasMessageThat() + .contains( + "field 'camelCased' is not declared in message 'dev.cel.testing.testdata.SingleFile"); + } + + @Test + public void compile_withJsonFieldName_astTagged() throws Exception { + Cel cel = + standardCelBuilderWithMacros() + .addVar("file", StructTypeReference.create(SingleFile.getDescriptor().getFullName())) + .addMessageTypes(SingleFile.getDescriptor()) + .setOptions(CelOptions.current().enableJsonFieldNames(true).build()) + .build(); + CelAbstractSyntaxTree ast = cel.compile("file.camelCased").getAst(); + + assertThat(ast.getSource().getExtensions()) + .contains( + Extension.create( + "json_name", Extension.Version.of(1L, 1L), Extension.Component.COMPONENT_RUNTIME)); + } + + @Test + public void compile_withJsonFieldName_protoFieldNameComparison_throws() throws Exception { + Cel cel = + standardCelBuilderWithMacros() + .addVar("file", StructTypeReference.create(SingleFile.getDescriptor().getFullName())) + .addMessageTypes(SingleFile.getDescriptor()) + .setOptions(CelOptions.current().enableJsonFieldNames(true).build()) + .build(); + + CelValidationException e = + assertThrows( + CelValidationException.class, + () -> cel.compile("file.camelCased == file.snake_cased").getAst()); + assertThat(e).hasMessageThat().contains("undefined field 'snake_cased'"); + } + private static TypeProvider aliasingProvider(ImmutableMap typeAliases) { return new TypeProvider() { @Override diff --git a/checker/src/main/java/dev/cel/checker/BUILD.bazel b/checker/src/main/java/dev/cel/checker/BUILD.bazel index 6c486bd92..91a409a00 100644 --- a/checker/src/main/java/dev/cel/checker/BUILD.bazel +++ b/checker/src/main/java/dev/cel/checker/BUILD.bazel @@ -177,6 +177,7 @@ java_library( ":standard_decl", "//:auto_value", "//common:cel_ast", + "//common:cel_source", "//common:compiler_common", "//common:container", "//common:operator", diff --git a/checker/src/main/java/dev/cel/checker/CelCheckerLegacyImpl.java b/checker/src/main/java/dev/cel/checker/CelCheckerLegacyImpl.java index 41d1ca073..df8a82f43 100644 --- a/checker/src/main/java/dev/cel/checker/CelCheckerLegacyImpl.java +++ b/checker/src/main/java/dev/cel/checker/CelCheckerLegacyImpl.java @@ -456,9 +456,12 @@ public CelCheckerLegacyImpl build() { } CelTypeProvider messageTypeProvider = - new ProtoMessageTypeProvider( - CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( - fileTypeSet, celOptions.resolveTypeDependencies())); + ProtoMessageTypeProvider.newBuilder() + .setAllowJsonFieldNames(celOptions.enableJsonFieldNames()) + .setResolveTypeDependencies(celOptions.resolveTypeDependencies()) + .addFileDescriptors(fileTypeSet) + .build(); + if (celTypeProvider != null && fileTypeSet.isEmpty()) { messageTypeProvider = celTypeProvider; } else if (celTypeProvider != null) { diff --git a/checker/src/main/java/dev/cel/checker/ExprChecker.java b/checker/src/main/java/dev/cel/checker/ExprChecker.java index 37b692ecf..4cde31922 100644 --- a/checker/src/main/java/dev/cel/checker/ExprChecker.java +++ b/checker/src/main/java/dev/cel/checker/ExprChecker.java @@ -31,6 +31,7 @@ import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOverloadDecl; import dev.cel.common.CelProtoAbstractSyntaxTree; +import dev.cel.common.CelSource; import dev.cel.common.Operator; import dev.cel.common.annotations.Internal; import dev.cel.common.ast.CelConstant; @@ -43,12 +44,14 @@ import dev.cel.common.types.ListType; import dev.cel.common.types.MapType; import dev.cel.common.types.OptionalType; +import dev.cel.common.types.ProtoMessageType; import dev.cel.common.types.SimpleType; import dev.cel.common.types.TypeType; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import org.jspecify.annotations.Nullable; /** @@ -61,6 +64,11 @@ @Internal @Deprecated public final class ExprChecker { + private static final CelSource.Extension JSON_NAME_EXTENSION = + CelSource.Extension.create( + "json_name", + CelSource.Extension.Version.of(1, 1), + CelSource.Extension.Component.COMPONENT_RUNTIME); /** * Deprecated type-check API. @@ -139,7 +147,11 @@ public static CelAbstractSyntaxTree typecheck( Map typeMap = Maps.transformValues(env.getTypeMap(), checker.inferenceContext::finalize); - return CelAbstractSyntaxTree.newCheckedAst(expr, ast.getSource(), env.getRefMap(), typeMap); + return CelAbstractSyntaxTree.newCheckedAst( + expr, + ast.getSource().toBuilder().addAllExtensions(checker.extensions).build(), + env.getRefMap(), + typeMap); } private final Env env; @@ -150,6 +162,7 @@ public static CelAbstractSyntaxTree typecheck( private final boolean compileTimeOverloadResolution; private final boolean homogeneousLiterals; private final boolean namespacedDeclarations; + private final Set extensions; private ExprChecker( Env env, @@ -167,6 +180,7 @@ private ExprChecker( this.compileTimeOverloadResolution = compileTimeOverloadResolution; this.homogeneousLiterals = homogeneousLiterals; this.namespacedDeclarations = namespacedDeclarations; + this.extensions = new HashSet<>(); } /** Visit the {@code expr} value, routing to overloads based on the kind of expression. */ @@ -376,13 +390,13 @@ private CelExpr visit(CelExpr expr, CelExpr.CelStruct struct) { env.setRef(expr, CelReference.newBuilder().setName(decl.name()).build()); CelType type = decl.type(); - if (type.kind() != CelKind.ERROR) { - if (type.kind() != CelKind.TYPE) { + if (!type.kind().equals(CelKind.ERROR)) { + if (!type.kind().equals(CelKind.TYPE)) { // expected type of types env.reportError(expr.id(), getPosition(expr), "'%s' is not a type", CelTypes.format(type)); } else { messageType = ((TypeType) type).type(); - if (messageType.kind() != CelKind.STRUCT) { + if (!messageType.kind().equals(CelKind.STRUCT)) { env.reportError( expr.id(), getPosition(expr), @@ -726,14 +740,18 @@ private CelType visitSelectField( } if (!Types.isDynOrError(operandType)) { - if (operandType.kind() == CelKind.STRUCT) { + if (operandType.kind().equals(CelKind.STRUCT)) { TypeProvider.FieldType fieldType = getFieldType(expr.id(), getPosition(expr), operandType, field); + ProtoMessageType protoMessageType = resolveProtoMessageType(operandType); + if (protoMessageType != null && protoMessageType.isJsonName(field)) { + extensions.add(JSON_NAME_EXTENSION); + } // Type of the field resultType = fieldType.celType(); - } else if (operandType.kind() == CelKind.MAP) { + } else if (operandType.kind().equals(CelKind.MAP)) { resultType = ((MapType) operandType).valueType(); - } else if (operandType.kind() == CelKind.TYPE_PARAM) { + } else if (operandType.kind().equals(CelKind.TYPE_PARAM)) { // Mark the operand as type DYN to avoid cases where the free type variable might take on // an incorrect type if used in multiple locations. // @@ -763,6 +781,33 @@ private CelType visitSelectField( return resultType; } + private @Nullable ProtoMessageType resolveProtoMessageType(CelType operandType) { + if (operandType instanceof ProtoMessageType) { + return (ProtoMessageType) operandType; + } + + if (operandType.kind().equals(CelKind.STRUCT)) { + // This is either a StructTypeReference or just a Struct. Attempt to search for + // ProtoMessageType that may exist in in the type provider. + TypeType typeDef = + typeProvider + .lookupCelType(operandType.name()) + .filter(t -> t instanceof TypeType) + .map(TypeType.class::cast) + .orElse(null); + if (typeDef == null || typeDef.parameters().size() != 1) { + return null; + } + + CelType maybeProtoMessageType = typeDef.parameters().get(0); + if (maybeProtoMessageType instanceof ProtoMessageType) { + return (ProtoMessageType) maybeProtoMessageType; + } + } + + return null; + } + private CelExpr visitOptionalCall(CelExpr expr, CelExpr.CelCall call) { CelExpr operand = call.args().get(0); CelExpr field = call.args().get(1); diff --git a/common/src/main/java/dev/cel/common/CelOptions.java b/common/src/main/java/dev/cel/common/CelOptions.java index d39d53803..d0b020697 100644 --- a/common/src/main/java/dev/cel/common/CelOptions.java +++ b/common/src/main/java/dev/cel/common/CelOptions.java @@ -83,6 +83,8 @@ public enum ProtoUnsetFieldOptions { public abstract boolean enableNamespacedDeclarations(); + public abstract boolean enableJsonFieldNames(); + // Evaluation related options public abstract boolean disableCelStandardEquality(); @@ -150,6 +152,7 @@ public static Builder newBuilder() { .enableTimestampEpoch(false) .enableHeterogeneousNumericComparisons(false) .enableNamespacedDeclarations(true) + .enableJsonFieldNames(false) // Evaluation options .disableCelStandardEquality(true) .evaluateCanonicalTypesToNativeValues(false) @@ -529,6 +532,16 @@ public abstract static class Builder { */ public abstract Builder maxRegexProgramSize(int value); + /** + * Use the `json_name` field option on a protobuf message as the name of the field. + * + *

If enabled, the compiler will only accept the `json_name` and no longer recognize the + * original protobuf field name. Use with caution as this may break existing expressions during + * compilation. The runtime continues to support both names for maintaining backwards + * compatibility. + */ + public abstract Builder enableJsonFieldNames(boolean value); + public abstract CelOptions build(); } } diff --git a/common/src/main/java/dev/cel/common/types/ProtoMessageType.java b/common/src/main/java/dev/cel/common/types/ProtoMessageType.java index 7ac3f4fd3..11e48bbe6 100644 --- a/common/src/main/java/dev/cel/common/types/ProtoMessageType.java +++ b/common/src/main/java/dev/cel/common/types/ProtoMessageType.java @@ -29,14 +29,17 @@ public final class ProtoMessageType extends StructType { private final StructType.FieldResolver extensionResolver; + private final JsonNameResolver jsonNameResolver; ProtoMessageType( String name, ImmutableSet fieldNames, StructType.FieldResolver fieldResolver, - StructType.FieldResolver extensionResolver) { + StructType.FieldResolver extensionResolver, + JsonNameResolver jsonNameResolver) { super(name, fieldNames, fieldResolver); this.extensionResolver = extensionResolver; + this.jsonNameResolver = jsonNameResolver; } /** Find an {@code Extension} by its fully-qualified {@code extensionName}. */ @@ -46,20 +49,35 @@ public Optional findExtension(String extensionName) { .map(type -> Extension.of(extensionName, type, this)); } + /** Returns true if the field name is a json name. */ + public boolean isJsonName(String fieldName) { + return jsonNameResolver.isJsonName(fieldName); + } + /** * Create a new instance of the {@code ProtoMessageType} using the {@code visibleFields} set as a * mask of the fields from the backing proto. */ public ProtoMessageType withVisibleFields(ImmutableSet visibleFields) { - return new ProtoMessageType(name, visibleFields, fieldResolver, extensionResolver); + return new ProtoMessageType( + name, visibleFields, fieldResolver, extensionResolver, jsonNameResolver); } public static ProtoMessageType create( String name, ImmutableSet fieldNames, FieldResolver fieldResolver, - FieldResolver extensionResolver) { - return new ProtoMessageType(name, fieldNames, fieldResolver, extensionResolver); + FieldResolver extensionResolver, + JsonNameResolver jsonNameResolver) { + return new ProtoMessageType( + name, fieldNames, fieldResolver, extensionResolver, jsonNameResolver); + } + + /** Functional interface for resolving whether a field name is a json name. */ + @FunctionalInterface + @Immutable + public interface JsonNameResolver { + boolean isJsonName(String fieldName); } /** {@code Extension} contains the name, type, and target message type of the extension. */ diff --git a/common/src/main/java/dev/cel/common/types/ProtoMessageTypeProvider.java b/common/src/main/java/dev/cel/common/types/ProtoMessageTypeProvider.java index 4b97178d0..0c49f1fd8 100644 --- a/common/src/main/java/dev/cel/common/types/ProtoMessageTypeProvider.java +++ b/common/src/main/java/dev/cel/common/types/ProtoMessageTypeProvider.java @@ -14,15 +14,14 @@ package dev.cel.common.types; -import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import com.google.common.collect.ImmutableCollection; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.CheckReturnValue; import com.google.errorprone.annotations.Immutable; import com.google.protobuf.DescriptorProtos.FileDescriptorSet; @@ -34,11 +33,11 @@ import dev.cel.common.CelDescriptorUtil; import dev.cel.common.CelDescriptors; import dev.cel.common.internal.FileDescriptorSetConverter; +import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.Map; import java.util.Optional; -import java.util.function.Function; /** * The {@code ProtoMessageTypeProvider} implements the {@link CelTypeProvider} interface to provide @@ -68,28 +67,53 @@ public final class ProtoMessageTypeProvider implements CelTypeProvider { .buildOrThrow(); private final ImmutableMap allTypes; + private final boolean allowJsonFieldNames; + /** Returns a new builder for {@link ProtoMessageTypeProvider}. */ + public static Builder newBuilder() { + return new Builder(); + } + + /** + * @deprecated Use {@link #newBuilder()} instead. + */ + @Deprecated public ProtoMessageTypeProvider() { - this(CelDescriptors.builder().build()); + this(CelDescriptors.builder().build(), false); } + /** + * @deprecated Use {@link #newBuilder()} instead. + */ + @Deprecated public ProtoMessageTypeProvider(FileDescriptorSet descriptorSet) { this( CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( - FileDescriptorSetConverter.convert(descriptorSet))); + FileDescriptorSetConverter.convert(descriptorSet)), + false); } + /** + * @deprecated Use {@link #newBuilder()} instead. + */ + @Deprecated public ProtoMessageTypeProvider(Iterable descriptors) { this( CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( - ImmutableSet.copyOf(Iterables.transform(descriptors, Descriptor::getFile)))); + ImmutableSet.copyOf(Iterables.transform(descriptors, Descriptor::getFile))), + false); } + /** + * @deprecated Use {@link #newBuilder()} instead. + */ + @Deprecated public ProtoMessageTypeProvider(ImmutableSet fileDescriptors) { - this(CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(fileDescriptors)); + this(CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(fileDescriptors), false); } - public ProtoMessageTypeProvider(CelDescriptors celDescriptors) { + private ProtoMessageTypeProvider(CelDescriptors celDescriptors, boolean allowJsonFieldNames) { + this.allowJsonFieldNames = allowJsonFieldNames; this.allTypes = ImmutableMap.builder() .putAll(createEnumTypes(celDescriptors.enumDescriptors())) @@ -120,8 +144,17 @@ private ImmutableMap createProtoMessageTypes( if (protoMessageTypes.containsKey(descriptor.getFullName())) { continue; } - ImmutableList fieldNames = - descriptor.getFields().stream().map(FieldDescriptor::getName).collect(toImmutableList()); + + ImmutableSet.Builder fieldNamesBuilder = ImmutableSet.builder(); + ImmutableSet.Builder jsonNamesBuilder = ImmutableSet.builder(); + for (FieldDescriptor fd : descriptor.getFields()) { + fieldNamesBuilder.add(fd.getName()); + if (allowJsonFieldNames) { + fieldNamesBuilder.add(fd.getJsonName()); + jsonNamesBuilder.add(fd.getJsonName()); + } + } + ImmutableSet jsonNames = jsonNamesBuilder.build(); Map extensionFields = new HashMap<>(); for (FieldDescriptor extension : extensionMap.get(descriptor.getFullName())) { @@ -133,9 +166,10 @@ private ImmutableMap createProtoMessageTypes( descriptor.getFullName(), ProtoMessageType.create( descriptor.getFullName(), - ImmutableSet.copyOf(fieldNames), + fieldNamesBuilder.build(), new FieldResolver(this, descriptor)::findField, - new FieldResolver(this, extensions)::findField)); + new FieldResolver(this, extensions)::findField, + jsonNames::contains)); } return ImmutableMap.copyOf(protoMessageTypes); } @@ -158,19 +192,34 @@ private ImmutableMap createEnumTypes( } private static class FieldResolver { - private final CelTypeProvider celTypeProvider; + private final ProtoMessageTypeProvider protoMessageTypeProvider; private final ImmutableMap fields; - private FieldResolver(CelTypeProvider celTypeProvider, Descriptor descriptor) { + private static ImmutableMap collectFieldDescriptorMap( + ProtoMessageTypeProvider protoMessageTypeProvider, Descriptor descriptor) { + ImmutableMap.Builder builder = ImmutableMap.builder(); + for (FieldDescriptor fd : descriptor.getFields()) { + if (protoMessageTypeProvider.allowJsonFieldNames && !fd.getJsonName().isEmpty()) { + builder.put(fd.getJsonName(), fd); + } else { + builder.put(fd.getName(), fd); + } + } + + return builder.buildOrThrow(); + } + + private FieldResolver( + ProtoMessageTypeProvider protoMessageTypeProvider, Descriptor descriptor) { this( - celTypeProvider, - descriptor.getFields().stream() - .collect(toImmutableMap(FieldDescriptor::getName, Function.identity()))); + protoMessageTypeProvider, + collectFieldDescriptorMap(protoMessageTypeProvider, descriptor)); } private FieldResolver( - CelTypeProvider celTypeProvider, ImmutableMap fields) { - this.celTypeProvider = celTypeProvider; + ProtoMessageTypeProvider protoMessageTypeProvider, + ImmutableMap fields) { + this.protoMessageTypeProvider = protoMessageTypeProvider; this.fields = fields; } @@ -203,11 +252,11 @@ private Optional findFieldInternal(FieldDescriptor fieldDescriptor) { String messageName = descriptor.getFullName(); fieldType = CelTypes.getWellKnownCelType(messageName) - .orElse(celTypeProvider.findType(descriptor.getFullName()).orElse(null)); + .orElse(protoMessageTypeProvider.findType(descriptor.getFullName()).orElse(null)); break; case ENUM: EnumDescriptor enumDescriptor = fieldDescriptor.getEnumType(); - fieldType = celTypeProvider.findType(enumDescriptor.getFullName()).orElse(null); + fieldType = protoMessageTypeProvider.findType(enumDescriptor.getFullName()).orElse(null); break; default: fieldType = PROTO_TYPE_TO_CEL_TYPE.get(fieldDescriptor.getType()); @@ -222,4 +271,64 @@ private Optional findFieldInternal(FieldDescriptor fieldDescriptor) { return Optional.of(fieldType); } } + + /** Builder for {@link ProtoMessageTypeProvider}. */ + public static final class Builder { + private final ImmutableSet.Builder fileDescriptors = ImmutableSet.builder(); + private boolean allowJsonFieldNames; + private boolean resolveTypeDependencies; + + /** Adds a {@link FileDescriptor} to the provider. */ + @CanIgnoreReturnValue + public Builder addFileDescriptors(FileDescriptor... fileDescriptors) { + return addFileDescriptors(Arrays.asList(fileDescriptors)); + } + + /** Adds a collection of {@link FileDescriptor}s to the provider. */ + @CanIgnoreReturnValue + public Builder addFileDescriptors(Iterable fileDescriptors) { + this.fileDescriptors.addAll(fileDescriptors); + return this; + } + + /** Adds a collection of {@link Descriptor}s. The parent file of each descriptor is added. */ + @CanIgnoreReturnValue + public Builder addDescriptors(Iterable descriptors) { + this.fileDescriptors.addAll(Iterables.transform(descriptors, Descriptor::getFile)); + return this; + } + + /** + * Use the `json_name` field option on a protobuf message as the name of the field. + * + *

If enabled, the type checker will only accept the `json_name` and will no longer recognize + * the original protobuf field name. This is to avoid ambiguity between the two names. + */ + @CanIgnoreReturnValue + public Builder setAllowJsonFieldNames(boolean allowJsonFieldNames) { + this.allowJsonFieldNames = allowJsonFieldNames; + return this; + } + + /** + * If true, all transitive dependencies of the added {@link FileDescriptor}s will be resolved + * and their types will be made available to the type provider. By default, this is disabled. + */ + @CanIgnoreReturnValue + public Builder setResolveTypeDependencies(boolean resolveTypeDependencies) { + this.resolveTypeDependencies = resolveTypeDependencies; + return this; + } + + /** Builds the {@link ProtoMessageTypeProvider}. */ + public ProtoMessageTypeProvider build() { + CelDescriptors celDescriptors = + CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( + fileDescriptors.build(), resolveTypeDependencies); + + return new ProtoMessageTypeProvider(celDescriptors, allowJsonFieldNames); + } + + private Builder() {} + } } diff --git a/common/src/test/java/dev/cel/common/types/BUILD.bazel b/common/src/test/java/dev/cel/common/types/BUILD.bazel index 600a63940..0c8121bbd 100644 --- a/common/src/test/java/dev/cel/common/types/BUILD.bazel +++ b/common/src/test/java/dev/cel/common/types/BUILD.bazel @@ -16,6 +16,7 @@ java_library( "//common/types:cel_types", "//common/types:message_type_provider", "//common/types:type_providers", + "//testing/protos:single_file_java_proto", "@cel_spec//proto/cel/expr:checked_java_proto", "@cel_spec//proto/cel/expr/conformance/proto2:test_all_types_java_proto", "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", diff --git a/common/src/test/java/dev/cel/common/types/ProtoMessageTypeProviderTest.java b/common/src/test/java/dev/cel/common/types/ProtoMessageTypeProviderTest.java index d774903e1..16797b714 100644 --- a/common/src/test/java/dev/cel/common/types/ProtoMessageTypeProviderTest.java +++ b/common/src/test/java/dev/cel/common/types/ProtoMessageTypeProviderTest.java @@ -20,8 +20,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import dev.cel.common.types.CelTypeProvider.CombinedCelTypeProvider; +import dev.cel.common.types.StructType.Field; import dev.cel.expr.conformance.proto2.TestAllTypes; import dev.cel.expr.conformance.proto2.TestAllTypesExtensions; +import dev.cel.testing.testdata.SingleFileProto.SingleFile; import java.util.Optional; import org.junit.Test; import org.junit.runner.RunWith; @@ -254,4 +256,23 @@ public void types_combinedDuplicateProviderIsSameAsFirst() { CombinedCelTypeProvider combined = new CombinedCelTypeProvider(proto3Provider, proto3Provider); assertThat(combined.types()).hasSize(proto3Provider.types().size()); } + + @Test + public void findField_withJsonNameOption() { + ProtoMessageTypeProvider typeProvider = + ProtoMessageTypeProvider.newBuilder() + .addFileDescriptors(SingleFile.getDescriptor().getFile()) + .setAllowJsonFieldNames(true) + .build(); + + ProtoMessageType msgType = + (ProtoMessageType) typeProvider.findType(SingleFile.getDescriptor().getFullName()).get(); + + // Note that these are the same fields, with json_name option set + Optional snakeCasedField = msgType.findField("snake_cased"); + Optional jsonNameField = msgType.findField("camelCased"); + + assertThat(snakeCasedField).isEmpty(); + assertThat(jsonNameField).isPresent(); + } } diff --git a/common/src/test/java/dev/cel/common/types/ProtoMessageTypeTest.java b/common/src/test/java/dev/cel/common/types/ProtoMessageTypeTest.java index 3feab06fd..d6e90b1b4 100644 --- a/common/src/test/java/dev/cel/common/types/ProtoMessageTypeTest.java +++ b/common/src/test/java/dev/cel/common/types/ProtoMessageTypeTest.java @@ -46,7 +46,8 @@ public void setUp() { "my.package.TestMessage", FIELD_MAP.keySet(), (field) -> Optional.ofNullable(FIELD_MAP.get(field)), - (extension) -> Optional.ofNullable(EXTENSION_MAP.get(extension))); + (extension) -> Optional.ofNullable(EXTENSION_MAP.get(extension)), + (unused) -> false); } @Test diff --git a/runtime/src/main/java/dev/cel/runtime/DescriptorMessageProvider.java b/runtime/src/main/java/dev/cel/runtime/DescriptorMessageProvider.java index 2980aea96..ba0e442ec 100644 --- a/runtime/src/main/java/dev/cel/runtime/DescriptorMessageProvider.java +++ b/runtime/src/main/java/dev/cel/runtime/DescriptorMessageProvider.java @@ -182,6 +182,15 @@ private FieldDescriptor findField(Descriptor descriptor, String fieldName) { } } + if (fieldDescriptor == null && celOptions.enableJsonFieldNames()) { + for (FieldDescriptor fd : descriptor.getFields()) { + if (fd.getJsonName().equals(fieldName)) { + fieldDescriptor = fd; + break; + } + } + } + if (fieldDescriptor == null) { throw new IllegalArgumentException( String.format( diff --git a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel index ed76fdbe7..b05f0b66a 100644 --- a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel @@ -19,7 +19,6 @@ java_library( # keep sorted exclude = [ "CelLiteInterpreterTest.java", - "CelValueInterpreterTest.java", "InterpreterTest.java", ] + ANDROID_TESTS, ), @@ -124,20 +123,6 @@ java_library( ], ) -java_library( - name = "cel_value_interpreter_test", - testonly = 1, - srcs = [ - "CelValueInterpreterTest.java", - ], - deps = [ - # "//java/com/google/testing/testsize:annotations", - "//testing:base_interpreter_test", - "@maven//:junit_junit", - "@maven//:com_google_testparameterinjector_test_parameter_injector", - ], -) - cel_android_local_test( name = "android_tests", srcs = ANDROID_TESTS, @@ -182,6 +167,7 @@ java_library( "CelLiteInterpreterTest.java", ], deps = [ + "//common:options", "//common/values:proto_message_lite_value_provider", "//extensions:optional_library", "//runtime", @@ -203,7 +189,6 @@ junit4_test_suites( src_dir = "src/test/java", deps = [ ":cel_lite_interpreter_test", - ":cel_value_interpreter_test", ":interpreter_test", ":tests", ], diff --git a/runtime/src/test/java/dev/cel/runtime/CelLiteInterpreterTest.java b/runtime/src/test/java/dev/cel/runtime/CelLiteInterpreterTest.java index 088a2d7b0..b3a1f2efa 100644 --- a/runtime/src/test/java/dev/cel/runtime/CelLiteInterpreterTest.java +++ b/runtime/src/test/java/dev/cel/runtime/CelLiteInterpreterTest.java @@ -15,6 +15,7 @@ package dev.cel.runtime; import com.google.testing.junit.testparameterinjector.TestParameterInjector; +import dev.cel.common.CelOptions; import dev.cel.common.values.ProtoMessageLiteValueProvider; import dev.cel.expr.conformance.proto3.TestAllTypesCelDescriptor; import dev.cel.extensions.CelOptionalLibrary; @@ -27,16 +28,16 @@ */ @RunWith(TestParameterInjector.class) public class CelLiteInterpreterTest extends BaseInterpreterTest { - public CelLiteInterpreterTest() { - super( - CelRuntimeFactory.standardCelRuntimeBuilder() - .setValueProvider( - ProtoMessageLiteValueProvider.newInstance( - dev.cel.expr.conformance.proto2.TestAllTypesCelDescriptor.getDescriptor(), - TestAllTypesCelDescriptor.getDescriptor())) - .addLibraries(CelOptionalLibrary.INSTANCE) - .setOptions(newBaseCelOptions().toBuilder().enableCelValue(true).build()) - .build()); + + @Override + protected CelRuntimeBuilder newBaseRuntimeBuilder(CelOptions celOptions) { + return CelRuntimeFactory.standardCelRuntimeBuilder() + .setValueProvider( + ProtoMessageLiteValueProvider.newInstance( + dev.cel.expr.conformance.proto2.TestAllTypesCelDescriptor.getDescriptor(), + TestAllTypesCelDescriptor.getDescriptor())) + .addLibraries(CelOptionalLibrary.INSTANCE) + .setOptions(celOptions.toBuilder().enableCelValue(true).build()); } @Override @@ -97,4 +98,10 @@ public void jsonValueTypes() { public void messages_error() { skipBaselineVerification(); } + + @Override + public void jsonFieldNames() { + // json_name field option is not yet supported in lite runtime + skipBaselineVerification(); + } } diff --git a/runtime/src/test/java/dev/cel/runtime/CelValueInterpreterTest.java b/runtime/src/test/java/dev/cel/runtime/CelValueInterpreterTest.java deleted file mode 100644 index ad3fae082..000000000 --- a/runtime/src/test/java/dev/cel/runtime/CelValueInterpreterTest.java +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2023 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package dev.cel.runtime; - -import com.google.testing.junit.testparameterinjector.TestParameterInjector; -// import com.google.testing.testsize.MediumTest; -import dev.cel.testing.BaseInterpreterTest; -import org.junit.runner.RunWith; - -/** Tests for {@link Interpreter} and related functionality using {@code CelValue}. */ -// @MediumTest -@RunWith(TestParameterInjector.class) -public class CelValueInterpreterTest extends BaseInterpreterTest { - - public CelValueInterpreterTest() { - super(newBaseCelOptions().toBuilder().enableCelValue(true).build()); - } - - @Override - public void wrappers() throws Exception { - // Field selection on repeated wrappers broken. - // This test along with CelValue adapter will be removed in a separate CL - skipBaselineVerification(); - } -} diff --git a/testing/src/main/java/dev/cel/testing/BUILD.bazel b/testing/src/main/java/dev/cel/testing/BUILD.bazel index d9d6a6564..5ecfd37f2 100644 --- a/testing/src/main/java/dev/cel/testing/BUILD.bazel +++ b/testing/src/main/java/dev/cel/testing/BUILD.bazel @@ -84,6 +84,7 @@ java_library( "//common/internal:proto_time_utils", "//common/resources/testdata/proto3:standalone_global_enum_java_proto", "//common/types", + "//common/types:message_type_provider", "//common/types:type_providers", "//common/values:cel_byte_string", "//extensions:optional_library", diff --git a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java index 425271e1b..74bcca00e 100644 --- a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java +++ b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java @@ -65,6 +65,7 @@ import dev.cel.common.types.ListType; import dev.cel.common.types.MapType; import dev.cel.common.types.OpaqueType; +import dev.cel.common.types.ProtoMessageTypeProvider; import dev.cel.common.types.SimpleType; import dev.cel.common.types.StructTypeReference; import dev.cel.common.types.TypeParamType; @@ -77,6 +78,7 @@ import dev.cel.runtime.CelFunctionBinding; import dev.cel.runtime.CelLateFunctionBindings; import dev.cel.runtime.CelRuntime; +import dev.cel.runtime.CelRuntimeBuilder; import dev.cel.runtime.CelRuntimeFactory; import dev.cel.runtime.CelUnknownSet; import dev.cel.runtime.CelVariableResolver; @@ -114,36 +116,43 @@ public abstract class BaseInterpreterTest extends CelBaselineTestCase { .comprehensionMaxIterations(1_000) .build(); private CelRuntime celRuntime; + protected CelOptions celOptions; protected BaseInterpreterTest() { - this(newRuntime(BASE_CEL_OPTIONS)); - } - - protected BaseInterpreterTest(CelOptions celOptions) { - this(newRuntime(celOptions)); + this.celOptions = BASE_CEL_OPTIONS; + this.celRuntime = newBaseRuntimeBuilder(celOptions).build(); } protected BaseInterpreterTest(CelRuntime celRuntime) { this.celRuntime = celRuntime; + this.celOptions = BASE_CEL_OPTIONS; } - private static CelRuntime newRuntime(CelOptions celOptions) { + protected CelRuntimeBuilder newBaseRuntimeBuilder(CelOptions celOptions) { return CelRuntimeFactory.standardCelRuntimeBuilder() .addLibraries(CelOptionalLibrary.INSTANCE) .addFileTypes(TEST_FILE_DESCRIPTORS) - .setOptions(celOptions) - .build(); + .setOptions(celOptions); } - protected static CelOptions newBaseCelOptions() { - return BASE_CEL_OPTIONS; + @Override + protected CelAbstractSyntaxTree prepareTest(List descriptors) { + return prepareTest( + ProtoMessageTypeProvider.newBuilder() + .addFileDescriptors(descriptors) + .setAllowJsonFieldNames(celOptions.enableJsonFieldNames()) + .build()); } @Override protected void prepareCompiler(CelTypeProvider typeProvider) { super.prepareCompiler(typeProvider); this.celCompiler = - celCompiler.toCompilerBuilder().addLibraries(CelOptionalLibrary.INSTANCE).build(); + celCompiler + .toCompilerBuilder() + .addLibraries(CelOptionalLibrary.INSTANCE) + .setOptions(celOptions) + .build(); } private CelAbstractSyntaxTree compileTestCase() { @@ -2068,7 +2077,8 @@ public void wrappers() throws Exception { @Test public void longComprehension() { ImmutableList l = LongStream.range(0L, 1000L).boxed().collect(toImmutableList()); - addFunctionBinding(CelFunctionBinding.from("constantLongList", ImmutableList.of(), unused -> l)); + addFunctionBinding( + CelFunctionBinding.from("constantLongList", ImmutableList.of(), unused -> l)); // Comprehension over compile-time constant long list. declareFunction( @@ -2482,4 +2492,22 @@ private static Descriptor getDeserializedTestAllTypeDescriptor() { throw new RuntimeException("Error loading TestAllTypes descriptor", e); } } + + @Test + public void jsonFieldNames() throws Exception { + this.celOptions = celOptions.toBuilder().enableJsonFieldNames(true).build(); + this.celRuntime = newBaseRuntimeBuilder(celOptions).build(); + + TestAllTypes message = TestAllTypes.newBuilder().setSingleInt32(42).build(); + declareVariable("x", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())); + + source = "x.singleInt32 == 42"; + assertThat(runTest(ImmutableMap.of("x", message))).isEqualTo(true); + + source = "TestAllTypes{singleInt32: 42}.singleInt32 == 42"; + container = CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage()); + assertThat(runTest()).isEqualTo(true); + + skipBaselineVerification(); + } } diff --git a/testing/src/main/java/dev/cel/testing/CelBaselineTestCase.java b/testing/src/main/java/dev/cel/testing/CelBaselineTestCase.java index 0bee52af7..79ae88f47 100644 --- a/testing/src/main/java/dev/cel/testing/CelBaselineTestCase.java +++ b/testing/src/main/java/dev/cel/testing/CelBaselineTestCase.java @@ -20,7 +20,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.protobuf.DescriptorProtos.FileDescriptorSet; -import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FileDescriptor; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelContainer; @@ -70,15 +69,11 @@ protected CelAbstractSyntaxTree prepareTest(List descriptors) { return prepareTest(new ProtoMessageTypeProvider(ImmutableSet.copyOf(descriptors))); } - protected CelAbstractSyntaxTree prepareTest(Iterable descriptors) { - return prepareTest(new ProtoMessageTypeProvider(descriptors)); - } - protected CelAbstractSyntaxTree prepareTest(FileDescriptorSet descriptorSet) { return prepareTest(new ProtoMessageTypeProvider(descriptorSet)); } - private CelAbstractSyntaxTree prepareTest(CelTypeProvider typeProvider) { + protected CelAbstractSyntaxTree prepareTest(CelTypeProvider typeProvider) { prepareCompiler(typeProvider); CelAbstractSyntaxTree ast; @@ -186,9 +181,7 @@ protected String formatVarDecl(CelVarDecl decl) { * Declares a function with one or more overloads * * @param functionName Function name - * @param overloads Function overloads in protobuf representation. If {@link #declareWithCelTypes} - * is set, the protobuf overloads are internally converted into java native versions {@link - * CelOverloadDecl}. + * @param overloads Function overloads in protobuf representation. */ protected void declareFunction(String functionName, CelOverloadDecl... overloads) { this.functionDecls.add(newFunctionDeclaration(functionName, overloads)); diff --git a/testing/src/test/resources/protos/single_file.proto b/testing/src/test/resources/protos/single_file.proto index 0fcf270a1..b5ce518e0 100644 --- a/testing/src/test/resources/protos/single_file.proto +++ b/testing/src/test/resources/protos/single_file.proto @@ -26,4 +26,5 @@ message SingleFile { string name = 1; Path path = 2; + string snake_cased = 3 [json_name = "camelCased"]; }