diff --git a/src/frontends/lean/decl_attributes.cpp b/src/frontends/lean/decl_attributes.cpp index a83259b9bd..f2e3298e99 100644 --- a/src/frontends/lean/decl_attributes.cpp +++ b/src/frontends/lean/decl_attributes.cpp @@ -151,7 +151,8 @@ bool decl_attributes::ok_for_inductive_type() const { for (entry const & e : m_entries) { name const & n = e.m_attr->get_name(); if (is_system_attribute(n)) { - if ((n != "class" && n != "vm_override" && !is_class_symbol_tracking_attribute(n)) || e.deleted()) + if ((n != "class" && n != "vm_override" && n != "elab_field_alternatives" + && !is_class_symbol_tracking_attribute(n)) || e.deleted()) return false; } } diff --git a/src/frontends/lean/elaborator.cpp b/src/frontends/lean/elaborator.cpp index a51fb7e4d5..950f4db63d 100644 --- a/src/frontends/lean/elaborator.cpp +++ b/src/frontends/lean/elaborator.cpp @@ -78,6 +78,7 @@ Author: Leonardo de Moura namespace lean { static name * g_elab_strategy = nullptr; static name * g_elaborator_coercions = nullptr; +static name * g_elab_field_alternatives = nullptr; bool get_elaborator_coercions(options const & opts) { return opts.get_bool(*g_elaborator_coercions, LEAN_DEFAULT_ELABORATOR_COERCIONS); @@ -130,6 +131,10 @@ elaborator_strategy get_elaborator_strategy(environment const & env, name const return elaborator_strategy::WithExpectedType; } +static names_attribute const & get_elab_field_alternatives_attribute() { + return static_cast(get_system_attribute(*g_elab_field_alternatives)); +} + #define trace_elab(CODE) lean_trace("elaborator", scope_trace_env _scope(m_env, m_ctx); CODE) #define trace_elab_detail(CODE) lean_trace("elaborator_detail", scope_trace_env _scope(m_env, m_ctx); CODE) #define trace_elab_debug(CODE) lean_trace("elaborator_debug", scope_trace_env _scope(m_env, m_ctx); CODE) @@ -1876,41 +1881,78 @@ expr elaborator::visit_app_core(expr fn, buffer const & args, optionalmk_ref()); - proj_type = field_res.m_ldecl->get_type(); - } else { - proj = copy_tag(fn, mk_constant(field_res.get_full_fname())); - proj_type = m_env.get(field_res.get_full_fname()).get_type(); + optional find_matching = {}; + switch (field_res.m_kind) { + case field_resolution::kind::ProjFn: { + auto fr = field_res.get_proj_fn(); + expr coerced_s = *mk_base_projections(m_env, fr.m_struct_name, fr.m_base_struct_name, mk_as_is(s)); + expr proj_app = mk_proj_app(m_env, fr.m_base_struct_name, fr.m_field_name, coerced_s, ref); + expr new_proj = visit_function(proj_app, has_args, has_args ? none_expr() : expected_type, ref); + return visit_base_app(new_proj, amask, args, expected_type, ref); + } + case field_resolution::kind::LocalRec: { + auto fr = field_res.get_local_rec(); + proj = copy_tag(fn, fr.m_ldecl.mk_ref()); + proj_type = fr.m_ldecl.get_type(); + find_matching = fr.m_base_name; + break; + } + case field_resolution::kind::Const: { + auto fr = field_res.get_const(); + expr coerced_s = *mk_base_projections(m_env, fr.m_struct_name, fr.m_base_struct_name, mk_as_is(s)); + s = copy_tag(s, std::move(coerced_s)); + proj = copy_tag(fn, mk_constant(fr.m_const_name)); + proj_type = m_env.get(field_res.get_full_name()).get_type(); + if (fr.m_find_matching) { + find_matching = fr.get_base_name(); + } + break; + } + default: lean_unreachable(); } + + //type_context_old::tmp_locals locals(m_ctx); + buffer fun_args; buffer new_args; - unsigned i = 0; + unsigned i = 0; while (is_pi(proj_type)) { if (is_explicit(binding_info(proj_type))) { - if (is_app_of(binding_domain(proj_type), field_res.m_base_S_name)) { + if (!find_matching || is_app_of(binding_domain(proj_type), *find_matching)) { /* found s location */ - expr coerced_s = *mk_base_projections(m_env, field_res.m_S_name, field_res.m_base_S_name, mk_as_is(s)); - new_args.push_back(copy_tag(fn, std::move(coerced_s))); + new_args.push_back(s); for (; i < args.size(); i++) new_args.push_back(args[i]); - expr new_proj = visit_function(proj, has_args, has_args ? none_expr() : expected_type, ref); - return visit_base_app(new_proj, amask, new_args, expected_type, ref); + + if (fun_args.empty()) { + expr new_proj = visit_function(proj, has_args, has_args ? none_expr() : expected_type, ref); + return visit_base_app(new_proj, amask, new_args, expected_type, ref); + } else { + expr new_proj = visit_function(proj, true, none_expr(), ref); + optional expected_type_f = expected_type ? some_expr(Pi(fun_args, *expected_type)) : none_expr(); + expr f = visit_base_app(new_proj, amask, new_args, expected_type_f, ref); + return copy_tag(ref, Fun(fun_args, f)); + } } else { - if (i >= args.size()) { - throw elaborator_exception(ref, sstream() << "invalid field notation, insufficient number of arguments for '" - << field_res.get_full_fname() << "'"); + if (i >= args.size()) { // TODO make this generate a lambda expression + auto funarg = mk_local(mk_fresh_name(), binding_name(proj_type), binding_domain(proj_type), binding_info(proj_type)); + fun_args.push_back(funarg); + new_args.push_back(funarg); + //throw elaborator_exception(ref, sstream() << "invalid field notation, insufficient number of arguments for '" + // << field_res.get_full_name() << "'"); + } else { + new_args.push_back(args[i]); } - new_args.push_back(args[i]); i++; } } proj_type = binding_body(proj_type); } throw elaborator_exception(ref, sstream() << "invalid field notation, function '" - << field_res.get_full_fname() << "' does not have explicit argument with type (" - << field_res.m_base_S_name << " ...)"); + << field_res.get_full_name() << "' does not have explicit argument with type (" + << field_res.get_base_name() << " ...)"); } else { expr new_fn = visit_function(fn, has_args, has_args ? none_expr() : expected_type, ref); /* Check if we should use a custom elaboration procedure for this application. */ @@ -2685,24 +2727,94 @@ expr elaborator::visit_inaccessible(expr const & e, optional const & expec return copy_tag(e, mk_inaccessible(new_a)); } -elaborator::field_resolution elaborator::field_to_decl(expr const & e, expr const & s, expr const & s_type) { +elaborator::field_resolution elaborator::resolve_field_notation_method(expr const & e, expr const & s, expr const & s_type, name const & struct_name, bool find_matching, + buffer const & extra_base_names) { + lean_assert(is_field_notation(e) && !is_anonymous_field_notation(e)); + name fname = get_field_notation_field_name(e); + + if (auto m = find_method(m_env, struct_name, fname)) { + return field_resolution_const(m->first, struct_name, m->second, find_matching); + } + + // Now try to look for extension methods + if (auto data = get_elab_field_alternatives_attribute().get(m_env, struct_name)) { + for (name const & alt : data->m_names) { + if (m_env.find(alt + fname)) { + return field_resolution_const(struct_name, struct_name, alt + fname, false); + } + } + } + // then do the same for the extra base_names + for (name const & alt : extra_base_names) { + if (m_env.find(alt + fname)) { + return field_resolution_const(struct_name, struct_name, alt + fname, false); + } + } + // prefer 'unknown identifier' error when lhs is a constant of non-value type - if (is_field_notation(e)) { - auto lhs = macro_arg(e, 0); - if (is_constant(lhs)) { - type_context_old::tmp_locals locals(m_ctx); - expr t = whnf(s_type); - while (is_pi(t)) { - t = whnf(instantiate(binding_body(t), locals.push_local_from_binding(t))); + auto lhs = macro_arg(e, 0); + if (is_constant(lhs)) { + type_context_old::tmp_locals locals(m_ctx); + expr t = whnf(s_type); + while (is_pi(t)) { + t = whnf(instantiate(binding_body(t), locals.push_local_from_binding(t))); + } + if (is_sort(t) && !is_anonymous_field_notation(e)) { + name fname = get_field_notation_field_name(e); + throw elaborator_exception(lhs, format("unknown identifier '") + format(const_name(lhs) + fname) + format("'")); + } + } + + auto pp_fn = mk_pp_ctx(); + throw elaborator_exception(e, format("invalid field notation, '") + format(fname) + format("'") + + format(" is not a valid \"field\" because environment does not contain ") + + format("'") + format(struct_name + fname) + format("'") + + pp_indent(pp_fn, s) + + line() + format("which has type") + + pp_indent(pp_fn, s_type)); +} + +elaborator::field_resolution elaborator::resolve_field_notation_aux(expr const & e, expr const & s, expr const & s_type) { + lean_assert(is_field_notation(e)); + + // If it's a function, resolve the field as a method in the function/pi/implies/forall namespaces. + if (is_pi(s_type)) { + type_context_old::tmp_locals locals(m_ctx); + expr t = s_type; + while (is_pi(t)) { + t = whnf(instantiate(binding_body(t), locals.push_local_from_binding(t))); + } + bool is_forall = false; + try { + expr t2 = m_ctx.relaxed_whnf(m_ctx.infer(t)); + is_forall = t2 == mk_Prop(); + } catch (exception &) {} + + name struct_name; + buffer extra; + if (is_forall) { + if (is_arrow(s_type)) { + struct_name = get_implies_name(); + extra.push_back(get_function_name()); + extra.push_back(get_forall_name()); + extra.push_back(get_pi_name()); + } else { + struct_name = get_forall_name(); + extra.push_back(get_pi_name()); } - if (is_sort(t) && !is_anonymous_field_notation(e)) { - name fname = get_field_notation_field_name(e); - throw elaborator_exception(lhs, format("unknown identifier '") + format(const_name(lhs)) + format(".") + - format(fname) + format("'")); + } else { + if (is_arrow(s_type)) { + struct_name = get_function_name(); + extra.push_back(get_pi_name()); + } else { + struct_name = get_pi_name(); } } + return resolve_field_notation_method(e, s, s_type, struct_name, false, extra); } - expr I = get_app_fn(s_type); + + expr I = get_app_fn(s_type); + if (!is_constant(I)) { auto pp_fn = mk_pp_ctx(); throw elaborator_exception(e, format("invalid field notation, type is not of the form (C ...) where C is a constant") + @@ -2710,17 +2822,20 @@ elaborator::field_resolution elaborator::field_to_decl(expr const & e, expr cons line() + format("has type") + pp_indent(pp_fn, s_type)); } + + auto struct_name = const_name(I); + if (is_anonymous_field_notation(e)) { - if (!is_structure(m_env, const_name(I))) { + if (!is_structure(m_env, struct_name)) { auto pp_fn = mk_pp_ctx(); throw elaborator_exception(e, format("invalid projection, structure expected") + pp_indent(pp_fn, s) + line() + format("has type") + pp_indent(pp_fn, s_type)); } - auto fnames = get_structure_fields(m_env, const_name(I)); + auto fnames = get_structure_fields(m_env, struct_name); unsigned fidx = get_field_notation_field_idx(e); - if (fidx == 0) { + if (fidx == 0) { throw elaborator_exception(e, "invalid projection, index must be greater than 0"); } if (fidx > fnames.size()) { @@ -2731,37 +2846,30 @@ elaborator::field_resolution elaborator::field_to_decl(expr const & e, expr cons line() + format("which has type") + pp_indent(pp_fn, s_type)); } - return const_name(I) + fnames[fidx-1]; + return field_resolution_proj_fn(struct_name, struct_name, fnames[fidx-1]); } else { - name fname = get_field_notation_field_name(e); - // search for "true" fields first, including in parent structures - if (is_structure_like(m_env, const_name(I))) - if (auto p = find_field(m_env, const_name(I), fname)) - return field_resolution(const_name(I), *p, fname); - name full_fname = const_name(I) + fname; - name local_name = full_fname.replace_prefix(get_namespace(env()), {}); - if (auto ldecl = m_ctx.lctx().find_if([&](local_decl const & decl) { - return decl.get_info().is_rec() && decl.get_pp_name() == local_name; - })) { - // projection is recursive call - return field_resolution(full_fname, ldecl); - } - if (!m_env.find(full_fname)) { - auto pp_fn = mk_pp_ctx(); - throw elaborator_exception(e, format("invalid field notation, '") + format(fname) + format("'") + - format(" is not a valid \"field\" because environment does not contain ") + - format("'") + format(full_fname) + format("'") + - pp_indent(pp_fn, s) + - line() + format("which has type") + - pp_indent(pp_fn, s_type)); + name fname = get_field_notation_field_name(e); + + // Search for "true" fields first, including in parent structures + if (is_structure_like(m_env, struct_name)) + if (auto p = find_field(m_env, struct_name, fname)) + return field_resolution_proj_fn(*p, struct_name, fname); + + // Check if field notation is being used to make a "local" recursive call. + name full_fname = struct_name + fname; + name local_name = full_fname.replace_prefix(get_namespace(m_env), {}); + if (auto ldecl = m_ctx.lctx().find_if([&](local_decl const & decl) { return decl.get_info().is_rec() && decl.get_pp_name() == local_name; })) { + return field_resolution_local_rec(struct_name, full_fname, *ldecl); } - return full_fname; + + return resolve_field_notation_method(e, s, s_type, struct_name); } } -elaborator::field_resolution elaborator::find_field_fn(expr const & e, expr const & s, expr const & s_type) { +elaborator::field_resolution elaborator::resolve_field_notation(expr const & e, expr const & s, expr const & s_type) { + lean_assert(is_field_notation(e)); try { - return field_to_decl(e, s, s_type); + return resolve_field_notation_aux(e, s, s_type); } catch (elaborator_exception & ex1) { expr new_s_type = s_type; if (auto d = unfold_term(env(), new_s_type)) @@ -2770,7 +2878,7 @@ elaborator::field_resolution elaborator::find_field_fn(expr const & e, expr cons if (new_s_type == s_type) throw; try { - return find_field_fn(e, s, new_s_type); + return resolve_field_notation(e, s, new_s_type); } catch (elaborator_exception & ex2) { throw nested_elaborator_exception(ex2.get_pos(), ex1, ex2.pp()); } @@ -2779,17 +2887,7 @@ elaborator::field_resolution elaborator::find_field_fn(expr const & e, expr cons expr elaborator::visit_field(expr const & e, optional const & expected_type) { lean_assert(is_field_notation(e)); - expr s = visit(macro_arg(e, 0), none_expr()); - expr s_type = head_beta_reduce(instantiate_mvars(infer_type(s))); - auto field_res = find_field_fn(e, s, s_type); - expr proj_app; - if (field_res.m_ldecl) { - proj_app = copy_tag(e, mk_app(field_res.m_ldecl->mk_ref(), mk_as_is(s))); - } else { - expr new_e = *mk_base_projections(m_env, field_res.m_S_name, field_res.m_base_S_name, mk_as_is(s)); - proj_app = mk_proj_app(m_env, field_res.m_base_S_name, field_res.m_fname, new_e, e); - } - return visit(proj_app, expected_type); + return visit_app_core(e, buffer(), expected_type, e); } class reduce_projections_visitor : public replace_visitor { @@ -4326,6 +4424,13 @@ void initialize_elaborator() { register_incompatible("elab_simple", "elab_as_eliminator"); register_incompatible("elab_with_expected_type", "elab_as_eliminator"); + g_elab_field_alternatives = new name("elab_field_alternatives"); + + register_system_attribute( + names_attribute( + *g_elab_field_alternatives, + "provides alternative prefixes to search when elaborating field notation")); + DECLARE_VM_BUILTIN(name({"environment", "add_defn_eqns"}), environment_add_defn_eqns); DECLARE_VM_BUILTIN(name({"tactic", "save_type_info"}), tactic_save_type_info); @@ -4338,6 +4443,7 @@ void initialize_elaborator() { void finalize_elaborator() { delete g_elab_strategy; + delete g_elab_field_alternatives; delete g_elaborator_coercions; } } diff --git a/src/frontends/lean/elaborator.h b/src/frontends/lean/elaborator.h index 9d40de292c..47f6e6ddac 100644 --- a/src/frontends/lean/elaborator.h +++ b/src/frontends/lean/elaborator.h @@ -246,23 +246,120 @@ class elaborator { expr visit_equation(expr const & eq, unsigned num_fns); expr visit_inaccessible(expr const & e, optional const & expected_type); + /** Field resolution: this is a field projection */ + struct field_resolution_proj_fn { + /** Name of the structure that is the source of the field. Is ancestor of \c m_struct_name */ + name m_base_struct_name; + /** Name of the structure for the projected expression */ + name m_struct_name; + /** The field name for the projection */ + name m_field_name; + + field_resolution_proj_fn(name const & base_struct_name, name const & struct_name, name const & field_name): + m_base_struct_name(base_struct_name), m_struct_name(struct_name), m_field_name(field_name) {} + + /** Get the name of the projection function */ + name get_full_name() const { return m_base_struct_name + m_field_name; } + }; + + /** Field resolution: this is a "method" (a.k.a. extended dot notation) */ + struct field_resolution_const { + /** If this is not equal to \c m_struct_name then we should insert parent projections, and this is + * the name of the structure that is the source of the field. Is ancestor of \c m_struct_name */ + name m_base_struct_name; + /** Generalized structure name for the method expression */ + name m_struct_name; + /** Name of the constant to use as a function */ + name m_const_name; + /** Whether to find the first explicit argument in \c m_const_name that accepts the expression. + * If false, just use the first explicit argument. */ + bool m_find_matching; + + field_resolution_const(name const & base_struct_name, name const & struct_name, name const & const_name, bool find_matching = true): + m_base_struct_name(base_struct_name), m_struct_name(struct_name), m_const_name(const_name), m_find_matching(find_matching) {} + + /** Get the structure name to search for when \c m_find_matching is true */ + name const & get_base_name() const { return m_base_struct_name; } + }; + + /** Field resolution: projection is being used to make a "local" recursive call */ + struct field_resolution_local_rec { + /** The generalized structure name for the argument to the call. */ + name m_base_name; + /** The resolved name for the recursive call (for error reporting). */ + name m_full_name; + /** The declaration for the "local" recursive call. */ + local_decl m_ldecl; + + field_resolution_local_rec(name const & base_name, name const & full_name, local_decl const & ldecl): + m_base_name(base_name), m_full_name(full_name), m_ldecl(ldecl) {} + }; + struct field_resolution { - name m_S_name; // structure name of field expression type - name m_base_S_name; // structure name of field - name m_fname; - optional m_ldecl; // projection is a local constant: recursive call - - field_resolution(name const & full_fname, optional ldecl = {}): - m_S_name(full_fname.get_prefix()), m_base_S_name(full_fname.get_prefix()), - m_fname(full_fname.get_string()), m_ldecl(ldecl) {} - field_resolution(const name & S_name, const name & base_S_name, const name & fname): - m_S_name(S_name), m_base_S_name(base_S_name), m_fname(fname) {} - - name get_full_fname() const { return m_base_S_name + m_fname; } + enum class kind { ProjFn, Const, LocalRec }; + + kind m_kind; + union { + field_resolution_proj_fn m_proj_fn; + field_resolution_const m_const; + field_resolution_local_rec m_local_rec; + }; + + field_resolution(field_resolution_proj_fn const & fr_proj_fn): + m_kind(kind::ProjFn), m_proj_fn(fr_proj_fn) {} + field_resolution(field_resolution_const const & fr_const): + m_kind(kind::Const), m_const(fr_const) {} + field_resolution(field_resolution_local_rec const & fr_local_rec): + m_kind(kind::LocalRec), m_local_rec(fr_local_rec) {} + + field_resolution(field_resolution const & fr):m_kind(fr.m_kind) { + switch (m_kind) { + case kind::ProjFn: new (&m_proj_fn) auto(fr.m_proj_fn); break; + case kind::Const: new (&m_const) auto(fr.m_const); break; + case kind::LocalRec: new (&m_local_rec) auto(fr.m_local_rec); break; + default: lean_unreachable(); + } + } + + ~field_resolution() { + switch (m_kind) { + case kind::ProjFn: m_proj_fn.~field_resolution_proj_fn(); break; + case kind::Const: m_const.~field_resolution_const(); break; + case kind::LocalRec: m_local_rec.~field_resolution_local_rec(); break; + } + } + + field_resolution_proj_fn const & get_proj_fn() { lean_assert(m_kind == kind::ProjFn); return m_proj_fn; } + field_resolution_const const & get_const() { lean_assert(m_kind == kind::Const); return m_const; } + field_resolution_local_rec const & get_local_rec() { lean_assert(m_kind == kind::LocalRec); return m_local_rec; } + + /** The function name to use when reporting errors associated to this field resolution. */ + name get_full_name() { + switch (m_kind) { + case kind::ProjFn: return get_proj_fn().get_full_name(); + case kind::Const: return get_const().m_const_name; + case kind::LocalRec: return get_local_rec().m_full_name; + default: lean_unreachable(); + } + } + + /** The structure name to use when reporting errors associated to this field resolution. */ + name get_base_name() { + switch (m_kind) { + case kind::ProjFn: return get_proj_fn().m_base_struct_name; + case kind::Const: return get_const().m_base_struct_name; + case kind::LocalRec: return get_local_rec().m_base_name; + default: lean_unreachable(); + } + } }; - field_resolution field_to_decl(expr const & e, expr const & s, expr const & s_type); - field_resolution find_field_fn(expr const & e, expr const & s, expr const & s_type); + field_resolution resolve_field_notation_method(expr const & e, expr const & s, expr const & s_type, name const & struct_name, bool find_matching = true, + buffer const & extra_base_names = buffer()); + field_resolution resolve_field_notation_aux(expr const & e, expr const & s, expr const & s_type); + /** \c e is the field notation expression, \c s is the elaborated expression, \c s_type is its type */ + field_resolution resolve_field_notation(expr const & e, expr const & s, expr const & s_type); + expr visit_field(expr const & e, optional const & expected_type); expr instantiate_mvars(expr const & e, std::function pred); // NOLINT expr visit_structure_instance(expr const & e, optional expected_type); diff --git a/src/frontends/lean/structure_cmd.cpp b/src/frontends/lean/structure_cmd.cpp index d85198c756..803254dd8b 100644 --- a/src/frontends/lean/structure_cmd.cpp +++ b/src/frontends/lean/structure_cmd.cpp @@ -177,6 +177,18 @@ optional find_field(environment const & env, name const & S_name, name con return {}; } +optional> find_method(environment const & env, name const & struct_name, name const & field_name) { + if (env.find(struct_name + field_name)) + return some(mk_pair(struct_name, struct_name + field_name)); + if (is_structure_like(env, struct_name)) { + for (auto const & p : get_parent_structures(env, struct_name)) { + if (auto m = find_method(env, p, field_name)) + return m; + } + } + return {}; +} + void get_structure_fields_flattened(environment const & env, name const & structure_name, buffer & full_fnames) { for (auto const & fname : get_structure_fields(env, structure_name)) { full_fnames.push_back(structure_name + fname); diff --git a/src/frontends/lean/structure_cmd.h b/src/frontends/lean/structure_cmd.h index 946da1a89b..08e381e322 100644 --- a/src/frontends/lean/structure_cmd.h +++ b/src/frontends/lean/structure_cmd.h @@ -27,6 +27,9 @@ optional find_field(environment const & env, name const & S_name, name con optional mk_base_projections(environment const & env, name const & S_name, name const & base_S_name, expr const & e); /** \brief Return an unelaborated expression applying a field projection */ expr mk_proj_app(environment const & env, name const & S_name, name const & fname, expr const & e, expr const & ref = {}); +/** \brief Searches for `struct_name.field_name` in the environment, and if `struct_name` is a structure, recursively does the same for parent structures. + * Returns (S', n) where S' is the name of the (generalized) structure and n is the name corresponding to \c field_name */ +optional> find_method(environment const & env, name const & struct_name, name const & field_name); /* Default value support */ optional has_default_value(environment const & env, name const & S_name, name const & fname); diff --git a/src/library/attribute_manager.cpp b/src/library/attribute_manager.cpp index 1bbb7cd436..53568770b7 100644 --- a/src/library/attribute_manager.cpp +++ b/src/library/attribute_manager.cpp @@ -16,6 +16,7 @@ Author: Leonardo de Moura namespace lean { template class typed_attribute; +template class typed_attribute; template class typed_attribute; ast_id key_value_data::parse(abstract_parser & p) { @@ -291,6 +292,21 @@ ast_id indices_attribute_data::parse(abstract_parser & p) { return data.m_id; } +ast_id names_attribute_data::parse(abstract_parser & p) { + buffer names; + lean_assert(dynamic_cast(&p)); + auto& p2 = *static_cast(&p); + auto& data = p2.new_ast("names", p2.pos()); + while (p2.curr_is_identifier()) { + name n = p2.get_name_val(); + data.push(p2.new_ast("ident", p2.pos(), n).m_id); + names.push_back(n); + p2.next(); + } + m_names = to_list(names); + return data.m_id; +} + void register_incompatible(char const * attr1, char const * attr2) { lean_assert(is_system_attribute(attr1)); lean_assert(is_system_attribute(attr2)); diff --git a/src/library/attribute_manager.h b/src/library/attribute_manager.h index ec5d59f338..7dfe63c23f 100644 --- a/src/library/attribute_manager.h +++ b/src/library/attribute_manager.h @@ -220,6 +220,31 @@ struct indices_attribute_data : public attr_data { } }; +struct names_attribute_data : public attr_data { + list m_names; + names_attribute_data(list const & names) : m_names(names) {} + names_attribute_data() : names_attribute_data(list()) {} + + virtual unsigned hash() const override { + unsigned h = 0; + for (name n : m_names) + h = ::lean::hash(h, n.hash()); + return h; + } + void write(serializer & s) const { + write_list(s, m_names); + } + void read(deserializer & d) { + m_names = read_list(d); + } + ast_id parse(abstract_parser & p) override; + virtual void print(std::ostream & out) override { + for (auto n : m_names) { + out << " " << n; + } + } +}; + struct key_value_data : public attr_data { // generalize: name_map m_pairs; std::string m_symbol; @@ -260,7 +285,10 @@ struct key_value_data : public attr_data { /** \brief Attribute that represents a list of indices. input and output are 1-indexed for convenience. */ typedef typed_attribute indices_attribute; -/** \brief Attribute that represents a list of indices. input and output are 1-indexed for convenience. */ +/** \brief Attribute that represents a list of names. */ +typedef typed_attribute names_attribute; + +/** \brief Attribute that represents a single key/value pair. */ typedef typed_attribute key_value_attribute; class user_attribute_ext { diff --git a/src/library/constants.cpp b/src/library/constants.cpp index 82430027c4..a7f060e479 100644 --- a/src/library/constants.cpp +++ b/src/library/constants.cpp @@ -63,10 +63,12 @@ name const * g_false_rec = nullptr; name const * g_false_of_true_eq_false = nullptr; name const * g_fin_mk = nullptr; name const * g_fin_ne_of_vne = nullptr; +name const * g_forall = nullptr; name const * g_forall_congr = nullptr; name const * g_forall_congr_eq = nullptr; name const * g_forall_not_of_not_exists = nullptr; name const * g_format = nullptr; +name const * g_function = nullptr; name const * g_funext = nullptr; name const * g_has_add = nullptr; name const * g_has_add_add = nullptr; @@ -216,6 +218,7 @@ name const * g_of_eq_true = nullptr; name const * g_opt_param = nullptr; name const * g_or = nullptr; name const * g_out_param = nullptr; +name const * g_pi = nullptr; name const * g_pprod = nullptr; name const * g_pprod_fst = nullptr; name const * g_pprod_mk = nullptr; @@ -344,10 +347,12 @@ void initialize_constants() { g_false_of_true_eq_false = new name{"false_of_true_eq_false"}; g_fin_mk = new name{"fin", "mk"}; g_fin_ne_of_vne = new name{"fin", "ne_of_vne"}; + g_forall = new name{"forall"}; g_forall_congr = new name{"forall_congr"}; g_forall_congr_eq = new name{"forall_congr_eq"}; g_forall_not_of_not_exists = new name{"forall_not_of_not_exists"}; g_format = new name{"format"}; + g_function = new name{"function"}; g_funext = new name{"funext"}; g_has_add = new name{"has_add"}; g_has_add_add = new name{"has_add", "add"}; @@ -497,6 +502,7 @@ void initialize_constants() { g_opt_param = new name{"opt_param"}; g_or = new name{"or"}; g_out_param = new name{"out_param"}; + g_pi = new name{"pi"}; g_pprod = new name{"pprod"}; g_pprod_fst = new name{"pprod", "fst"}; g_pprod_mk = new name{"pprod", "mk"}; @@ -626,10 +632,12 @@ void finalize_constants() { delete g_false_of_true_eq_false; delete g_fin_mk; delete g_fin_ne_of_vne; + delete g_forall; delete g_forall_congr; delete g_forall_congr_eq; delete g_forall_not_of_not_exists; delete g_format; + delete g_function; delete g_funext; delete g_has_add; delete g_has_add_add; @@ -779,6 +787,7 @@ void finalize_constants() { delete g_opt_param; delete g_or; delete g_out_param; + delete g_pi; delete g_pprod; delete g_pprod_fst; delete g_pprod_mk; @@ -907,10 +916,12 @@ name const & get_false_rec_name() { return *g_false_rec; } name const & get_false_of_true_eq_false_name() { return *g_false_of_true_eq_false; } name const & get_fin_mk_name() { return *g_fin_mk; } name const & get_fin_ne_of_vne_name() { return *g_fin_ne_of_vne; } +name const & get_forall_name() { return *g_forall; } name const & get_forall_congr_name() { return *g_forall_congr; } name const & get_forall_congr_eq_name() { return *g_forall_congr_eq; } name const & get_forall_not_of_not_exists_name() { return *g_forall_not_of_not_exists; } name const & get_format_name() { return *g_format; } +name const & get_function_name() { return *g_function; } name const & get_funext_name() { return *g_funext; } name const & get_has_add_name() { return *g_has_add; } name const & get_has_add_add_name() { return *g_has_add_add; } @@ -1060,6 +1071,7 @@ name const & get_of_eq_true_name() { return *g_of_eq_true; } name const & get_opt_param_name() { return *g_opt_param; } name const & get_or_name() { return *g_or; } name const & get_out_param_name() { return *g_out_param; } +name const & get_pi_name() { return *g_pi; } name const & get_pprod_name() { return *g_pprod; } name const & get_pprod_fst_name() { return *g_pprod_fst; } name const & get_pprod_mk_name() { return *g_pprod_mk; } diff --git a/src/library/constants.h b/src/library/constants.h index c3e1f9acb2..78a17e7202 100644 --- a/src/library/constants.h +++ b/src/library/constants.h @@ -65,10 +65,12 @@ name const & get_false_rec_name(); name const & get_false_of_true_eq_false_name(); name const & get_fin_mk_name(); name const & get_fin_ne_of_vne_name(); +name const & get_forall_name(); name const & get_forall_congr_name(); name const & get_forall_congr_eq_name(); name const & get_forall_not_of_not_exists_name(); name const & get_format_name(); +name const & get_function_name(); name const & get_funext_name(); name const & get_has_add_name(); name const & get_has_add_add_name(); @@ -218,6 +220,7 @@ name const & get_of_eq_true_name(); name const & get_opt_param_name(); name const & get_or_name(); name const & get_out_param_name(); +name const & get_pi_name(); name const & get_pprod_name(); name const & get_pprod_fst_name(); name const & get_pprod_mk_name(); diff --git a/src/library/constants.txt b/src/library/constants.txt index 61cd0d5232..e84a45f96e 100644 --- a/src/library/constants.txt +++ b/src/library/constants.txt @@ -58,10 +58,12 @@ false.rec false_of_true_eq_false fin.mk fin.ne_of_vne +forall forall_congr forall_congr_eq forall_not_of_not_exists format +function funext has_add has_add.add @@ -211,6 +213,7 @@ of_eq_true opt_param or out_param +pi pprod pprod.fst pprod.mk