From 328fee53cd545f5c330030c423270c7b74fb5565 Mon Sep 17 00:00:00 2001 From: Sean Huh Date: Thu, 29 Jan 2026 12:02:17 -0800 Subject: [PATCH] Fix wrapper adaptations --- .../test/java/dev/cel/bundle/CelImplTest.java | 70 +++++++++++++++++++ .../dev/cel/common/internal/ProtoAdapter.java | 51 +++++++++++--- runtime/src/test/resources/wrappers.baseline | 41 +++++++++++ .../dev/cel/testing/BaseInterpreterTest.java | 36 ++++++++++ 4 files changed, 190 insertions(+), 8 deletions(-) diff --git a/bundle/src/test/java/dev/cel/bundle/CelImplTest.java b/bundle/src/test/java/dev/cel/bundle/CelImplTest.java index 13c3392d3..df5143e37 100644 --- a/bundle/src/test/java/dev/cel/bundle/CelImplTest.java +++ b/bundle/src/test/java/dev/cel/bundle/CelImplTest.java @@ -21,6 +21,8 @@ import static dev.cel.common.CelOverloadDecl.newMemberOverload; import static org.junit.Assert.assertThrows; +import com.google.common.primitives.UnsignedLong; +import com.google.protobuf.UInt64Value; import dev.cel.expr.CheckedExpr; import dev.cel.expr.Constant; import dev.cel.expr.Decl; @@ -2193,6 +2195,74 @@ public void toBuilder_isImmutable() { assertThat(newRuntimeBuilder).isNotEqualTo(celImpl.toRuntimeBuilder()); } + @Test + public void mapSelection_uintWrapper() throws Exception { + Cel cel = CelFactory.standardCelBuilder() + .addVar("args", MapType.create(SimpleType.DYN, SimpleType.DYN)) + .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) + .addMessageTypes(TestAllTypes.getDescriptor()) + .build(); + + CelAbstractSyntaxTree ast = cel.compile("args.i[1]").getAst(); + + Object result = cel.createProgram(ast).eval( + ImmutableMap.of("args", + ImmutableMap.of("i", ImmutableMap.of(1L, UInt64Value.of(123L))))); + + assertThat(result).isEqualTo(UnsignedLong.valueOf(123L)); + } + + @Test + public void messageCreation_listContainsUintWrapperCreation() throws Exception { + Cel cel = CelFactory.standardCelBuilder() + .addVar("args", MapType.create(SimpleType.DYN, SimpleType.DYN)) + .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) + .addMessageTypes(TestAllTypes.getDescriptor()) + .build(); + + CelAbstractSyntaxTree ast = cel.compile("TestAllTypes{repeated_uint64: [google.protobuf.UInt64Value{value: 123u}]}").getAst(); + + Object result = cel.createProgram(ast).eval( + ImmutableMap.of("args", + ImmutableMap.of("i", ImmutableList.of(UInt64Value.of(123L))))); + + assertThat(result).isEqualTo(TestAllTypes.newBuilder().addRepeatedUint64(123L).build()); + } + + @Test + public void messageCreation_listContainsUintWrapper() throws Exception { + Cel cel = CelFactory.standardCelBuilder() + .addVar("args", MapType.create(SimpleType.DYN, SimpleType.DYN)) + .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) + .addMessageTypes(TestAllTypes.getDescriptor()) + .build(); + + CelAbstractSyntaxTree ast = cel.compile("TestAllTypes{repeated_uint64: args.i}").getAst(); + + Object result = cel.createProgram(ast).eval( + ImmutableMap.of("args", + ImmutableMap.of("i", ImmutableList.of(UInt64Value.of(123L))))); + + assertThat(result).isEqualTo(TestAllTypes.newBuilder().addRepeatedUint64(123L).build()); + } + + @Test + public void messageCreation_mapContainsUintWrapper() throws Exception { + Cel cel = CelFactory.standardCelBuilder() + .addVar("args", MapType.create(SimpleType.DYN, SimpleType.DYN)) + .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) + .addMessageTypes(TestAllTypes.getDescriptor()) + .build(); + + CelAbstractSyntaxTree ast = cel.compile("TestAllTypes{map_int64_uint64 : args.i}").getAst(); + + Object result = cel.createProgram(ast).eval( + ImmutableMap.of("args", + ImmutableMap.of("i", ImmutableMap.of(1L, UInt64Value.of(123L))))); + + assertThat(result).isEqualTo(TestAllTypes.newBuilder().putMapInt64Uint64(1L, 123L).build()); + } + private static TypeProvider aliasingProvider(ImmutableMap typeAliases) { return new TypeProvider() { @Override diff --git a/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java b/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java index 3c3382ef2..2bab64804 100644 --- a/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java +++ b/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java @@ -30,6 +30,7 @@ import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.MapEntry; import com.google.protobuf.Message; +import com.google.protobuf.MessageLite; import com.google.protobuf.MessageOrBuilder; import dev.cel.common.CelOptions; import dev.cel.common.annotations.Internal; @@ -244,28 +245,48 @@ private BidiConverter fieldToValueConverter(FieldDescriptor fieldDescriptor) { case SFIXED32: case SINT32: case INT32: - return INT_CONVERTER; + return unwrapAndConvert(INT_CONVERTER); case FIXED32: case UINT32: if (celOptions.enableUnsignedLongs()) { - return UNSIGNED_UINT32_CONVERTER; + return unwrapAndConvert(UNSIGNED_UINT32_CONVERTER); } - return SIGNED_UINT32_CONVERTER; + return unwrapAndConvert(SIGNED_UINT32_CONVERTER); case FIXED64: case UINT64: if (celOptions.enableUnsignedLongs()) { - return UNSIGNED_UINT64_CONVERTER; + return unwrapAndConvert(UNSIGNED_UINT64_CONVERTER); } - return BidiConverter.IDENTITY; + return BidiConverter.of( + BidiConverter.IDENTITY.forwardConverter(), + value -> BidiConverter.IDENTITY.backwardConverter().convert(unwrap(value))); case FLOAT: - return DOUBLE_CONVERTER; + return unwrapAndConvert(DOUBLE_CONVERTER); + case DOUBLE: + case SFIXED64: + case SINT64: + case INT64: + return BidiConverter.of( + BidiConverter.IDENTITY.forwardConverter(), + value -> BidiConverter.IDENTITY.backwardConverter().convert(unwrap(value))); case BYTES: if (celOptions.evaluateCanonicalTypesToNativeValues()) { return BidiConverter.of( - ProtoAdapter::adaptProtoByteStringToValue, ProtoAdapter::adaptCelByteStringToProto); + ProtoAdapter::adaptProtoByteStringToValue, + value -> adaptCelByteStringToProto(unwrap(value))); } - return BidiConverter.IDENTITY; + return BidiConverter.of( + BidiConverter.IDENTITY.forwardConverter(), + value -> BidiConverter.IDENTITY.backwardConverter().convert(unwrap(value))); + case STRING: + return BidiConverter.of( + BidiConverter.IDENTITY.forwardConverter(), + value -> BidiConverter.IDENTITY.backwardConverter().convert(unwrap(value))); + case BOOL: + return BidiConverter.of( + BidiConverter.IDENTITY.forwardConverter(), + value -> BidiConverter.IDENTITY.backwardConverter().convert(unwrap(value))); case ENUM: return BidiConverter.of( value -> (long) ((EnumValueDescriptor) value).getNumber(), @@ -371,4 +392,18 @@ private static int unsignedIntCheckedCast(long value) { throw new CelNumericOverflowException(e); } } + + private Object unwrap(Object value) { + if (value instanceof MessageLite) { + return adaptProtoToValue((MessageOrBuilder) value); + } + return value; + } + + private BidiConverter unwrapAndConvert( + final BidiConverter original) { + return BidiConverter.of( + original.forwardConverter()::convert, + value -> original.backwardConverter().convert((Number) unwrap(value))); + } } diff --git a/runtime/src/test/resources/wrappers.baseline b/runtime/src/test/resources/wrappers.baseline index a971dcb29..d8212059b 100644 --- a/runtime/src/test/resources/wrappers.baseline +++ b/runtime/src/test/resources/wrappers.baseline @@ -154,6 +154,47 @@ declare dyn_var { bindings: {dyn_var=NULL_VALUE} result: NULL_VALUE +Source: TestAllTypes{repeated_int32: int32_list}.repeated_int32 == [1] && TestAllTypes{repeated_int64: int64_list}.repeated_int64 == [2] && TestAllTypes{repeated_uint32: uint32_list}.repeated_uint32 == [3u] && TestAllTypes{repeated_uint64: uint64_list}.repeated_uint64 == [4u] && TestAllTypes{repeated_float: float_list}.repeated_float == [5.5] && TestAllTypes{repeated_double: double_list}.repeated_double == [6.6] && TestAllTypes{repeated_bool: bool_list}.repeated_bool == [true] && TestAllTypes{repeated_string: string_list}.repeated_string == ['hello'] && TestAllTypes{repeated_bytes: bytes_list}.repeated_bytes == [b'world'] +declare int32_list { + value list(int) +} +declare int64_list { + value list(int) +} +declare uint32_list { + value list(uint) +} +declare uint64_list { + value list(uint) +} +declare float_list { + value list(double) +} +declare double_list { + value list(double) +} +declare bool_list { + value list(bool) +} +declare string_list { + value list(string) +} +declare bytes_list { + value list(bytes) +} +=====> +bindings: {int32_list=[value: 1 +], int64_list=[value: 2 +], uint32_list=[value: 3 +], uint64_list=[value: 4 +], float_list=[value: 5.5 +], double_list=[value: 6.6 +], bool_list=[value: true +], string_list=[value: "hello" +], bytes_list=[value: "world" +]} +result: true + Source: google.protobuf.Timestamp{ seconds: 253402300800 } =====> bindings: {} diff --git a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java index 425271e1b..46887af11 100644 --- a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java +++ b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java @@ -2058,6 +2058,42 @@ public void wrappers() throws Exception { source = "dyn_var"; runTest(ImmutableMap.of("dyn_var", NullValue.NULL_VALUE)); + clearAllDeclarations(); + declareVariable("int32_list", ListType.create(SimpleType.INT)); + declareVariable("int64_list", ListType.create(SimpleType.INT)); + declareVariable("uint32_list", ListType.create(SimpleType.UINT)); + declareVariable("uint64_list", ListType.create(SimpleType.UINT)); + declareVariable("float_list", ListType.create(SimpleType.DOUBLE)); + declareVariable("double_list", ListType.create(SimpleType.DOUBLE)); + declareVariable("bool_list", ListType.create(SimpleType.BOOL)); + declareVariable("string_list", ListType.create(SimpleType.STRING)); + declareVariable("bytes_list", ListType.create(SimpleType.BYTES)); + + container = CelContainer.ofName(TestAllTypes.getDescriptor().getFullName()); + source = + "TestAllTypes{repeated_int32: int32_list}.repeated_int32 == [1] && " + + "TestAllTypes{repeated_int64: int64_list}.repeated_int64 == [2] && " + + "TestAllTypes{repeated_uint32: uint32_list}.repeated_uint32 == [3u] && " + + "TestAllTypes{repeated_uint64: uint64_list}.repeated_uint64 == [4u] && " + + "TestAllTypes{repeated_float: float_list}.repeated_float == [5.5] && " + + "TestAllTypes{repeated_double: double_list}.repeated_double == [6.6] && " + + "TestAllTypes{repeated_bool: bool_list}.repeated_bool == [true] && " + + "TestAllTypes{repeated_string: string_list}.repeated_string == ['hello'] && " + + "TestAllTypes{repeated_bytes: bytes_list}.repeated_bytes == [b'world']"; + + runTest( + ImmutableMap.builder() + .put("int32_list", ImmutableList.of(Int32Value.of(1))) + .put("int64_list", ImmutableList.of(Int64Value.of(2))) + .put("uint32_list", ImmutableList.of(UInt32Value.of(3))) + .put("uint64_list", ImmutableList.of(UInt64Value.of(4))) + .put("float_list", ImmutableList.of(FloatValue.of(5.5f))) + .put("double_list", ImmutableList.of(DoubleValue.of(6.6))) + .put("bool_list", ImmutableList.of(BoolValue.of(true))) + .put("string_list", ImmutableList.of(StringValue.of("hello"))) + .put("bytes_list", ImmutableList.of(BytesValue.of(ByteString.copyFromUtf8("world")))) + .build()); + clearAllDeclarations(); // Currently allowed, but will be an error // See https://github.com/google/cel-spec/pull/501