diff --git a/checker/src/main/java/dev/cel/checker/BUILD.bazel b/checker/src/main/java/dev/cel/checker/BUILD.bazel index 6c486bd92..8c6a8b89d 100644 --- a/checker/src/main/java/dev/cel/checker/BUILD.bazel +++ b/checker/src/main/java/dev/cel/checker/BUILD.bazel @@ -179,12 +179,14 @@ java_library( "//common:cel_ast", "//common:compiler_common", "//common:container", + "//common:mutable_ast", "//common:operator", "//common:options", "//common:proto_ast", "//common/annotations", "//common/ast", "//common/ast:expr_converter", + "//common/ast:mutable_expr", "//common/internal:errors", "//common/internal:file_descriptor_converter", "//common/types", diff --git a/checker/src/main/java/dev/cel/checker/Env.java b/checker/src/main/java/dev/cel/checker/Env.java index 7029781e5..a7cc467b3 100644 --- a/checker/src/main/java/dev/cel/checker/Env.java +++ b/checker/src/main/java/dev/cel/checker/Env.java @@ -38,6 +38,7 @@ import dev.cel.common.ast.CelConstant; import dev.cel.common.ast.CelExpr; import dev.cel.common.ast.CelExprConverter; +import dev.cel.common.ast.CelMutableExpr; import dev.cel.common.ast.CelReference; import dev.cel.common.internal.Errors; import dev.cel.common.types.CelKind; @@ -304,6 +305,14 @@ public CelType getType(CelExpr expr) { return Preconditions.checkNotNull(typeMap.get(expr.id()), "expression has no type"); } + /** + * Returns the type associated with a mutable expression by expression id. It's an error to call this + * method if the type is not present. + */ + CelType getType(CelMutableExpr expr) { + return Preconditions.checkNotNull(typeMap.get(expr.id()), "expression has no type"); + } + /** * Sets the type associated with an expression by id. It's an error if the type is already set and * is different than the provided one. Returns the expression parameter. @@ -319,6 +328,21 @@ public CelExpr setType(CelExpr expr, CelType type) { return expr; } + /** + * Sets the type associated with a mutable expression by id. It's an error if the type is already set and + * is different than the provided one. Returns the expression parameter. + */ + @CanIgnoreReturnValue + CelMutableExpr setType(CelMutableExpr expr, CelType type) { + CelType oldType = typeMap.put(expr.id(), type); + Preconditions.checkState( + oldType == null || oldType.equals(type), + "expression already has a type which is incompatible.\n old: %s\n new: %s", + oldType, + type); + return expr; + } + /** * Sets the reference associated with an expression. It's an error if the reference is already set * and is different. @@ -330,6 +354,17 @@ public void setRef(CelExpr expr, CelReference reference) { "expression already has a reference which is incompatible"); } + /** + * Sets the reference associated with a mutable expression. It's an error if the reference is already set + * and is different. + */ + void setRef(CelMutableExpr expr, CelReference reference) { + CelReference oldReference = referenceMap.put(expr.id(), reference); + Preconditions.checkState( + oldReference == null || oldReference.equals(reference), + "expression already has a reference which is incompatible"); + } + /** * Adds a declaration to the environment, based on the Decl proto. Will report errors if the * declaration overlaps with an existing one, or clashes with a macro. diff --git a/checker/src/main/java/dev/cel/checker/ExprChecker.java b/checker/src/main/java/dev/cel/checker/ExprChecker.java index 37b692ecf..8cf45e2ae 100644 --- a/checker/src/main/java/dev/cel/checker/ExprChecker.java +++ b/checker/src/main/java/dev/cel/checker/ExprChecker.java @@ -29,12 +29,28 @@ import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelContainer; import dev.cel.common.CelFunctionDecl; +import dev.cel.common.CelMutableAst; import dev.cel.common.CelOverloadDecl; import dev.cel.common.CelProtoAbstractSyntaxTree; import dev.cel.common.Operator; import dev.cel.common.annotations.Internal; import dev.cel.common.ast.CelConstant; import dev.cel.common.ast.CelExpr; +import dev.cel.common.ast.CelMutableExpr; +import dev.cel.common.ast.CelMutableExpr.CelMutableCall; +import dev.cel.common.ast.CelMutableExpr.CelMutableComprehension; +import dev.cel.common.ast.CelMutableExpr.CelMutableIdent; +import dev.cel.common.ast.CelMutableExpr.CelMutableList; +import dev.cel.common.ast.CelMutableExpr.CelMutableMap; +import dev.cel.common.ast.CelMutableExpr.CelMutableSelect; +import dev.cel.common.ast.CelMutableExpr.CelMutableStruct; +import dev.cel.common.ast.CelMutableExpr.CelMutableCall; +import dev.cel.common.ast.CelMutableExpr.CelMutableComprehension; +import dev.cel.common.ast.CelMutableExpr.CelMutableIdent; +import dev.cel.common.ast.CelMutableExpr.CelMutableList; +import dev.cel.common.ast.CelMutableExpr.CelMutableMap; +import dev.cel.common.ast.CelMutableExpr.CelMutableSelect; +import dev.cel.common.ast.CelMutableExpr.CelMutableStruct; import dev.cel.common.ast.CelReference; import dev.cel.common.types.CelKind; import dev.cel.common.types.CelProtoTypes; @@ -130,16 +146,20 @@ public static CelAbstractSyntaxTree typecheck( env.enableCompileTimeOverloadResolution(), env.enableHomogeneousLiterals(), env.enableNamespacedDeclarations()); - CelExpr expr = checker.visit(ast.getExpr()); + + CelMutableAst mutableAst = CelMutableAst.fromCelAst(ast); + checker.visit(mutableAst.expr()); if (expectedResultType.isPresent()) { - checker.assertType(expr, expectedResultType.get()); + checker.assertType(mutableAst.expr(), expectedResultType.get()); } // Walk over the final type map substituting any type parameters either by their bound value or // by DYN. Map typeMap = Maps.transformValues(env.getTypeMap(), checker.inferenceContext::finalize); - return CelAbstractSyntaxTree.newCheckedAst(expr, ast.getSource(), env.getRefMap(), typeMap); + CelAbstractSyntaxTree parsedAst = mutableAst.toParsedAst(); + return CelAbstractSyntaxTree.newCheckedAst( + parsedAst.getExpr(), parsedAst.getSource(), env.getRefMap(), typeMap); } private final Env env; @@ -170,32 +190,38 @@ private ExprChecker( } /** Visit the {@code expr} value, routing to overloads based on the kind of expression. */ - @CheckReturnValue - public CelExpr visit(CelExpr expr) { - switch (expr.exprKind().getKind()) { + public void visit(CelMutableExpr expr) { + switch (expr.getKind()) { case CONSTANT: - return visit(expr, expr.constant()); + visit(expr, expr.constant()); + break; case IDENT: - return visit(expr, expr.ident()); + visit(expr, expr.ident()); + break; case SELECT: - return visit(expr, expr.select()); + visit(expr, expr.select()); + break; case CALL: - return visit(expr, expr.call()); + visit(expr, expr.call()); + break; case LIST: - return visit(expr, expr.list()); + visit(expr, expr.list()); + break; case STRUCT: - return visit(expr, expr.struct()); + visit(expr, expr.struct()); + break; case MAP: - return visit(expr, expr.map()); + visit(expr, expr.map()); + break; case COMPREHENSION: - return visit(expr, expr.comprehension()); + visit(expr, expr.comprehension()); + break; default: throw new IllegalArgumentException("unexpected expr kind"); } } - @CheckReturnValue - private CelExpr visit(CelExpr expr, CelConstant constant) { + private void visit(CelMutableExpr expr, CelConstant constant) { switch (constant.getKind()) { case INT64_VALUE: env.setType(expr, SimpleType.INT); @@ -227,33 +253,29 @@ private CelExpr visit(CelExpr expr, CelConstant constant) { default: throw new IllegalArgumentException("unexpected constant case: " + constant.getKind()); } - return expr; } - @CheckReturnValue - private CelExpr visit(CelExpr expr, CelExpr.CelIdent ident) { + private void visit(CelMutableExpr expr, CelMutableIdent ident) { CelIdentDecl decl = env.lookupIdent(expr.id(), getPosition(expr), container, ident.name()); checkNotNull(decl); if (decl.equals(Env.ERROR_IDENT_DECL)) { // error reported env.setType(expr, SimpleType.ERROR); env.setRef(expr, makeReference(decl.name(), decl)); - return expr; + return; } String refName = maybeDisambiguate(ident.name(), decl.name()); if (!refName.equals(ident.name())) { // Overwrite the identifier with its fully qualified name. - expr = replaceIdentSubtree(expr, refName); + expr.setIdent(CelMutableIdent.create(refName)); } env.setType(expr, decl.type()); env.setRef(expr, makeReference(refName, decl)); - return expr; } - @CheckReturnValue - private CelExpr visit(CelExpr expr, CelExpr.CelSelect select) { + private void visit(CelMutableExpr expr, CelMutableSelect select) { // Before traversing down the tree, try to interpret as qualified name. String qname = asQualifiedName(expr); if (qname != null) { @@ -268,44 +290,34 @@ private CelExpr visit(CelExpr expr, CelExpr.CelSelect select) { if (namespacedDeclarations) { // Rewrite the node to be a variable reference to the resolved fully-qualified // variable name. - expr = replaceIdentSubtree(expr, refName); + expr.setIdent(CelMutableIdent.create(refName)); } env.setType(expr, decl.type()); env.setRef(expr, makeReference(refName, decl)); } - return expr; + return; } } // Interpret as field selection, first traversing down the operand. - CelExpr visitedOperand = visit(select.operand()); - if (namespacedDeclarations && !select.operand().equals(visitedOperand)) { - // Subtree has been rewritten. Replace the operand. - expr = replaceSelectOperandSubtree(expr, visitedOperand); - } - CelType resultType = visitSelectField(expr, visitedOperand, select.field(), false); + visit(select.operand()); + + CelType resultType = visitSelectField(expr, select.operand(), select.field(), false); if (select.testOnly()) { resultType = SimpleType.BOOL; } env.setType(expr, resultType); - return expr; } - @CheckReturnValue - private CelExpr visit(CelExpr expr, CelExpr.CelCall call) { + private void visit(CelMutableExpr expr, CelMutableCall call) { String functionName = call.function(); if (Operator.OPTIONAL_SELECT.getFunction().equals(functionName)) { - return visitOptionalCall(expr, call); + visitOptionalCall(expr, call); + return; } // Traverse arguments. - ImmutableList argsList = call.args(); - for (int i = 0; i < argsList.size(); i++) { - CelExpr arg = argsList.get(i); - CelExpr visitedArg = visit(arg); - if (namespacedDeclarations && !visitedArg.equals(arg)) { - // Argument has been overwritten. - expr = replaceCallArgumentSubtree(expr, visitedArg, i); - } + for (CelMutableExpr arg : call.args()) { + visit(arg); } int position = getPosition(expr); @@ -319,7 +331,7 @@ private CelExpr visit(CelExpr expr, CelExpr.CelCall call) { if (!decl.name().equals(call.function())) { if (namespacedDeclarations) { // Overwrite the function name with its fully qualified resolved name. - expr = replaceCallSubtree(expr, decl.name()); + expr.setCall(CelMutableCall.create(decl.name(), call.args())); } } } else { @@ -334,17 +346,12 @@ private CelExpr visit(CelExpr expr, CelExpr.CelCall call) { // The function name is namespaced and so preserving the target operand would // be an inaccurate representation of the desired evaluation behavior. // Overwrite with fully-qualified resolved function name sans receiver target. - expr = replaceCallSubtree(expr, decl.name()); + expr.setCall(CelMutableCall.create(decl.name(), call.args())); } } else { // Regular instance call. - CelExpr target = call.target().get(); - CelExpr visitedTargetExpr = visit(target); - if (namespacedDeclarations && !visitedTargetExpr.equals(target)) { - // Visiting target contained a namespaced function. Rewrite the call expression here by - // setting the target to the new subtree. - expr = replaceCallSubtree(expr, visitedTargetExpr); - } + CelMutableExpr target = call.target().get(); + visit(target); resolution = resolveOverload( expr.id(), @@ -357,21 +364,15 @@ private CelExpr visit(CelExpr expr, CelExpr.CelCall call) { env.setType(expr, resolution.type()); env.setRef(expr, resolution.reference()); - - return expr; } - @CheckReturnValue - private CelExpr visit(CelExpr expr, CelExpr.CelStruct struct) { + private void visit(CelMutableExpr expr, CelMutableStruct struct) { // Determine the type of the message. CelType messageType = SimpleType.ERROR; CelIdentDecl decl = env.lookupIdent(expr.id(), getPosition(expr), container, struct.messageName()); if (!struct.messageName().equals(decl.name())) { - expr = - expr.toBuilder() - .setStruct(struct.toBuilder().setMessageName(decl.name()).build()) - .build(); + struct.setMessageName(decl.name()); } env.setRef(expr, CelReference.newBuilder().setName(decl.name()).build()); @@ -401,24 +402,21 @@ private CelExpr visit(CelExpr expr, CelExpr.CelStruct struct) { } // Check the field initializers. - ImmutableList entriesList = struct.entries(); - for (int i = 0; i < entriesList.size(); i++) { - CelExpr.CelStruct.Entry entry = entriesList.get(i); - CelExpr visitedValueExpr = visit(entry.value()); - if (namespacedDeclarations && !visitedValueExpr.equals(entry.value())) { - // Subtree has been rewritten. Replace the struct value. - expr = replaceStructEntryValueSubtree(expr, visitedValueExpr, i); - } + List entriesList = struct.entries(); + for (CelMutableStruct.Entry entry : entriesList) { + CelMutableExpr value = entry.value(); + visit(value); + CelType fieldType = getFieldType(entry.id(), getPosition(entry), messageType, entry.fieldKey()).celType(); - CelType valueType = env.getType(visitedValueExpr); + CelType valueType = env.getType(value); if (entry.optionalEntry()) { if (valueType instanceof OptionalType) { valueType = unwrapOptional(valueType); } else { assertIsAssignable( - visitedValueExpr.id(), - getPosition(visitedValueExpr), + value.id(), + getPosition(value), valueType, OptionalType.create(valueType)); } @@ -433,48 +431,41 @@ private CelExpr visit(CelExpr expr, CelExpr.CelStruct struct) { CelTypes.format(valueType)); } } - return expr; } - @CheckReturnValue - private CelExpr visit(CelExpr expr, CelExpr.CelMap map) { + private void visit(CelMutableExpr expr, CelMutableMap map) { CelType mapKeyType = null; CelType mapValueType = null; - ImmutableList entriesList = map.entries(); - for (int i = 0; i < entriesList.size(); i++) { - CelExpr.CelMap.Entry entry = entriesList.get(i); - CelExpr visitedMapKeyExpr = visit(entry.key()); - if (namespacedDeclarations && !visitedMapKeyExpr.equals(entry.key())) { - // Subtree has been rewritten. Replace the map key. - expr = replaceMapEntryKeySubtree(expr, visitedMapKeyExpr, i); - } + List entriesList = map.entries(); + for (CelMutableMap.Entry entry : entriesList) { + CelMutableExpr key = entry.key(); + visit(key); + mapKeyType = joinTypes( - visitedMapKeyExpr.id(), - getPosition(visitedMapKeyExpr), + key.id(), + getPosition(key), mapKeyType, - env.getType(visitedMapKeyExpr)); + env.getType(key)); - CelExpr visitedValueExpr = visit(entry.value()); - if (namespacedDeclarations && !visitedValueExpr.equals(entry.value())) { - // Subtree has been rewritten. Replace the map value. - expr = replaceMapEntryValueSubtree(expr, visitedValueExpr, i); - } - CelType valueType = env.getType(visitedValueExpr); + CelMutableExpr value = entry.value(); + visit(value); + + CelType valueType = env.getType(value); if (entry.optionalEntry()) { if (valueType instanceof OptionalType) { valueType = unwrapOptional(valueType); } else { assertIsAssignable( - visitedValueExpr.id(), - getPosition(visitedValueExpr), + value.id(), + getPosition(value), valueType, OptionalType.create(valueType)); } } mapValueType = - joinTypes(visitedValueExpr.id(), getPosition(visitedValueExpr), mapValueType, valueType); + joinTypes(value.id(), getPosition(value), mapValueType, valueType); } if (mapKeyType == null) { // If the map is empty, assign free type variables to key and value type. @@ -482,46 +473,40 @@ private CelExpr visit(CelExpr expr, CelExpr.CelMap map) { mapValueType = inferenceContext.newTypeVar("value"); } env.setType(expr, MapType.create(mapKeyType, mapValueType)); - return expr; } - @CheckReturnValue - private CelExpr visit(CelExpr expr, CelExpr.CelList list) { + private void visit(CelMutableExpr expr, CelMutableList list) { CelType elemsType = null; - ImmutableList elementsList = list.elements(); + List elementsList = list.elements(); HashSet optionalIndices = new HashSet<>(list.optionalIndices()); for (int i = 0; i < elementsList.size(); i++) { - CelExpr visitedElem = visit(elementsList.get(i)); - if (namespacedDeclarations && !visitedElem.equals(elementsList.get(i))) { - // Subtree has been rewritten. Replace the list element - expr = replaceListElementSubtree(expr, visitedElem, i); - } - CelType elemType = env.getType(visitedElem); + CelMutableExpr elem = elementsList.get(i); + visit(elem); + + CelType elemType = env.getType(elem); if (optionalIndices.contains(i)) { if (elemType instanceof OptionalType) { elemType = unwrapOptional(elemType); } else { assertIsAssignable( - visitedElem.id(), getPosition(visitedElem), elemType, OptionalType.create(elemType)); + elem.id(), getPosition(elem), elemType, OptionalType.create(elemType)); } } - elemsType = joinTypes(visitedElem.id(), getPosition(visitedElem), elemsType, elemType); + elemsType = joinTypes(elem.id(), getPosition(elem), elemsType, elemType); } if (elemsType == null) { // If the list is empty, assign free type var to elem type. elemsType = inferenceContext.newTypeVar("elem"); } env.setType(expr, ListType.create(elemsType)); - return expr; } - @CheckReturnValue - private CelExpr visit(CelExpr expr, CelExpr.CelComprehension compre) { - CelExpr visitedRange = visit(compre.iterRange()); - CelExpr visitedInit = visit(compre.accuInit()); - CelType accuType = env.getType(visitedInit); - CelType rangeType = inferenceContext.specialize(env.getType(visitedRange)); + private void visit(CelMutableExpr expr, CelMutableComprehension compre) { + visit(compre.iterRange()); + visit(compre.accuInit()); + CelType accuType = env.getType(compre.accuInit()); + CelType rangeType = inferenceContext.specialize(env.getType(compre.iterRange())); CelType varType; CelType varType2 = null; switch (rangeType.kind()) { @@ -556,7 +541,7 @@ private CelExpr visit(CelExpr expr, CelExpr.CelComprehension compre) { default: env.reportError( expr.id(), - getPosition(visitedRange), + getPosition(compre.iterRange()), "expression of type '%s' cannot be range of a comprehension " + "(must be list, map, or dynamic)", CelTypes.format(rangeType)); @@ -574,30 +559,16 @@ private CelExpr visit(CelExpr expr, CelExpr.CelComprehension compre) { if (!Strings.isNullOrEmpty(compre.iterVar2())) { env.add(CelIdentDecl.newIdentDeclaration(compre.iterVar2(), varType2)); } - CelExpr condition = visit(compre.loopCondition()); - assertType(condition, SimpleType.BOOL); - CelExpr visitedStep = visit(compre.loopStep()); - assertType(visitedStep, accuType); + visit(compre.loopCondition()); + assertType(compre.loopCondition(), SimpleType.BOOL); + visit(compre.loopStep()); + assertType(compre.loopStep(), accuType); // Forget iteration variable, as result expression must only depend on accu. env.exitScope(); - CelExpr visitedResult = visit(compre.result()); + visit(compre.result()); env.exitScope(); - if (namespacedDeclarations) { - if (!visitedRange.equals(compre.iterRange())) { - expr = replaceComprehensionRangeSubtree(expr, visitedRange); - } - if (!visitedInit.equals(compre.accuInit())) { - expr = replaceComprehensionAccuInitSubtree(expr, visitedInit); - } - if (!visitedStep.equals(compre.loopStep())) { - expr = replaceComprehensionStepSubtree(expr, visitedStep); - } - if (!visitedResult.equals(compre.result())) { - expr = replaceComprehensionResultSubtree(expr, visitedResult); - } - } - env.setType(expr, inferenceContext.specialize(env.getType(visitedResult))); - return expr; + + env.setType(expr, inferenceContext.specialize(env.getType(compre.result()))); } private CelReference makeReference(String name, CelIdentDecl decl) { @@ -630,8 +601,8 @@ private OverloadResolution resolveOverload( long callExprId, int position, @Nullable CelFunctionDecl function, - @Nullable CelExpr target, - List args) { + @Nullable CelMutableExpr target, + List args) { if (function == null || function.equals(Env.ERROR_FUNCTION_DECL)) { // Error reported, just return error value. return OverloadResolution.of(CelReference.newBuilder().build(), SimpleType.ERROR); @@ -640,7 +611,7 @@ private OverloadResolution resolveOverload( if (target != null) { argTypes.add(env.getType(target)); } - for (CelExpr arg : args) { + for (CelMutableExpr arg : args) { argTypes.add(env.getType(arg)); } CelType resultType = null; // For most common result type. @@ -716,7 +687,7 @@ private OverloadResolution resolveOverload( // Return value from visit is not needed as the subtree is not rewritten here. @SuppressWarnings("CheckReturnValue") private CelType visitSelectField( - CelExpr expr, CelExpr operand, String field, boolean isOptional) { + CelMutableExpr expr, CelMutableExpr operand, String field, boolean isOptional) { CelType operandType = inferenceContext.specialize(env.getType(operand)); CelType resultType = SimpleType.ERROR; @@ -763,25 +734,20 @@ private CelType visitSelectField( return resultType; } - private CelExpr visitOptionalCall(CelExpr expr, CelExpr.CelCall call) { - CelExpr operand = call.args().get(0); - CelExpr field = call.args().get(1); - if (!field.exprKind().getKind().equals(CelExpr.ExprKind.Kind.CONSTANT) + private void visitOptionalCall(CelMutableExpr expr, CelMutableCall call) { + CelMutableExpr operand = call.args().get(0); + CelMutableExpr field = call.args().get(1); + if (field.getKind() != CelExpr.ExprKind.Kind.CONSTANT || field.constant().getKind() != CelConstant.Kind.STRING_VALUE) { env.reportError(expr.id(), getPosition(field), "unsupported optional field selection"); - return expr; + return; } - CelExpr visitedOperand = visit(operand); - if (namespacedDeclarations && !operand.equals(visitedOperand)) { - // Subtree has been rewritten. Replace the operand. - expr = replaceCallArgumentSubtree(expr, visitedOperand, 0); - } + visit(operand); + CelType resultType = visitSelectField(expr, operand, field.constant().stringValue(), true); env.setType(expr, resultType); env.setRef(expr, CelReference.newBuilder().addOverloadIds("select_optional_field").build()); - - return expr; } /** @@ -789,8 +755,8 @@ private CelExpr visitOptionalCall(CelExpr expr, CelExpr.CelCall call) { * expression and returns the name they constitute, or null if the expression cannot be * interpreted like this. */ - private @Nullable String asQualifiedName(CelExpr expr) { - switch (expr.exprKind().getKind()) { + private @Nullable String asQualifiedName(CelMutableExpr expr) { + switch (expr.getKind()) { case IDENT: return expr.ident().name(); case SELECT: @@ -862,16 +828,16 @@ private CelType unwrapOptional(CelType type) { return type.parameters().get(0); } - private void assertType(CelExpr expr, CelType type) { + private void assertType(CelMutableExpr expr, CelType type) { assertIsAssignable(expr.id(), getPosition(expr), env.getType(expr), type); } - private int getPosition(CelExpr expr) { + private int getPosition(CelMutableExpr expr) { Integer pos = positionMap.get(expr.id()); return pos == null ? 0 : pos; } - private int getPosition(CelExpr.CelStruct.Entry entry) { + private int getPosition(CelMutableStruct.Entry entry) { Integer pos = positionMap.get(entry.id()); return pos == null ? 0 : pos; } @@ -894,81 +860,4 @@ public static OverloadResolution of(CelReference reference, CelType type) { /** Helper object to represent a {@link TypeProvider.FieldType} lookup failure. */ private static final TypeProvider.FieldType ERROR = TypeProvider.FieldType.of(Types.ERROR); - - private static CelExpr replaceIdentSubtree(CelExpr expr, String name) { - CelExpr.CelIdent newIdent = CelExpr.CelIdent.newBuilder().setName(name).build(); - return expr.toBuilder().setIdent(newIdent).build(); - } - - private static CelExpr replaceSelectOperandSubtree(CelExpr expr, CelExpr operand) { - CelExpr.CelSelect newSelect = expr.select().toBuilder().setOperand(operand).build(); - return expr.toBuilder().setSelect(newSelect).build(); - } - - private static CelExpr replaceCallArgumentSubtree(CelExpr expr, CelExpr newArg, int index) { - CelExpr.CelCall newCall = expr.call().toBuilder().setArg(index, newArg).build(); - return expr.toBuilder().setCall(newCall).build(); - } - - private static CelExpr replaceCallSubtree(CelExpr expr, String functionName) { - CelExpr.CelCall newCall = - expr.call().toBuilder().setFunction(functionName).clearTarget().build(); - return expr.toBuilder().setCall(newCall).build(); - } - - private static CelExpr replaceCallSubtree(CelExpr expr, CelExpr target) { - CelExpr.CelCall newCall = expr.call().toBuilder().setTarget(target).build(); - return expr.toBuilder().setCall(newCall).build(); - } - - private static CelExpr replaceListElementSubtree(CelExpr expr, CelExpr element, int index) { - CelExpr.CelList newList = expr.list().toBuilder().setElement(index, element).build(); - return expr.toBuilder().setList(newList).build(); - } - - private static CelExpr replaceStructEntryValueSubtree(CelExpr expr, CelExpr newValue, int index) { - CelExpr.CelStruct struct = expr.struct(); - CelExpr.CelStruct.Entry newEntry = - struct.entries().get(index).toBuilder().setValue(newValue).build(); - struct = struct.toBuilder().setEntry(index, newEntry).build(); - return expr.toBuilder().setStruct(struct).build(); - } - - private static CelExpr replaceMapEntryKeySubtree(CelExpr expr, CelExpr newKey, int index) { - CelExpr.CelMap map = expr.map(); - CelExpr.CelMap.Entry newEntry = map.entries().get(index).toBuilder().setKey(newKey).build(); - map = map.toBuilder().setEntry(index, newEntry).build(); - return expr.toBuilder().setMap(map).build(); - } - - private static CelExpr replaceMapEntryValueSubtree(CelExpr expr, CelExpr newValue, int index) { - CelExpr.CelMap map = expr.map(); - CelExpr.CelMap.Entry newEntry = map.entries().get(index).toBuilder().setValue(newValue).build(); - map = map.toBuilder().setEntry(index, newEntry).build(); - return expr.toBuilder().setMap(map).build(); - } - - private static CelExpr replaceComprehensionAccuInitSubtree(CelExpr expr, CelExpr newAccuInit) { - CelExpr.CelComprehension newComprehension = - expr.comprehension().toBuilder().setAccuInit(newAccuInit).build(); - return expr.toBuilder().setComprehension(newComprehension).build(); - } - - private static CelExpr replaceComprehensionRangeSubtree(CelExpr expr, CelExpr newRange) { - CelExpr.CelComprehension newComprehension = - expr.comprehension().toBuilder().setIterRange(newRange).build(); - return expr.toBuilder().setComprehension(newComprehension).build(); - } - - private static CelExpr replaceComprehensionStepSubtree(CelExpr expr, CelExpr newStep) { - CelExpr.CelComprehension newComprehension = - expr.comprehension().toBuilder().setLoopStep(newStep).build(); - return expr.toBuilder().setComprehension(newComprehension).build(); - } - - private static CelExpr replaceComprehensionResultSubtree(CelExpr expr, CelExpr newResult) { - CelExpr.CelComprehension newComprehension = - expr.comprehension().toBuilder().setResult(newResult).build(); - return expr.toBuilder().setComprehension(newComprehension).build(); - } } diff --git a/common/src/main/java/dev/cel/common/BUILD.bazel b/common/src/main/java/dev/cel/common/BUILD.bazel index ba223d213..62ba8db78 100644 --- a/common/src/main/java/dev/cel/common/BUILD.bazel +++ b/common/src/main/java/dev/cel/common/BUILD.bazel @@ -194,6 +194,7 @@ java_library( deps = [ ":cel_source", "//common/ast:mutable_expr", + "//common/internal", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", ], diff --git a/common/src/main/java/dev/cel/common/CelMutableSource.java b/common/src/main/java/dev/cel/common/CelMutableSource.java index 459042c6d..0f35b2b32 100644 --- a/common/src/main/java/dev/cel/common/CelMutableSource.java +++ b/common/src/main/java/dev/cel/common/CelMutableSource.java @@ -33,11 +33,19 @@ * Represents the mutable portion of the {@link CelSource}. This is intended for the purposes of * augmenting an AST through CEL optimizers. */ +import com.google.common.collect.ImmutableList; +import dev.cel.common.internal.CelCodePointArray; + +// ... + public final class CelMutableSource { private String description; private final Map macroCalls; private final Set extensions; + private final CelCodePointArray codePoints; + private final ImmutableList lineOffsets; + private final Map positions; @CanIgnoreReturnValue public CelMutableSource addMacroCalls(long exprId, CelMutableExpr expr) { @@ -89,9 +97,10 @@ public Set getExtensions() { } public CelSource toCelSource() { - return CelSource.newBuilder() + return CelSource.newBuilder(codePoints, lineOffsets) .setDescription(description) .addAllExtensions(extensions) + .addPositionsMap(positions) .addAllMacroCalls( macroCalls.entrySet().stream() .collect( @@ -101,7 +110,13 @@ public CelSource toCelSource() { } public static CelMutableSource newInstance() { - return new CelMutableSource("", new HashMap<>(), new HashSet<>()); + return new CelMutableSource( + "", + new HashMap<>(), + new HashSet<>(), + CelCodePointArray.fromString(""), + ImmutableList.of(), + new HashMap<>()); } public static CelMutableSource fromCelSource(CelSource source) { @@ -117,13 +132,24 @@ public static CelMutableSource fromCelSource(CelSource source) { "Unexpected source collision at ID: " + prev.id()); }, HashMap::new)), - source.getExtensions()); + source.getExtensions(), + source.getContent(), + source.getLineOffsets(), + source.getPositionsMap()); } CelMutableSource( - String description, Map macroCalls, Set extensions) { + String description, + Map macroCalls, + Set extensions, + CelCodePointArray codePoints, + ImmutableList lineOffsets, + Map positions) { this.description = checkNotNull(description); this.macroCalls = checkNotNull(macroCalls); this.extensions = checkNotNull(extensions); + this.codePoints = checkNotNull(codePoints); + this.lineOffsets = checkNotNull(lineOffsets); + this.positions = checkNotNull(positions); } } diff --git a/common/src/main/java/dev/cel/common/CelSource.java b/common/src/main/java/dev/cel/common/CelSource.java index 2678a0a2c..dcbf54771 100644 --- a/common/src/main/java/dev/cel/common/CelSource.java +++ b/common/src/main/java/dev/cel/common/CelSource.java @@ -148,8 +148,12 @@ public static Builder newBuilder(String text) { return newBuilder(CelCodePointArray.fromString(text)); } - public static Builder newBuilder(CelCodePointArray codePointArray) { - return new Builder(codePointArray, codePointArray.lineOffsets()); + public static Builder newBuilder(CelCodePointArray codePoints) { + return new Builder(codePoints, codePoints.lineOffsets()); + } + + public static Builder newBuilder(CelCodePointArray codePoints, List lineOffsets) { + return new Builder(codePoints, lineOffsets); } /** Builder for {@link CelSource}. */