Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions bundle/src/test/java/dev/cel/bundle/CelImplTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import static dev.cel.common.CelOverloadDecl.newMemberOverload;
import static org.junit.Assert.assertThrows;

import com.google.common.primitives.UnsignedLong;
import com.google.protobuf.UInt64Value;
import dev.cel.expr.CheckedExpr;
import dev.cel.expr.Constant;
import dev.cel.expr.Decl;
Expand Down Expand Up @@ -2193,6 +2195,74 @@ public void toBuilder_isImmutable() {
assertThat(newRuntimeBuilder).isNotEqualTo(celImpl.toRuntimeBuilder());
}

@Test
public void mapSelection_uintWrapper() throws Exception {
Cel cel = CelFactory.standardCelBuilder()
.addVar("args", MapType.create(SimpleType.DYN, SimpleType.DYN))
.setContainer(CelContainer.ofName("cel.expr.conformance.proto3"))
.addMessageTypes(TestAllTypes.getDescriptor())
.build();

CelAbstractSyntaxTree ast = cel.compile("args.i[1]").getAst();

Object result = cel.createProgram(ast).eval(
ImmutableMap.of("args",
ImmutableMap.of("i", ImmutableMap.of(1L, UInt64Value.of(123L)))));

assertThat(result).isEqualTo(UnsignedLong.valueOf(123L));
}

@Test
public void messageCreation_listContainsUintWrapperCreation() throws Exception {
Cel cel = CelFactory.standardCelBuilder()
.addVar("args", MapType.create(SimpleType.DYN, SimpleType.DYN))
.setContainer(CelContainer.ofName("cel.expr.conformance.proto3"))
.addMessageTypes(TestAllTypes.getDescriptor())
.build();

CelAbstractSyntaxTree ast = cel.compile("TestAllTypes{repeated_uint64: [google.protobuf.UInt64Value{value: 123u}]}").getAst();

Object result = cel.createProgram(ast).eval(
ImmutableMap.of("args",
ImmutableMap.of("i", ImmutableList.of(UInt64Value.of(123L)))));

assertThat(result).isEqualTo(TestAllTypes.newBuilder().addRepeatedUint64(123L).build());
}

@Test
public void messageCreation_listContainsUintWrapper() throws Exception {
Cel cel = CelFactory.standardCelBuilder()
.addVar("args", MapType.create(SimpleType.DYN, SimpleType.DYN))
.setContainer(CelContainer.ofName("cel.expr.conformance.proto3"))
.addMessageTypes(TestAllTypes.getDescriptor())
.build();

CelAbstractSyntaxTree ast = cel.compile("TestAllTypes{repeated_uint64: args.i}").getAst();

Object result = cel.createProgram(ast).eval(
ImmutableMap.of("args",
ImmutableMap.of("i", ImmutableList.of(UInt64Value.of(123L)))));

assertThat(result).isEqualTo(TestAllTypes.newBuilder().addRepeatedUint64(123L).build());
}

@Test
public void messageCreation_mapContainsUintWrapper() throws Exception {
Cel cel = CelFactory.standardCelBuilder()
.addVar("args", MapType.create(SimpleType.DYN, SimpleType.DYN))
.setContainer(CelContainer.ofName("cel.expr.conformance.proto3"))
.addMessageTypes(TestAllTypes.getDescriptor())
.build();

CelAbstractSyntaxTree ast = cel.compile("TestAllTypes{map_int64_uint64 : args.i}").getAst();

Object result = cel.createProgram(ast).eval(
ImmutableMap.of("args",
ImmutableMap.of("i", ImmutableMap.of(1L, UInt64Value.of(123L)))));

assertThat(result).isEqualTo(TestAllTypes.newBuilder().putMapInt64Uint64(1L, 123L).build());
}

private static TypeProvider aliasingProvider(ImmutableMap<String, Type> typeAliases) {
return new TypeProvider() {
@Override
Expand Down
51 changes: 43 additions & 8 deletions common/src/main/java/dev/cel/common/internal/ProtoAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.MapEntry;
import com.google.protobuf.Message;
import com.google.protobuf.MessageLite;
import com.google.protobuf.MessageOrBuilder;
import dev.cel.common.CelOptions;
import dev.cel.common.annotations.Internal;
Expand Down Expand Up @@ -244,28 +245,48 @@ private BidiConverter fieldToValueConverter(FieldDescriptor fieldDescriptor) {
case SFIXED32:
case SINT32:
case INT32:
return INT_CONVERTER;
return unwrapAndConvert(INT_CONVERTER);
case FIXED32:
case UINT32:
if (celOptions.enableUnsignedLongs()) {
return UNSIGNED_UINT32_CONVERTER;
return unwrapAndConvert(UNSIGNED_UINT32_CONVERTER);
}
return SIGNED_UINT32_CONVERTER;
return unwrapAndConvert(SIGNED_UINT32_CONVERTER);
case FIXED64:
case UINT64:
if (celOptions.enableUnsignedLongs()) {
return UNSIGNED_UINT64_CONVERTER;
return unwrapAndConvert(UNSIGNED_UINT64_CONVERTER);
}
return BidiConverter.IDENTITY;
return BidiConverter.of(
BidiConverter.IDENTITY.forwardConverter(),
value -> BidiConverter.IDENTITY.backwardConverter().convert(unwrap(value)));
case FLOAT:
return DOUBLE_CONVERTER;
return unwrapAndConvert(DOUBLE_CONVERTER);
case DOUBLE:
case SFIXED64:
case SINT64:
case INT64:
return BidiConverter.of(
BidiConverter.IDENTITY.forwardConverter(),
value -> BidiConverter.IDENTITY.backwardConverter().convert(unwrap(value)));
case BYTES:
if (celOptions.evaluateCanonicalTypesToNativeValues()) {
return BidiConverter.<Object, Object>of(
ProtoAdapter::adaptProtoByteStringToValue, ProtoAdapter::adaptCelByteStringToProto);
ProtoAdapter::adaptProtoByteStringToValue,
value -> adaptCelByteStringToProto(unwrap(value)));
}

return BidiConverter.IDENTITY;
return BidiConverter.of(
BidiConverter.IDENTITY.forwardConverter(),
value -> BidiConverter.IDENTITY.backwardConverter().convert(unwrap(value)));
case STRING:
return BidiConverter.of(
BidiConverter.IDENTITY.forwardConverter(),
value -> BidiConverter.IDENTITY.backwardConverter().convert(unwrap(value)));
case BOOL:
return BidiConverter.of(
BidiConverter.IDENTITY.forwardConverter(),
value -> BidiConverter.IDENTITY.backwardConverter().convert(unwrap(value)));
case ENUM:
return BidiConverter.<Object, Long>of(
value -> (long) ((EnumValueDescriptor) value).getNumber(),
Expand Down Expand Up @@ -371,4 +392,18 @@ private static int unsignedIntCheckedCast(long value) {
throw new CelNumericOverflowException(e);
}
}

private Object unwrap(Object value) {
if (value instanceof MessageLite) {
return adaptProtoToValue((MessageOrBuilder) value);
}
return value;
}

private BidiConverter<Number, Object> unwrapAndConvert(
final BidiConverter<Number, Number> original) {
return BidiConverter.of(
original.forwardConverter()::convert,
value -> original.backwardConverter().convert((Number) unwrap(value)));
}
}
41 changes: 41 additions & 0 deletions runtime/src/test/resources/wrappers.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,47 @@ declare dyn_var {
bindings: {dyn_var=NULL_VALUE}
result: NULL_VALUE

Source: TestAllTypes{repeated_int32: int32_list}.repeated_int32 == [1] && TestAllTypes{repeated_int64: int64_list}.repeated_int64 == [2] && TestAllTypes{repeated_uint32: uint32_list}.repeated_uint32 == [3u] && TestAllTypes{repeated_uint64: uint64_list}.repeated_uint64 == [4u] && TestAllTypes{repeated_float: float_list}.repeated_float == [5.5] && TestAllTypes{repeated_double: double_list}.repeated_double == [6.6] && TestAllTypes{repeated_bool: bool_list}.repeated_bool == [true] && TestAllTypes{repeated_string: string_list}.repeated_string == ['hello'] && TestAllTypes{repeated_bytes: bytes_list}.repeated_bytes == [b'world']
declare int32_list {
value list(int)
}
declare int64_list {
value list(int)
}
declare uint32_list {
value list(uint)
}
declare uint64_list {
value list(uint)
}
declare float_list {
value list(double)
}
declare double_list {
value list(double)
}
declare bool_list {
value list(bool)
}
declare string_list {
value list(string)
}
declare bytes_list {
value list(bytes)
}
=====>
bindings: {int32_list=[value: 1
], int64_list=[value: 2
], uint32_list=[value: 3
], uint64_list=[value: 4
], float_list=[value: 5.5
], double_list=[value: 6.6
], bool_list=[value: true
], string_list=[value: "hello"
], bytes_list=[value: "world"
]}
result: true

Source: google.protobuf.Timestamp{ seconds: 253402300800 }
=====>
bindings: {}
Expand Down
36 changes: 36 additions & 0 deletions testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2058,6 +2058,42 @@ public void wrappers() throws Exception {
source = "dyn_var";
runTest(ImmutableMap.of("dyn_var", NullValue.NULL_VALUE));

clearAllDeclarations();
declareVariable("int32_list", ListType.create(SimpleType.INT));
declareVariable("int64_list", ListType.create(SimpleType.INT));
declareVariable("uint32_list", ListType.create(SimpleType.UINT));
declareVariable("uint64_list", ListType.create(SimpleType.UINT));
declareVariable("float_list", ListType.create(SimpleType.DOUBLE));
declareVariable("double_list", ListType.create(SimpleType.DOUBLE));
declareVariable("bool_list", ListType.create(SimpleType.BOOL));
declareVariable("string_list", ListType.create(SimpleType.STRING));
declareVariable("bytes_list", ListType.create(SimpleType.BYTES));

container = CelContainer.ofName(TestAllTypes.getDescriptor().getFullName());
source =
"TestAllTypes{repeated_int32: int32_list}.repeated_int32 == [1] && "
+ "TestAllTypes{repeated_int64: int64_list}.repeated_int64 == [2] && "
+ "TestAllTypes{repeated_uint32: uint32_list}.repeated_uint32 == [3u] && "
+ "TestAllTypes{repeated_uint64: uint64_list}.repeated_uint64 == [4u] && "
+ "TestAllTypes{repeated_float: float_list}.repeated_float == [5.5] && "
+ "TestAllTypes{repeated_double: double_list}.repeated_double == [6.6] && "
+ "TestAllTypes{repeated_bool: bool_list}.repeated_bool == [true] && "
+ "TestAllTypes{repeated_string: string_list}.repeated_string == ['hello'] && "
+ "TestAllTypes{repeated_bytes: bytes_list}.repeated_bytes == [b'world']";

runTest(
ImmutableMap.<String, Object>builder()
.put("int32_list", ImmutableList.of(Int32Value.of(1)))
.put("int64_list", ImmutableList.of(Int64Value.of(2)))
.put("uint32_list", ImmutableList.of(UInt32Value.of(3)))
.put("uint64_list", ImmutableList.of(UInt64Value.of(4)))
.put("float_list", ImmutableList.of(FloatValue.of(5.5f)))
.put("double_list", ImmutableList.of(DoubleValue.of(6.6)))
.put("bool_list", ImmutableList.of(BoolValue.of(true)))
.put("string_list", ImmutableList.of(StringValue.of("hello")))
.put("bytes_list", ImmutableList.of(BytesValue.of(ByteString.copyFromUtf8("world"))))
.build());

clearAllDeclarations();
// Currently allowed, but will be an error
// See https://github.com/google/cel-spec/pull/501
Expand Down
Loading