From 741ad142f99700578be59479c27998fab12d9fb7 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Tue, 2 Dec 2025 13:43:54 -0800 Subject: [PATCH] Support container resolution for calls and struct creation in planner PiperOrigin-RevId: 839411411 --- .../cel/runtime/planner/ProgramPlanner.java | 120 ++++++++++++++---- .../java/dev/cel/runtime/planner/BUILD.bazel | 1 + .../runtime/planner/ProgramPlannerTest.java | 43 ++++++- 3 files changed, 135 insertions(+), 29 deletions(-) diff --git a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java index e3f00a405..bf7729c0f 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java @@ -28,6 +28,7 @@ import dev.cel.common.ast.CelExpr.CelCall; import dev.cel.common.ast.CelExpr.CelList; import dev.cel.common.ast.CelExpr.CelMap; +import dev.cel.common.ast.CelExpr.CelSelect; import dev.cel.common.ast.CelExpr.CelStruct; import dev.cel.common.ast.CelExpr.CelStruct.Entry; import dev.cel.common.ast.CelReference; @@ -58,6 +59,7 @@ public final class ProgramPlanner { private final CelValueProvider valueProvider; private final DefaultDispatcher dispatcher; private final AttributeFactory attributeFactory; + private final CelContainer container; /** * Plans a {@link Program} from the provided parsed-only or type-checked {@link @@ -168,7 +170,6 @@ private Interpretable planCall(CelExpr expr, PlannerContext ctx) { evaluatedArgs[argIndex + offset] = plan(args.get(argIndex), ctx); } - // TODO: Handle all specialized calls (logical operators, conditionals, equals etc) String functionName = resolvedFunction.functionName(); Operator operator = Operator.findReverse(functionName).orElse(null); if (operator != null) { @@ -209,16 +210,7 @@ private Interpretable planCall(CelExpr expr, PlannerContext ctx) { private Interpretable planCreateStruct(CelExpr celExpr, PlannerContext ctx) { CelStruct struct = celExpr.struct(); - CelType structType = - typeProvider - .findType(struct.messageName()) - .orElseThrow( - () -> new IllegalArgumentException("Undefined type name: " + struct.messageName())); - if (!structType.kind().equals(CelKind.STRUCT)) { - throw new IllegalArgumentException( - String.format( - "Expected struct type for %s, got %s", structType.name(), structType.kind())); - } + StructType structType = resolveStructType(struct); ImmutableList entries = struct.entries(); String[] keys = new String[entries.size()]; @@ -230,7 +222,7 @@ private Interpretable planCreateStruct(CelExpr celExpr, PlannerContext ctx) { values[i] = plan(entry.value(), ctx); } - return EvalCreateStruct.create(valueProvider, (StructType) structType, keys, values); + return EvalCreateStruct.create(valueProvider, structType, keys, values); } private Interpretable planCreateList(CelExpr celExpr, PlannerContext ctx) { @@ -269,7 +261,7 @@ private Interpretable planCreateMap(CelExpr celExpr, PlannerContext ctx) { private ResolvedFunction resolveFunction( CelExpr expr, ImmutableMap referenceMap) { CelCall call = expr.call(); - Optional target = call.target(); + Optional maybeTarget = call.target(); String functionName = call.function(); CelReference reference = referenceMap.get(expr.id()); @@ -281,22 +273,89 @@ private ResolvedFunction resolveFunction( .setFunctionName(functionName) .setOverloadId(reference.overloadIds().get(0)); - target.ifPresent(builder::setTarget); + maybeTarget.ifPresent(builder::setTarget); return builder.build(); } } - // Parsed-only. - // TODO: Handle containers. - if (!target.isPresent()) { + // Parsed-only function resolution. + // + // There are two distinct cases we must handle: + // + // 1. Non-qualified function calls. This will resolve into either: + // - A simple global call foo() + // - A fully qualified global call through normal container resolution foo.bar.qux() + // 2. Qualified function calls: + // - A member call on an identifier foo.bar() + // - A fully qualified global call, through normal container resolution or abbreviations + // foo.bar.qux() + if (!maybeTarget.isPresent()) { + for (String cand : container.resolveCandidateNames(functionName)) { + CelResolvedOverload overload = dispatcher.findOverload(cand).orElse(null); + if (overload != null) { + return ResolvedFunction.newBuilder().setFunctionName(cand).build(); + } + } + + // Normal global call return ResolvedFunction.newBuilder().setFunctionName(functionName).build(); - } else { - return ResolvedFunction.newBuilder() - .setFunctionName(functionName) - .setTarget(target.get()) - .build(); } + + CelExpr target = maybeTarget.get(); + String qualifiedPrefix = toQualifiedName(target).orElse(null); + if (qualifiedPrefix != null) { + String qualifiedName = qualifiedPrefix + "." + functionName; + for (String cand : container.resolveCandidateNames(qualifiedName)) { + CelResolvedOverload overload = dispatcher.findOverload(cand).orElse(null); + if (overload != null) { + return ResolvedFunction.newBuilder().setFunctionName(cand).build(); + } + } + } + + // Normal member call + return ResolvedFunction.newBuilder().setFunctionName(functionName).setTarget(target).build(); + } + + private StructType resolveStructType(CelStruct struct) { + String messageName = struct.messageName(); + for (String typeName : container.resolveCandidateNames(messageName)) { + CelType structType = typeProvider.findType(typeName).orElse(null); + if (structType == null) { + continue; + } + + if (!structType.kind().equals(CelKind.STRUCT)) { + throw new IllegalArgumentException( + String.format( + "Expected struct type for %s, got %s", structType.name(), structType.kind())); + } + + return (StructType) structType; + } + + throw new IllegalArgumentException("Undefined type name: " + messageName); + } + + /** Converts a given expression into a qualified name, if possible. */ + private Optional toQualifiedName(CelExpr operand) { + switch (operand.getKind()) { + case IDENT: + return Optional.of(operand.ident().name()); + case SELECT: + CelSelect select = operand.select(); + String maybeQualified = toQualifiedName(select.operand()).orElse(null); + if (maybeQualified != null) { + return Optional.of(maybeQualified + "." + select.field()); + } + + break; + default: + // fall-through + } + + return Optional.empty(); } @AutoValue @@ -338,17 +397,22 @@ private static PlannerContext create(CelAbstractSyntaxTree ast) { } public static ProgramPlanner newPlanner( - CelTypeProvider typeProvider, CelValueProvider valueProvider, DefaultDispatcher dispatcher) { - return new ProgramPlanner(typeProvider, valueProvider, dispatcher); + CelTypeProvider typeProvider, + CelValueProvider valueProvider, + DefaultDispatcher dispatcher, + CelContainer container) { + return new ProgramPlanner(typeProvider, valueProvider, dispatcher, container); } private ProgramPlanner( - CelTypeProvider typeProvider, CelValueProvider valueProvider, DefaultDispatcher dispatcher) { + CelTypeProvider typeProvider, + CelValueProvider valueProvider, + DefaultDispatcher dispatcher, + CelContainer container) { this.typeProvider = typeProvider; this.valueProvider = valueProvider; - // TODO: Container support this.dispatcher = dispatcher; - this.attributeFactory = - AttributeFactory.newAttributeFactory(CelContainer.newBuilder().build(), typeProvider); + this.container = container; + this.attributeFactory = AttributeFactory.newAttributeFactory(container, typeProvider); } } diff --git a/runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel b/runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel index 8c8c66369..9e3a79ed9 100644 --- a/runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel +++ b/runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel @@ -18,6 +18,7 @@ java_library( "//common:cel_descriptor_util", "//common:cel_source", "//common:compiler_common", + "//common:container", "//common:error_codes", "//common:operator", "//common:options", diff --git a/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java b/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java index 5001e3559..205e2ef8b 100644 --- a/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java +++ b/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java @@ -31,6 +31,7 @@ import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelContainer; import dev.cel.common.CelDescriptorUtil; import dev.cel.common.CelErrorCode; import dev.cel.common.CelOptions; @@ -99,9 +100,11 @@ public final class ProgramPlannerTest { DynamicProto.create(DefaultMessageFactory.create(DESCRIPTOR_POOL)); private static final CelValueProvider VALUE_PROVIDER = ProtoMessageValueProvider.newInstance(CelOptions.DEFAULT, DYNAMIC_PROTO); + private static final CelContainer CEL_CONTAINER = + CelContainer.newBuilder().setName("cel.expr.conformance.proto3").build(); private static final ProgramPlanner PLANNER = - ProgramPlanner.newPlanner(TYPE_PROVIDER, VALUE_PROVIDER, newDispatcher()); + ProgramPlanner.newPlanner(TYPE_PROVIDER, VALUE_PROVIDER, newDispatcher(), CEL_CONTAINER); private static final CelCompiler CEL_COMPILER = CelCompilerFactory.standardCelCompilerBuilder() .addVar("map_var", MapType.create(SimpleType.STRING, SimpleType.DYN)) @@ -114,6 +117,10 @@ public final class ProgramPlannerTest { "neg", newGlobalOverload("neg_int", SimpleType.INT, SimpleType.INT), newGlobalOverload("neg_double", SimpleType.DOUBLE, SimpleType.DOUBLE)), + newFunctionDeclaration( + "cel.expr.conformance.proto3.power", + newGlobalOverload( + "power_int_int", SimpleType.INT, SimpleType.INT, SimpleType.INT)), newFunctionDeclaration( "concat", newGlobalOverload( @@ -122,6 +129,7 @@ public final class ProgramPlannerTest { "bytes_concat_bytes", SimpleType.BYTES, SimpleType.BYTES, SimpleType.BYTES))) .addMessageTypes(TestAllTypes.getDescriptor()) .addLibraries(CelExtensions.optional()) + .setContainer(CEL_CONTAINER) .build(); /** @@ -174,6 +182,14 @@ private static DefaultDispatcher newDispatcher() { "neg", CelFunctionBinding.from("neg_int", Long.class, arg -> -arg), CelFunctionBinding.from("neg_double", Double.class, arg -> -arg)); + addBindings( + builder, + "cel.expr.conformance.proto3.power", + CelFunctionBinding.from( + "power_int_int", + Long.class, + Long.class, + (value, power) -> (long) Math.pow(value, power))); addBindings( builder, "concat", @@ -379,6 +395,16 @@ public void planCreateStruct_withFields() throws Exception { .isEqualTo(TestAllTypes.newBuilder().setSingleString("foo").setSingleBool(true).build()); } + @Test + public void plan_createStruct_withContainer() throws Exception { + CelAbstractSyntaxTree ast = compile("TestAllTypes{}"); + Program program = PLANNER.plan(ast); + + TestAllTypes result = (TestAllTypes) program.eval(); + + assertThat(result).isEqualTo(TestAllTypes.getDefaultInstance()); + } + @Test public void plan_call_zeroArgs() throws Exception { CelAbstractSyntaxTree ast = compile("zero()"); @@ -555,6 +581,21 @@ public void plan_call_conditional_throws(String expression) throws Exception { assertThat(e.getErrorCode()).isEqualTo(CelErrorCode.DIVIDE_BY_ZERO); } + @Test + @TestParameters("{expression: 'power(2,3)'}") + @TestParameters("{expression: 'proto3.power(2,3)'}") + @TestParameters("{expression: 'conformance.proto3.power(2,3)'}") + @TestParameters("{expression: 'expr.conformance.proto3.power(2,3)'}") + @TestParameters("{expression: 'cel.expr.conformance.proto3.power(2,3)'}") + public void plan_call_withContainer(String expression) throws Exception { + CelAbstractSyntaxTree ast = compile(expression); // invokes cel.expr.conformance.proto3.power + Program program = PLANNER.plan(ast); + + Long result = (Long) program.eval(); + + assertThat(result).isEqualTo(8); + } + private CelAbstractSyntaxTree compile(String expression) throws Exception { CelAbstractSyntaxTree ast = CEL_COMPILER.parse(expression).getAst(); if (isParseOnly) {