Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import dev.cel.runtime.CelRuntime;
import dev.cel.runtime.CelRuntimeFactory;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -243,4 +244,76 @@ public void lazyBinding_withNestedBinds() throws Exception {
assertThat(result).isTrue();
assertThat(invocation.get()).isEqualTo(2);
}

@Test
@SuppressWarnings({"Immutable", "unchecked"}) // Test only
public void lazyBinding_boundAttributeInComprehension() throws Exception {
CelCompiler celCompiler =
CelCompilerFactory.standardCelCompilerBuilder()
.setStandardMacros(CelStandardMacro.MAP)
.addLibraries(CelExtensions.bindings())
.addFunctionDeclarations(
CelFunctionDecl.newFunctionDeclaration(
"get_true",
CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)))
.build();
AtomicInteger invocation = new AtomicInteger();
CelRuntime celRuntime =
CelRuntimeFactory.standardCelRuntimeBuilder()
.addFunctionBindings(
CelFunctionBinding.from(
"get_true_overload",
ImmutableList.of(),
arg -> {
invocation.getAndIncrement();
return true;
}))
.build();

CelAbstractSyntaxTree ast =
celCompiler.compile("cel.bind(x, get_true(), [1,2,3].map(y, y < 0 || x))").getAst();

List<Boolean> result = (List<Boolean>) celRuntime.createProgram(ast).eval();

assertThat(result).containsExactly(true, true, true);
assertThat(invocation.get()).isEqualTo(1);
}

@Test
@SuppressWarnings({"Immutable"}) // Test only
public void lazyBinding_boundAttributeInNestedComprehension() throws Exception {
CelCompiler celCompiler =
CelCompilerFactory.standardCelCompilerBuilder()
.setStandardMacros(CelStandardMacro.EXISTS)
.addLibraries(CelExtensions.bindings())
.addFunctionDeclarations(
CelFunctionDecl.newFunctionDeclaration(
"get_true",
CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)))
.build();
AtomicInteger invocation = new AtomicInteger();
CelRuntime celRuntime =
CelRuntimeFactory.standardCelRuntimeBuilder()
.addFunctionBindings(
CelFunctionBinding.from(
"get_true_overload",
ImmutableList.of(),
arg -> {
invocation.getAndIncrement();
return true;
}))
.build();

CelAbstractSyntaxTree ast =
celCompiler
.compile(
"cel.bind(x, get_true(), [1,2,3].exists(unused, x && "
+ "['a','b','c'].exists(unused_2, x)))")
.getAst();

boolean result = (boolean) celRuntime.createProgram(ast).eval();

assertThat(result).isTrue();
assertThat(invocation.get()).isEqualTo(1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.stream.Collectors.toCollection;

import com.google.auto.value.AutoValue;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
Expand All @@ -41,8 +43,11 @@
import dev.cel.common.CelVarDecl;
import dev.cel.common.ast.CelExpr;
import dev.cel.common.ast.CelExpr.CelCall;
import dev.cel.common.ast.CelExpr.CelComprehension;
import dev.cel.common.ast.CelExpr.CelList;
import dev.cel.common.ast.CelExpr.ExprKind.Kind;
import dev.cel.common.ast.CelMutableExpr;
import dev.cel.common.ast.CelMutableExpr.CelMutableComprehension;
import dev.cel.common.ast.CelMutableExprConverter;
import dev.cel.common.navigation.CelNavigableExpr;
import dev.cel.common.navigation.CelNavigableMutableAst;
Expand All @@ -60,6 +65,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Stream;

/**
* Performs Common Subexpression Elimination.
Expand Down Expand Up @@ -90,14 +96,15 @@ public class SubexpressionOptimizer implements CelAstOptimizer {
private static final SubexpressionOptimizer INSTANCE =
new SubexpressionOptimizer(SubexpressionOptimizerOptions.newBuilder().build());
private static final String BIND_IDENTIFIER_PREFIX = "@r";
private static final String MANGLED_COMPREHENSION_ITER_VAR_PREFIX = "@it";
private static final String MANGLED_COMPREHENSION_ITER_VAR2_PREFIX = "@it2";
private static final String MANGLED_COMPREHENSION_ACCU_VAR_PREFIX = "@ac";
private static final String CEL_BLOCK_FUNCTION = "cel.@block";
private static final String BLOCK_INDEX_PREFIX = "@index";
private static final Extension CEL_BLOCK_AST_EXTENSION_TAG =
Extension.create("cel_block", Version.of(1L, 1L), Component.COMPONENT_RUNTIME);

@VisibleForTesting static final String MANGLED_COMPREHENSION_ITER_VAR_PREFIX = "@it";
@VisibleForTesting static final String MANGLED_COMPREHENSION_ITER_VAR2_PREFIX = "@it2";
@VisibleForTesting static final String MANGLED_COMPREHENSION_ACCU_VAR_PREFIX = "@ac";

private final SubexpressionOptimizerOptions cseOptions;
private final AstMutator astMutator;
private final ImmutableSet<String> cseEliminableFunctions;
Expand Down Expand Up @@ -269,6 +276,8 @@ static void verifyOptimizedAstCorrectness(CelAbstractSyntaxTree ast) {
Verify.verify(
resultHasAtLeastOneBlockIndex,
"Expected at least one reference of index in cel.block result");

verifyNoInvalidScopedMangledVariables(celBlockExpr);
}

private static void verifyBlockIndex(CelExpr celExpr, int maxIndexValue) {
Expand All @@ -289,6 +298,67 @@ private static void verifyBlockIndex(CelExpr celExpr, int maxIndexValue) {
celExpr);
}

private static void verifyNoInvalidScopedMangledVariables(CelExpr celExpr) {
CelCall celBlockCall = celExpr.call();
CelExpr blockBody = celBlockCall.args().get(1);

ImmutableSet<String> allMangledVariablesInBlockBody =
CelNavigableExpr.fromExpr(blockBody)
.allNodes()
.map(CelNavigableExpr::expr)
.flatMap(SubexpressionOptimizer::extractMangledNames)
.collect(toImmutableSet());

CelList blockIndices = celBlockCall.args().get(0).list();
for (CelExpr blockIndex : blockIndices.elements()) {
ImmutableSet<String> indexDeclaredCompVariables =
CelNavigableExpr.fromExpr(blockIndex)
.allNodes()
.map(CelNavigableExpr::expr)
.filter(expr -> expr.getKind() == Kind.COMPREHENSION)
.map(CelExpr::comprehension)
.flatMap(comp -> Stream.of(comp.iterVar(), comp.iterVar2()))
.filter(iter -> !Strings.isNullOrEmpty(iter))
.collect(toImmutableSet());

boolean containsIllegalDeclaration =
CelNavigableExpr.fromExpr(blockIndex)
.allNodes()
.map(CelNavigableExpr::expr)
.filter(expr -> expr.getKind() == Kind.IDENT)
.map(expr -> expr.ident().name())
.filter(SubexpressionOptimizer::isMangled)
.anyMatch(
ident ->
!indexDeclaredCompVariables.contains(ident)
&& allMangledVariablesInBlockBody.contains(ident));

Verify.verify(
!containsIllegalDeclaration,
"Illegal declared reference to a comprehension variable found in block indices. Expr: %s",
celExpr);
}
}

private static Stream<String> extractMangledNames(CelExpr expr) {
if (expr.getKind().equals(Kind.IDENT)) {
String name = expr.ident().name();
return isMangled(name) ? Stream.of(name) : Stream.empty();
}
if (expr.getKind().equals(Kind.COMPREHENSION)) {
CelComprehension comp = expr.comprehension();
return Stream.of(comp.iterVar(), comp.iterVar2(), comp.accuVar())
.filter(x -> !Strings.isNullOrEmpty(x))
.filter(SubexpressionOptimizer::isMangled);
}
return Stream.empty();
}

private static boolean isMangled(String name) {
return name.startsWith(MANGLED_COMPREHENSION_ITER_VAR_PREFIX)
|| name.startsWith(MANGLED_COMPREHENSION_ITER_VAR2_PREFIX);
}

private static CelAbstractSyntaxTree tagAstExtension(CelAbstractSyntaxTree ast) {
// Tag the extension
CelSource.Builder celSourceBuilder =
Expand Down Expand Up @@ -355,8 +425,8 @@ private List<CelMutableExpr> getCseCandidatesWithRecursionDepth(
navAst
.getRoot()
.descendants(TraversalOrder.PRE_ORDER)
.filter(node -> canEliminate(node, ineligibleExprs))
.filter(node -> node.height() <= recursionLimit)
.filter(node -> canEliminate(node, ineligibleExprs))
.sorted(Comparator.comparingInt(CelNavigableMutableExpr::height).reversed())
.collect(toImmutableList());
if (descendants.isEmpty()) {
Expand Down Expand Up @@ -441,7 +511,45 @@ private boolean canEliminate(
&& navigableExpr.expr().list().elements().isEmpty())
&& containsEliminableFunctionOnly(navigableExpr)
&& !ineligibleExprs.contains(navigableExpr.expr())
&& containsComprehensionIdentInSubexpr(navigableExpr);
&& containsComprehensionIdentInSubexpr(navigableExpr)
&& containsProperScopedComprehensionIdents(navigableExpr);
}

private boolean containsProperScopedComprehensionIdents(CelNavigableMutableExpr navExpr) {
if (!navExpr.getKind().equals(Kind.COMPREHENSION)) {
return true;
}

// For nested comprehensions of form [1].exists(x, [2].exists(y, x == y)), the inner
// comprehension [2].exists(y, x == y)
// should not be extracted out into a block index, as it causes issues with scoping.
ImmutableSet<String> mangledIterVars =
navExpr
.descendants()
.filter(x -> x.getKind().equals(Kind.IDENT))
.map(x -> x.expr().ident().name())
.filter(
name ->
name.startsWith(MANGLED_COMPREHENSION_ITER_VAR_PREFIX)
|| name.startsWith(MANGLED_COMPREHENSION_ITER_VAR2_PREFIX))
.collect(toImmutableSet());

CelNavigableMutableExpr parent = navExpr.parent().orElse(null);
while (parent != null) {
if (parent.getKind().equals(Kind.COMPREHENSION)) {
CelMutableComprehension comp = parent.expr().comprehension();
boolean containsParentIterReferences =
mangledIterVars.contains(comp.iterVar()) || mangledIterVars.contains(comp.iterVar2());

if (containsParentIterReferences) {
return false;
}
}

parent = parent.parent().orElse(null);
}

return true;
}

private boolean containsComprehensionIdentInSubexpr(CelNavigableMutableExpr navExpr) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import dev.cel.runtime.CelFunctionBinding;
import dev.cel.runtime.CelRuntime;
import dev.cel.runtime.CelRuntimeFactory;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -381,6 +382,31 @@ public void lazyEval_blockIndexEvaluatedOnlyOnce() throws Exception {
assertThat(invocation.get()).isEqualTo(1);
}

@Test
@SuppressWarnings({"Immutable", "unchecked"}) // Test only
public void lazyEval_withinComprehension_blockIndexEvaluatedOnlyOnce() throws Exception {
AtomicInteger invocation = new AtomicInteger();
CelRuntime celRuntime =
CelRuntimeFactory.standardCelRuntimeBuilder()
.addMessageTypes(TestAllTypes.getDescriptor())
.addFunctionBindings(
CelFunctionBinding.from(
"get_true_overload",
ImmutableList.of(),
arg -> {
invocation.getAndIncrement();
return true;
}))
.build();
CelAbstractSyntaxTree ast =
compileUsingInternalFunctions("cel.block([get_true()], [1,2,3].map(x, x < 0 || index0))");

List<Boolean> result = (List<Boolean>) celRuntime.createProgram(ast).eval();

assertThat(result).containsExactly(true, true, true);
assertThat(invocation.get()).isEqualTo(1);
}

@Test
@SuppressWarnings("Immutable") // Test only
public void lazyEval_multipleBlockIndices_inResultExpr() throws Exception {
Expand Down Expand Up @@ -452,9 +478,9 @@ public void lazyEval_nestedComprehension_indexReferencedInNestedScopes() throws
// Equivalent of [true, false, true].map(c0, [c0].map(c1, [c0, c1, true]))
CelAbstractSyntaxTree ast =
compileUsingInternalFunctions(
"cel.block([c0, c1, get_true()], [index2, false, index2].map(c0, [c0].map(c1, [index0,"
+ " index1, index2]))) == [[[true, true, true]], [[false, false, true]], [[true,"
+ " true, true]]]");
"cel.block([true, false, get_true()], [index2, false, index2].map(c0, [c0].map(c1, [c0,"
+ " c1, index2]))) == [[[true, true, true]], [[false, false, true]], [[true, true,"
+ " true]]]");

boolean result = (boolean) celRuntime.createProgram(ast).eval();

Expand Down
Loading