Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 92 additions & 28 deletions runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import dev.cel.common.ast.CelExpr.CelCall;
import dev.cel.common.ast.CelExpr.CelList;
import dev.cel.common.ast.CelExpr.CelMap;
import dev.cel.common.ast.CelExpr.CelSelect;
import dev.cel.common.ast.CelExpr.CelStruct;
import dev.cel.common.ast.CelExpr.CelStruct.Entry;
import dev.cel.common.ast.CelReference;
Expand Down Expand Up @@ -58,6 +59,7 @@ public final class ProgramPlanner {
private final CelValueProvider valueProvider;
private final DefaultDispatcher dispatcher;
private final AttributeFactory attributeFactory;
private final CelContainer container;

/**
* Plans a {@link Program} from the provided parsed-only or type-checked {@link
Expand Down Expand Up @@ -168,7 +170,6 @@ private Interpretable planCall(CelExpr expr, PlannerContext ctx) {
evaluatedArgs[argIndex + offset] = plan(args.get(argIndex), ctx);
}

// TODO: Handle all specialized calls (logical operators, conditionals, equals etc)
String functionName = resolvedFunction.functionName();
Operator operator = Operator.findReverse(functionName).orElse(null);
if (operator != null) {
Expand Down Expand Up @@ -209,16 +210,7 @@ private Interpretable planCall(CelExpr expr, PlannerContext ctx) {

private Interpretable planCreateStruct(CelExpr celExpr, PlannerContext ctx) {
CelStruct struct = celExpr.struct();
CelType structType =
typeProvider
.findType(struct.messageName())
.orElseThrow(
() -> new IllegalArgumentException("Undefined type name: " + struct.messageName()));
if (!structType.kind().equals(CelKind.STRUCT)) {
throw new IllegalArgumentException(
String.format(
"Expected struct type for %s, got %s", structType.name(), structType.kind()));
}
StructType structType = resolveStructType(struct);

ImmutableList<Entry> entries = struct.entries();
String[] keys = new String[entries.size()];
Expand All @@ -230,7 +222,7 @@ private Interpretable planCreateStruct(CelExpr celExpr, PlannerContext ctx) {
values[i] = plan(entry.value(), ctx);
}

return EvalCreateStruct.create(valueProvider, (StructType) structType, keys, values);
return EvalCreateStruct.create(valueProvider, structType, keys, values);
}

private Interpretable planCreateList(CelExpr celExpr, PlannerContext ctx) {
Expand Down Expand Up @@ -269,7 +261,7 @@ private Interpretable planCreateMap(CelExpr celExpr, PlannerContext ctx) {
private ResolvedFunction resolveFunction(
CelExpr expr, ImmutableMap<Long, CelReference> referenceMap) {
CelCall call = expr.call();
Optional<CelExpr> target = call.target();
Optional<CelExpr> maybeTarget = call.target();
String functionName = call.function();

CelReference reference = referenceMap.get(expr.id());
Expand All @@ -281,22 +273,89 @@ private ResolvedFunction resolveFunction(
.setFunctionName(functionName)
.setOverloadId(reference.overloadIds().get(0));

target.ifPresent(builder::setTarget);
maybeTarget.ifPresent(builder::setTarget);

return builder.build();
}
}

// Parsed-only.
// TODO: Handle containers.
if (!target.isPresent()) {
// Parsed-only function resolution.
//
// There are two distinct cases we must handle:
//
// 1. Non-qualified function calls. This will resolve into either:
// - A simple global call foo()
// - A fully qualified global call through normal container resolution foo.bar.qux()
// 2. Qualified function calls:
// - A member call on an identifier foo.bar()
// - A fully qualified global call, through normal container resolution or abbreviations
// foo.bar.qux()
if (!maybeTarget.isPresent()) {
for (String cand : container.resolveCandidateNames(functionName)) {
CelResolvedOverload overload = dispatcher.findOverload(cand).orElse(null);
if (overload != null) {
return ResolvedFunction.newBuilder().setFunctionName(cand).build();
}
}

// Normal global call
return ResolvedFunction.newBuilder().setFunctionName(functionName).build();
} else {
return ResolvedFunction.newBuilder()
.setFunctionName(functionName)
.setTarget(target.get())
.build();
}

CelExpr target = maybeTarget.get();
String qualifiedPrefix = toQualifiedName(target).orElse(null);
if (qualifiedPrefix != null) {
String qualifiedName = qualifiedPrefix + "." + functionName;
for (String cand : container.resolveCandidateNames(qualifiedName)) {
CelResolvedOverload overload = dispatcher.findOverload(cand).orElse(null);
if (overload != null) {
return ResolvedFunction.newBuilder().setFunctionName(cand).build();
}
}
}

// Normal member call
return ResolvedFunction.newBuilder().setFunctionName(functionName).setTarget(target).build();
}

private StructType resolveStructType(CelStruct struct) {
String messageName = struct.messageName();
for (String typeName : container.resolveCandidateNames(messageName)) {
CelType structType = typeProvider.findType(typeName).orElse(null);
if (structType == null) {
continue;
}

if (!structType.kind().equals(CelKind.STRUCT)) {
throw new IllegalArgumentException(
String.format(
"Expected struct type for %s, got %s", structType.name(), structType.kind()));
}

return (StructType) structType;
}

throw new IllegalArgumentException("Undefined type name: " + messageName);
}

/** Converts a given expression into a qualified name, if possible. */
private Optional<String> toQualifiedName(CelExpr operand) {
switch (operand.getKind()) {
case IDENT:
return Optional.of(operand.ident().name());
case SELECT:
CelSelect select = operand.select();
String maybeQualified = toQualifiedName(select.operand()).orElse(null);
if (maybeQualified != null) {
return Optional.of(maybeQualified + "." + select.field());
}

break;
default:
// fall-through
}

return Optional.empty();
}

@AutoValue
Expand Down Expand Up @@ -338,17 +397,22 @@ private static PlannerContext create(CelAbstractSyntaxTree ast) {
}

public static ProgramPlanner newPlanner(
CelTypeProvider typeProvider, CelValueProvider valueProvider, DefaultDispatcher dispatcher) {
return new ProgramPlanner(typeProvider, valueProvider, dispatcher);
CelTypeProvider typeProvider,
CelValueProvider valueProvider,
DefaultDispatcher dispatcher,
CelContainer container) {
return new ProgramPlanner(typeProvider, valueProvider, dispatcher, container);
}

private ProgramPlanner(
CelTypeProvider typeProvider, CelValueProvider valueProvider, DefaultDispatcher dispatcher) {
CelTypeProvider typeProvider,
CelValueProvider valueProvider,
DefaultDispatcher dispatcher,
CelContainer container) {
this.typeProvider = typeProvider;
this.valueProvider = valueProvider;
// TODO: Container support
this.dispatcher = dispatcher;
this.attributeFactory =
AttributeFactory.newAttributeFactory(CelContainer.newBuilder().build(), typeProvider);
this.container = container;
this.attributeFactory = AttributeFactory.newAttributeFactory(container, typeProvider);
}
}
1 change: 1 addition & 0 deletions runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ java_library(
"//common:cel_descriptor_util",
"//common:cel_source",
"//common:compiler_common",
"//common:container",
"//common:error_codes",
"//common:operator",
"//common:options",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import com.google.testing.junit.testparameterinjector.TestParameterInjector;
import com.google.testing.junit.testparameterinjector.TestParameters;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelContainer;
import dev.cel.common.CelDescriptorUtil;
import dev.cel.common.CelErrorCode;
import dev.cel.common.CelOptions;
Expand Down Expand Up @@ -99,9 +100,11 @@ public final class ProgramPlannerTest {
DynamicProto.create(DefaultMessageFactory.create(DESCRIPTOR_POOL));
private static final CelValueProvider VALUE_PROVIDER =
ProtoMessageValueProvider.newInstance(CelOptions.DEFAULT, DYNAMIC_PROTO);
private static final CelContainer CEL_CONTAINER =
CelContainer.newBuilder().setName("cel.expr.conformance.proto3").build();

private static final ProgramPlanner PLANNER =
ProgramPlanner.newPlanner(TYPE_PROVIDER, VALUE_PROVIDER, newDispatcher());
ProgramPlanner.newPlanner(TYPE_PROVIDER, VALUE_PROVIDER, newDispatcher(), CEL_CONTAINER);
private static final CelCompiler CEL_COMPILER =
CelCompilerFactory.standardCelCompilerBuilder()
.addVar("map_var", MapType.create(SimpleType.STRING, SimpleType.DYN))
Expand All @@ -114,6 +117,10 @@ public final class ProgramPlannerTest {
"neg",
newGlobalOverload("neg_int", SimpleType.INT, SimpleType.INT),
newGlobalOverload("neg_double", SimpleType.DOUBLE, SimpleType.DOUBLE)),
newFunctionDeclaration(
"cel.expr.conformance.proto3.power",
newGlobalOverload(
"power_int_int", SimpleType.INT, SimpleType.INT, SimpleType.INT)),
newFunctionDeclaration(
"concat",
newGlobalOverload(
Expand All @@ -122,6 +129,7 @@ public final class ProgramPlannerTest {
"bytes_concat_bytes", SimpleType.BYTES, SimpleType.BYTES, SimpleType.BYTES)))
.addMessageTypes(TestAllTypes.getDescriptor())
.addLibraries(CelExtensions.optional())
.setContainer(CEL_CONTAINER)
.build();

/**
Expand Down Expand Up @@ -174,6 +182,14 @@ private static DefaultDispatcher newDispatcher() {
"neg",
CelFunctionBinding.from("neg_int", Long.class, arg -> -arg),
CelFunctionBinding.from("neg_double", Double.class, arg -> -arg));
addBindings(
builder,
"cel.expr.conformance.proto3.power",
CelFunctionBinding.from(
"power_int_int",
Long.class,
Long.class,
(value, power) -> (long) Math.pow(value, power)));
addBindings(
builder,
"concat",
Expand Down Expand Up @@ -379,6 +395,16 @@ public void planCreateStruct_withFields() throws Exception {
.isEqualTo(TestAllTypes.newBuilder().setSingleString("foo").setSingleBool(true).build());
}

@Test
public void plan_createStruct_withContainer() throws Exception {
CelAbstractSyntaxTree ast = compile("TestAllTypes{}");
Program program = PLANNER.plan(ast);

TestAllTypes result = (TestAllTypes) program.eval();

assertThat(result).isEqualTo(TestAllTypes.getDefaultInstance());
}

@Test
public void plan_call_zeroArgs() throws Exception {
CelAbstractSyntaxTree ast = compile("zero()");
Expand Down Expand Up @@ -555,6 +581,21 @@ public void plan_call_conditional_throws(String expression) throws Exception {
assertThat(e.getErrorCode()).isEqualTo(CelErrorCode.DIVIDE_BY_ZERO);
}

@Test
@TestParameters("{expression: 'power(2,3)'}")
@TestParameters("{expression: 'proto3.power(2,3)'}")
@TestParameters("{expression: 'conformance.proto3.power(2,3)'}")
@TestParameters("{expression: 'expr.conformance.proto3.power(2,3)'}")
@TestParameters("{expression: 'cel.expr.conformance.proto3.power(2,3)'}")
public void plan_call_withContainer(String expression) throws Exception {
CelAbstractSyntaxTree ast = compile(expression); // invokes cel.expr.conformance.proto3.power
Program program = PLANNER.plan(ast);

Long result = (Long) program.eval();

assertThat(result).isEqualTo(8);
}

private CelAbstractSyntaxTree compile(String expression) throws Exception {
CelAbstractSyntaxTree ast = CEL_COMPILER.parse(expression).getAst();
if (isParseOnly) {
Expand Down
Loading