Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/frontends/lean/decl_attributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ bool decl_attributes::ok_for_inductive_type() const {
for (entry const & e : m_entries) {
name const & n = e.m_attr->get_name();
if (is_system_attribute(n)) {
if ((n != "class" && n != "vm_override" && !is_class_symbol_tracking_attribute(n)) || e.deleted())
if ((n != "class" && n != "vm_override" && n != "elab_field_alternatives"
&& !is_class_symbol_tracking_attribute(n)) || e.deleted())
return false;
}
}
Expand Down
250 changes: 178 additions & 72 deletions src/frontends/lean/elaborator.cpp

Large diffs are not rendered by default.

125 changes: 111 additions & 14 deletions src/frontends/lean/elaborator.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,23 +246,120 @@ class elaborator {
expr visit_equation(expr const & eq, unsigned num_fns);
expr visit_inaccessible(expr const & e, optional<expr> const & expected_type);

/** Field resolution: this is a field projection */
struct field_resolution_proj_fn {
/** Name of the structure that is the source of the field. Is ancestor of \c m_struct_name */
name m_base_struct_name;
/** Name of the structure for the projected expression */
name m_struct_name;
/** The field name for the projection */
name m_field_name;

field_resolution_proj_fn(name const & base_struct_name, name const & struct_name, name const & field_name):
m_base_struct_name(base_struct_name), m_struct_name(struct_name), m_field_name(field_name) {}

/** Get the name of the projection function */
name get_full_name() const { return m_base_struct_name + m_field_name; }
};

/** Field resolution: this is a "method" (a.k.a. extended dot notation) */
struct field_resolution_const {
/** If this is not equal to \c m_struct_name then we should insert parent projections, and this is
* the name of the structure that is the source of the field. Is ancestor of \c m_struct_name */
name m_base_struct_name;
/** Generalized structure name for the method expression */
name m_struct_name;
/** Name of the constant to use as a function */
name m_const_name;
/** Whether to find the first explicit argument in \c m_const_name that accepts the expression.
* If false, just use the first explicit argument. */
bool m_find_matching;

field_resolution_const(name const & base_struct_name, name const & struct_name, name const & const_name, bool find_matching = true):
m_base_struct_name(base_struct_name), m_struct_name(struct_name), m_const_name(const_name), m_find_matching(find_matching) {}

/** Get the structure name to search for when \c m_find_matching is true */
name const & get_base_name() const { return m_base_struct_name; }
};

/** Field resolution: projection is being used to make a "local" recursive call */
struct field_resolution_local_rec {
/** The generalized structure name for the argument to the call. */
name m_base_name;
/** The resolved name for the recursive call (for error reporting). */
name m_full_name;
/** The declaration for the "local" recursive call. */
local_decl m_ldecl;

field_resolution_local_rec(name const & base_name, name const & full_name, local_decl const & ldecl):
m_base_name(base_name), m_full_name(full_name), m_ldecl(ldecl) {}
};

struct field_resolution {
name m_S_name; // structure name of field expression type
name m_base_S_name; // structure name of field
name m_fname;
optional<local_decl> m_ldecl; // projection is a local constant: recursive call

field_resolution(name const & full_fname, optional<local_decl> ldecl = {}):
m_S_name(full_fname.get_prefix()), m_base_S_name(full_fname.get_prefix()),
m_fname(full_fname.get_string()), m_ldecl(ldecl) {}
field_resolution(const name & S_name, const name & base_S_name, const name & fname):
m_S_name(S_name), m_base_S_name(base_S_name), m_fname(fname) {}

name get_full_fname() const { return m_base_S_name + m_fname; }
enum class kind { ProjFn, Const, LocalRec };

kind m_kind;
union {
field_resolution_proj_fn m_proj_fn;
field_resolution_const m_const;
field_resolution_local_rec m_local_rec;
};

field_resolution(field_resolution_proj_fn const & fr_proj_fn):
m_kind(kind::ProjFn), m_proj_fn(fr_proj_fn) {}
field_resolution(field_resolution_const const & fr_const):
m_kind(kind::Const), m_const(fr_const) {}
field_resolution(field_resolution_local_rec const & fr_local_rec):
m_kind(kind::LocalRec), m_local_rec(fr_local_rec) {}

field_resolution(field_resolution const & fr):m_kind(fr.m_kind) {
switch (m_kind) {
case kind::ProjFn: new (&m_proj_fn) auto(fr.m_proj_fn); break;
case kind::Const: new (&m_const) auto(fr.m_const); break;
case kind::LocalRec: new (&m_local_rec) auto(fr.m_local_rec); break;
default: lean_unreachable();
}
}

~field_resolution() {
switch (m_kind) {
case kind::ProjFn: m_proj_fn.~field_resolution_proj_fn(); break;
case kind::Const: m_const.~field_resolution_const(); break;
case kind::LocalRec: m_local_rec.~field_resolution_local_rec(); break;
}
}

field_resolution_proj_fn const & get_proj_fn() { lean_assert(m_kind == kind::ProjFn); return m_proj_fn; }
field_resolution_const const & get_const() { lean_assert(m_kind == kind::Const); return m_const; }
field_resolution_local_rec const & get_local_rec() { lean_assert(m_kind == kind::LocalRec); return m_local_rec; }

/** The function name to use when reporting errors associated to this field resolution. */
name get_full_name() {
switch (m_kind) {
case kind::ProjFn: return get_proj_fn().get_full_name();
case kind::Const: return get_const().m_const_name;
case kind::LocalRec: return get_local_rec().m_full_name;
default: lean_unreachable();
}
}

/** The structure name to use when reporting errors associated to this field resolution. */
name get_base_name() {
switch (m_kind) {
case kind::ProjFn: return get_proj_fn().m_base_struct_name;
case kind::Const: return get_const().m_base_struct_name;
case kind::LocalRec: return get_local_rec().m_base_name;
default: lean_unreachable();
}
}
};

field_resolution field_to_decl(expr const & e, expr const & s, expr const & s_type);
field_resolution find_field_fn(expr const & e, expr const & s, expr const & s_type);
field_resolution resolve_field_notation_method(expr const & e, expr const & s, expr const & s_type, name const & struct_name, bool find_matching = true,
buffer<name> const & extra_base_names = buffer<name>());
field_resolution resolve_field_notation_aux(expr const & e, expr const & s, expr const & s_type);
/** \c e is the field notation expression, \c s is the elaborated expression, \c s_type is its type */
field_resolution resolve_field_notation(expr const & e, expr const & s, expr const & s_type);

expr visit_field(expr const & e, optional<expr> const & expected_type);
expr instantiate_mvars(expr const & e, std::function<bool(expr const &)> pred); // NOLINT
expr visit_structure_instance(expr const & e, optional<expr> expected_type);
Expand Down
12 changes: 12 additions & 0 deletions src/frontends/lean/structure_cmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,18 @@ optional<name> find_field(environment const & env, name const & S_name, name con
return {};
}

optional<pair<name, name>> find_method(environment const & env, name const & struct_name, name const & field_name) {
if (env.find(struct_name + field_name))
return some(mk_pair(struct_name, struct_name + field_name));
if (is_structure_like(env, struct_name)) {
for (auto const & p : get_parent_structures(env, struct_name)) {
if (auto m = find_method(env, p, field_name))
return m;
}
}
return {};
}

void get_structure_fields_flattened(environment const & env, name const & structure_name, buffer<name> & full_fnames) {
for (auto const & fname : get_structure_fields(env, structure_name)) {
full_fnames.push_back(structure_name + fname);
Expand Down
3 changes: 3 additions & 0 deletions src/frontends/lean/structure_cmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ optional<name> find_field(environment const & env, name const & S_name, name con
optional<expr> mk_base_projections(environment const & env, name const & S_name, name const & base_S_name, expr const & e);
/** \brief Return an unelaborated expression applying a field projection */
expr mk_proj_app(environment const & env, name const & S_name, name const & fname, expr const & e, expr const & ref = {});
/** \brief Searches for `struct_name.field_name` in the environment, and if `struct_name` is a structure, recursively does the same for parent structures.
* Returns (S', n) where S' is the name of the (generalized) structure and n is the name corresponding to \c field_name */
optional<pair<name, name>> find_method(environment const & env, name const & struct_name, name const & field_name);

/* Default value support */
optional<name> has_default_value(environment const & env, name const & S_name, name const & fname);
Expand Down
16 changes: 16 additions & 0 deletions src/library/attribute_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Author: Leonardo de Moura

namespace lean {
template class typed_attribute<indices_attribute_data>;
template class typed_attribute<names_attribute_data>;
template class typed_attribute<key_value_data>;

ast_id key_value_data::parse(abstract_parser & p) {
Expand Down Expand Up @@ -291,6 +292,21 @@ ast_id indices_attribute_data::parse(abstract_parser & p) {
return data.m_id;
}

ast_id names_attribute_data::parse(abstract_parser & p) {
buffer<name> names;
lean_assert(dynamic_cast<parser *>(&p));
auto& p2 = *static_cast<parser *>(&p);
auto& data = p2.new_ast("names", p2.pos());
while (p2.curr_is_identifier()) {
name n = p2.get_name_val();
data.push(p2.new_ast("ident", p2.pos(), n).m_id);
names.push_back(n);
p2.next();
}
m_names = to_list(names);
return data.m_id;
}

void register_incompatible(char const * attr1, char const * attr2) {
lean_assert(is_system_attribute(attr1));
lean_assert(is_system_attribute(attr2));
Expand Down
30 changes: 29 additions & 1 deletion src/library/attribute_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,31 @@ struct indices_attribute_data : public attr_data {
}
};

struct names_attribute_data : public attr_data {
list<name> m_names;
names_attribute_data(list<name> const & names) : m_names(names) {}
names_attribute_data() : names_attribute_data(list<name>()) {}

virtual unsigned hash() const override {
unsigned h = 0;
for (name n : m_names)
h = ::lean::hash(h, n.hash());
return h;
}
void write(serializer & s) const {
write_list(s, m_names);
}
void read(deserializer & d) {
m_names = read_list<name>(d);
}
ast_id parse(abstract_parser & p) override;
virtual void print(std::ostream & out) override {
for (auto n : m_names) {
out << " " << n;
}
}
};

struct key_value_data : public attr_data {
// generalize: name_map<std::string> m_pairs;
std::string m_symbol;
Expand Down Expand Up @@ -260,7 +285,10 @@ struct key_value_data : public attr_data {
/** \brief Attribute that represents a list of indices. input and output are 1-indexed for convenience. */
typedef typed_attribute<indices_attribute_data> indices_attribute;

/** \brief Attribute that represents a list of indices. input and output are 1-indexed for convenience. */
/** \brief Attribute that represents a list of names. */
typedef typed_attribute<names_attribute_data> names_attribute;

/** \brief Attribute that represents a single key/value pair. */
typedef typed_attribute<key_value_data> key_value_attribute;

class user_attribute_ext {
Expand Down
12 changes: 12 additions & 0 deletions src/library/constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,12 @@ name const * g_false_rec = nullptr;
name const * g_false_of_true_eq_false = nullptr;
name const * g_fin_mk = nullptr;
name const * g_fin_ne_of_vne = nullptr;
name const * g_forall = nullptr;
name const * g_forall_congr = nullptr;
name const * g_forall_congr_eq = nullptr;
name const * g_forall_not_of_not_exists = nullptr;
name const * g_format = nullptr;
name const * g_function = nullptr;
name const * g_funext = nullptr;
name const * g_has_add = nullptr;
name const * g_has_add_add = nullptr;
Expand Down Expand Up @@ -216,6 +218,7 @@ name const * g_of_eq_true = nullptr;
name const * g_opt_param = nullptr;
name const * g_or = nullptr;
name const * g_out_param = nullptr;
name const * g_pi = nullptr;
name const * g_pprod = nullptr;
name const * g_pprod_fst = nullptr;
name const * g_pprod_mk = nullptr;
Expand Down Expand Up @@ -344,10 +347,12 @@ void initialize_constants() {
g_false_of_true_eq_false = new name{"false_of_true_eq_false"};
g_fin_mk = new name{"fin", "mk"};
g_fin_ne_of_vne = new name{"fin", "ne_of_vne"};
g_forall = new name{"forall"};
g_forall_congr = new name{"forall_congr"};
g_forall_congr_eq = new name{"forall_congr_eq"};
g_forall_not_of_not_exists = new name{"forall_not_of_not_exists"};
g_format = new name{"format"};
g_function = new name{"function"};
g_funext = new name{"funext"};
g_has_add = new name{"has_add"};
g_has_add_add = new name{"has_add", "add"};
Expand Down Expand Up @@ -497,6 +502,7 @@ void initialize_constants() {
g_opt_param = new name{"opt_param"};
g_or = new name{"or"};
g_out_param = new name{"out_param"};
g_pi = new name{"pi"};
g_pprod = new name{"pprod"};
g_pprod_fst = new name{"pprod", "fst"};
g_pprod_mk = new name{"pprod", "mk"};
Expand Down Expand Up @@ -626,10 +632,12 @@ void finalize_constants() {
delete g_false_of_true_eq_false;
delete g_fin_mk;
delete g_fin_ne_of_vne;
delete g_forall;
delete g_forall_congr;
delete g_forall_congr_eq;
delete g_forall_not_of_not_exists;
delete g_format;
delete g_function;
delete g_funext;
delete g_has_add;
delete g_has_add_add;
Expand Down Expand Up @@ -779,6 +787,7 @@ void finalize_constants() {
delete g_opt_param;
delete g_or;
delete g_out_param;
delete g_pi;
delete g_pprod;
delete g_pprod_fst;
delete g_pprod_mk;
Expand Down Expand Up @@ -907,10 +916,12 @@ name const & get_false_rec_name() { return *g_false_rec; }
name const & get_false_of_true_eq_false_name() { return *g_false_of_true_eq_false; }
name const & get_fin_mk_name() { return *g_fin_mk; }
name const & get_fin_ne_of_vne_name() { return *g_fin_ne_of_vne; }
name const & get_forall_name() { return *g_forall; }
name const & get_forall_congr_name() { return *g_forall_congr; }
name const & get_forall_congr_eq_name() { return *g_forall_congr_eq; }
name const & get_forall_not_of_not_exists_name() { return *g_forall_not_of_not_exists; }
name const & get_format_name() { return *g_format; }
name const & get_function_name() { return *g_function; }
name const & get_funext_name() { return *g_funext; }
name const & get_has_add_name() { return *g_has_add; }
name const & get_has_add_add_name() { return *g_has_add_add; }
Expand Down Expand Up @@ -1060,6 +1071,7 @@ name const & get_of_eq_true_name() { return *g_of_eq_true; }
name const & get_opt_param_name() { return *g_opt_param; }
name const & get_or_name() { return *g_or; }
name const & get_out_param_name() { return *g_out_param; }
name const & get_pi_name() { return *g_pi; }
name const & get_pprod_name() { return *g_pprod; }
name const & get_pprod_fst_name() { return *g_pprod_fst; }
name const & get_pprod_mk_name() { return *g_pprod_mk; }
Expand Down
3 changes: 3 additions & 0 deletions src/library/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,12 @@ name const & get_false_rec_name();
name const & get_false_of_true_eq_false_name();
name const & get_fin_mk_name();
name const & get_fin_ne_of_vne_name();
name const & get_forall_name();
name const & get_forall_congr_name();
name const & get_forall_congr_eq_name();
name const & get_forall_not_of_not_exists_name();
name const & get_format_name();
name const & get_function_name();
name const & get_funext_name();
name const & get_has_add_name();
name const & get_has_add_add_name();
Expand Down Expand Up @@ -218,6 +220,7 @@ name const & get_of_eq_true_name();
name const & get_opt_param_name();
name const & get_or_name();
name const & get_out_param_name();
name const & get_pi_name();
name const & get_pprod_name();
name const & get_pprod_fst_name();
name const & get_pprod_mk_name();
Expand Down
3 changes: 3 additions & 0 deletions src/library/constants.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,12 @@ false.rec
false_of_true_eq_false
fin.mk
fin.ne_of_vne
forall
forall_congr
forall_congr_eq
forall_not_of_not_exists
format
function
funext
has_add
has_add.add
Expand Down Expand Up @@ -211,6 +213,7 @@ of_eq_true
opt_param
or
out_param
pi
pprod
pprod.fst
pprod.mk
Expand Down