Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 34 additions & 15 deletions checker/internal/type_checker_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,11 @@ class ResolveVisitor : public AstVisitorBase {
bool namespace_rewrite;
};

struct AttributeResolution {
const VariableDecl* decl;
bool requires_disambiguation;
};

ResolveVisitor(absl::string_view container,
NamespaceGenerator namespace_generator,
const TypeCheckEnv& env, const Ast& ast,
Expand Down Expand Up @@ -295,7 +300,7 @@ class ResolveVisitor : public AstVisitorBase {
return functions_;
}

const absl::flat_hash_map<const Expr*, const VariableDecl*>& attributes()
const absl::flat_hash_map<const Expr*, AttributeResolution>& attributes()
const {
return attributes_;
}
Expand Down Expand Up @@ -481,7 +486,7 @@ class ResolveVisitor : public AstVisitorBase {

// References that were resolved and may require AST rewrites.
absl::flat_hash_map<const Expr*, FunctionResolution> functions_;
absl::flat_hash_map<const Expr*, const VariableDecl*> attributes_;
absl::flat_hash_map<const Expr*, AttributeResolution> attributes_;
absl::flat_hash_map<const Expr*, std::string> struct_types_;

absl::flat_hash_map<const Expr*, Type> types_;
Expand Down Expand Up @@ -974,8 +979,9 @@ void ResolveVisitor::ResolveFunctionOverloads(const Expr& expr,

const VariableDecl* absl_nullable ResolveVisitor::LookupLocalIdentifier(
absl::string_view name) {
// Container resolution doesn't apply for local vars so .foo is redundant but
// legal.
// Note: if we see a leading dot, this shouldn't resolve to a local variable,
// but we need to check whether we need to disambiguate against a global in
// the reference map.
if (absl::StartsWith(name, ".")) {
name = name.substr(1);
}
Expand Down Expand Up @@ -1012,14 +1018,16 @@ void ResolveVisitor::ResolveSimpleIdentifier(const Expr& expr,
absl::string_view name) {
// Local variables (comprehension, bind) are simple identifiers so we can
// skip generating the different namespace-qualified candidates.
const VariableDecl* decl = LookupLocalIdentifier(name);
const VariableDecl* local_decl = LookupLocalIdentifier(name);

if (decl != nullptr) {
attributes_[&expr] = decl;
types_[&expr] = inference_context_->InstantiateTypeParams(decl->type());
if (local_decl != nullptr && !absl::StartsWith(name, ".")) {
attributes_[&expr] = {local_decl, false};
types_[&expr] =
inference_context_->InstantiateTypeParams(local_decl->type());
return;
}

const VariableDecl* decl = nullptr;
namespace_generator_.GenerateCandidates(
name, [&decl, this](absl::string_view candidate) {
decl = LookupGlobalIdentifier(candidate);
Expand All @@ -1028,7 +1036,8 @@ void ResolveVisitor::ResolveSimpleIdentifier(const Expr& expr,
});

if (decl != nullptr) {
attributes_[&expr] = decl;
attributes_[&expr] = {decl,
/* requires_disambiguation= */ local_decl != nullptr};
types_[&expr] = inference_context_->InstantiateTypeParams(decl->type());
return;
}
Expand All @@ -1046,10 +1055,13 @@ void ResolveVisitor::ResolveQualifiedIdentifier(

// Local variables (comprehension, bind) are simple identifiers so we can
// skip generating the different namespace-qualified candidates.
const VariableDecl* decl = LookupLocalIdentifier(qualifiers[0]);
const VariableDecl* local_decl = LookupLocalIdentifier(qualifiers[0]);
const VariableDecl* decl = nullptr;

int matched_segment_index = -1;

if (decl != nullptr) {
if (local_decl != nullptr && !absl::StartsWith(qualifiers[0], ".")) {
decl = local_decl;
matched_segment_index = 0;
} else {
namespace_generator_.GenerateCandidates(
Expand Down Expand Up @@ -1080,7 +1092,9 @@ void ResolveVisitor::ResolveQualifiedIdentifier(
root = &root->select_expr().operand();
}

attributes_[root] = decl;
attributes_[root] = {decl,
/* requires_disambiguation= */ decl != local_decl &&
local_decl != nullptr};
types_[root] = inference_context_->InstantiateTypeParams(decl->type());

// fix-up select operations that were deferred.
Expand Down Expand Up @@ -1227,13 +1241,18 @@ class ResolveRewriter : public AstRewriterBase {
bool rewritten = false;
if (auto iter = visitor_.attributes().find(&expr);
iter != visitor_.attributes().end()) {
const VariableDecl* decl = iter->second;
const VariableDecl* decl = iter->second.decl;
auto& ast_ref = reference_map_[expr.id()];
ast_ref.set_name(decl->name());
std::string name = decl->name();
if (iter->second.requires_disambiguation &&
!absl::StartsWith(name, ".")) {
name = absl::StrCat(".", name);
}
ast_ref.set_name(name);
if (decl->has_value()) {
ast_ref.set_value(decl->value());
}
expr.mutable_ident_expr().set_name(decl->name());
expr.mutable_ident_expr().set_name(std::move(name));
rewritten = true;
} else if (auto iter = visitor_.functions().find(&expr);
iter != visitor_.functions().end()) {
Expand Down
82 changes: 82 additions & 0 deletions checker/internal/type_checker_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,12 @@ absl::Status RegisterMinimalBuiltins(google::protobuf::Arena* absl_nonnull arena
"equals",
/*return_type=*/BoolType{}, TypeParamType("A"), TypeParamType("A"))));

FunctionDecl ne_op;
ne_op.set_name("_!=_");
CEL_RETURN_IF_ERROR(ne_op.AddOverload(MakeOverloadDecl(
"not_equals",
/*return_type=*/BoolType{}, TypeParamType("A"), TypeParamType("A"))));

FunctionDecl ternary_op;
ternary_op.set_name("_?_:_");
CEL_RETURN_IF_ERROR(ternary_op.AddOverload(MakeOverloadDecl(
Expand Down Expand Up @@ -276,6 +282,7 @@ absl::Status RegisterMinimalBuiltins(google::protobuf::Arena* absl_nonnull arena
env.InsertFunctionIfAbsent(std::move(gt_op));
env.InsertFunctionIfAbsent(std::move(to_int));
env.InsertFunctionIfAbsent(std::move(eq_op));
env.InsertFunctionIfAbsent(std::move(ne_op));
env.InsertFunctionIfAbsent(std::move(ternary_op));
env.InsertFunctionIfAbsent(std::move(index_op));
env.InsertFunctionIfAbsent(std::move(to_dyn));
Expand Down Expand Up @@ -793,6 +800,81 @@ TEST(TypeCheckerImplTest, ComprehensionVarsShadowsQualifiedIdent) {
Not(Contains(Pair(_, IsVariableReference("x.y")))));
}

TEST(TypeCheckerImplTest, ComprehensionVarsShadowsQualifiedIdentTypeError) {
TypeCheckEnv env(GetSharedTestingDescriptorPool());
google::protobuf::Arena arena;
ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk());

env.InsertVariableIfAbsent(MakeVariableDecl("x.y", IntType()));

TypeCheckerImpl impl(std::move(env));
ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("[0].all(x, x.y == 0)"));
ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast)));

EXPECT_FALSE(result.IsValid());

EXPECT_THAT(
result.FormatError(),
HasSubstr("type 'int' cannot be the operand of a select operation"));
}

TEST(TypeCheckerImplTest, ComprehensionVarsDisamgiguatesQualifiedIdent) {
TypeCheckEnv env(GetSharedTestingDescriptorPool());
google::protobuf::Arena arena;
ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk());

env.InsertVariableIfAbsent(MakeVariableDecl("x.y", IntType()));

TypeCheckerImpl impl(std::move(env));
ASSERT_OK_AND_ASSIGN(auto ast,
MakeTestParsedAst("[{'y': 0}].all(x, .x.y == 2)"));
ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast)));

EXPECT_TRUE(result.IsValid());

EXPECT_THAT(result.GetIssues(), IsEmpty());
ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst());
EXPECT_THAT(checked_ast->reference_map(),
Contains(Pair(_, IsVariableReference(".x.y"))));
}

TEST(TypeCheckerImplTest, ComprehensionVarsDisamgiguatesQualifiedIdentMixed) {
TypeCheckEnv env(GetSharedTestingDescriptorPool());
google::protobuf::Arena arena;
ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk());

env.InsertVariableIfAbsent(MakeVariableDecl("x.y", StringType()));

TypeCheckerImpl impl(std::move(env));
ASSERT_OK_AND_ASSIGN(auto ast,
MakeTestParsedAst("[{'y': 0}].all(x, .x.y != x.y)"));
ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast)));

EXPECT_FALSE(result.IsValid());
EXPECT_THAT(
result.FormatError(),
HasSubstr("no matching overload for '_!=_' applied to '(string, int)'"));
}

TEST(TypeCheckerImplTest, ComprehensionVarsDisamgiguatesIdent) {
TypeCheckEnv env(GetSharedTestingDescriptorPool());
google::protobuf::Arena arena;
ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk());

env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType()));

TypeCheckerImpl impl(std::move(env));
ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("['foo'].all(x, .x == 2)"));
ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast)));

EXPECT_TRUE(result.IsValid());

EXPECT_THAT(result.GetIssues(), IsEmpty());
ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst());
EXPECT_THAT(checked_ast->reference_map(),
Contains(Pair(_, IsVariableReference(".x"))));
}

TEST(TypeCheckerImplTest, ComprehensionVarsCyclicParamAssignability) {
TypeCheckEnv env(GetSharedTestingDescriptorPool());
google::protobuf::Arena arena;
Expand Down