diff --git a/affinescript.opam b/affinescript.opam index 1e0fdb8..1bc717a 100644 --- a/affinescript.opam +++ b/affinescript.opam @@ -23,7 +23,7 @@ doc: "https://github.com/hyperpolymath/affinescript" bug-reports: "https://github.com/hyperpolymath/affinescript/issues" depends: [ "ocaml" {>= "5.1"} - "dune" {>= "3.14" & >= "3.14"} + "dune" {>= "3.14"} "menhir" {>= "20231231"} "sedlex" {>= "3.2"} "ppx_deriving" {>= "5.2"} diff --git a/lib/borrow.ml b/lib/borrow.ml index ed8f01d..90bd26a 100644 --- a/lib/borrow.ml +++ b/lib/borrow.ml @@ -60,6 +60,9 @@ type borrow_error = type 'a result = ('a, borrow_error) Result.t +(* Result bind - define before use *) +let ( let* ) = Result.bind + (** Create a new borrow checker state *) let create () : state = { @@ -99,7 +102,7 @@ let find_conflicting_borrow (state : state) (new_borrow : borrow) : borrow optio ) state.borrows (** Record a move *) -let record_move (state : state) (place : place) (span : Span.t) : unit result = +let record_move (state : state) (place : place) (_span : Span.t) : unit result = (* Check for active borrows *) match List.find_opt (fun b -> places_overlap place b.b_place) state.borrows with | Some borrow -> Error (MoveWhileBorrowed (place, borrow)) @@ -137,105 +140,291 @@ let check_use (state : state) (place : place) (span : Span.t) : unit result = else Ok () +(** Get span from an expression *) +let rec expr_span (expr : expr) : Span.t = + match expr with + | ExprSpan (_, span) -> span + | ExprLit lit -> lit_span lit + | ExprVar id -> id.span + | ExprLet { el_pat; _ } -> pattern_span el_pat + | ExprIf { ei_cond; _ } -> expr_span ei_cond + | ExprMatch { em_scrutinee; _ } -> expr_span em_scrutinee + | ExprLambda { elam_params; _ } -> + begin match elam_params with + | p :: _ -> p.p_name.span + | [] -> Span.dummy + end + | ExprApp (f, _) -> expr_span f + | ExprField (e, _) -> expr_span e + | ExprTupleIndex (e, _) -> expr_span e + | ExprIndex (e, _) -> expr_span e + | ExprTuple exprs -> + begin match exprs with + | e :: _ -> expr_span e + | [] -> Span.dummy + end + | ExprArray exprs -> + begin match exprs with + | e :: _ -> expr_span e + | [] -> Span.dummy + end + | ExprRecord { er_fields; _ } -> + begin match er_fields with + | (id, _) :: _ -> id.span + | [] -> Span.dummy + end + | ExprRowRestrict (e, _) -> expr_span e + | ExprBinary (e, _, _) -> expr_span e + | ExprUnary (_, e) -> expr_span e + | ExprBlock { blk_stmts; blk_expr } -> + begin match blk_stmts with + | StmtLet { sl_pat; _ } :: _ -> pattern_span sl_pat + | StmtExpr e :: _ -> expr_span e + | StmtAssign (e, _, _) :: _ -> expr_span e + | StmtWhile (e, _) :: _ -> expr_span e + | StmtFor (p, _, _) :: _ -> pattern_span p + | [] -> match blk_expr with Some e -> expr_span e | None -> Span.dummy + end + | ExprReturn _ -> Span.dummy + | ExprTry _ -> Span.dummy + | ExprHandle { eh_body; _ } -> expr_span eh_body + | ExprResume _ -> Span.dummy + | ExprUnsafe _ -> Span.dummy + | ExprVariant (id, _) -> id.span + +and lit_span (lit : literal) : Span.t = + match lit with + | LitInt (_, span) -> span + | LitFloat (_, span) -> span + | LitBool (_, span) -> span + | LitChar (_, span) -> span + | LitString (_, span) -> span + | LitUnit span -> span + +and pattern_span (pat : pattern) : Span.t = + match pat with + | PatWildcard span -> span + | PatVar id -> id.span + | PatLit lit -> lit_span lit + | PatCon (id, _) -> id.span + | PatTuple pats -> + begin match pats with + | p :: _ -> pattern_span p + | [] -> Span.dummy + end + | PatRecord ((id, _) :: _, _) -> id.span + | PatRecord ([], _) -> Span.dummy + | PatOr (p1, _) -> pattern_span p1 + | PatAs (id, _) -> id.span + (** Convert an expression to a place (if it's an l-value) *) let rec expr_to_place (symbols : Symbol.t) (expr : expr) : place option = match expr with - | EVar id -> - begin match Symbol.lookup symbols id.id_name with + | ExprVar id -> + begin match Symbol.lookup symbols id.name with | Some sym -> Some (PlaceVar sym.sym_id) | None -> None end - | ERecordAccess (base, field, _) -> + | ExprField (base, field) -> begin match expr_to_place symbols base with - | Some base_place -> Some (PlaceField (base_place, field.id_name)) + | Some base_place -> Some (PlaceField (base_place, field.name)) | None -> None end - | EIndex (base, _, _) -> + | ExprIndex (base, _) -> begin match expr_to_place symbols base with | Some base_place -> Some (PlaceIndex (base_place, None)) | None -> None end + | ExprSpan (e, _) -> + expr_to_place symbols e | _ -> None (** Check borrows in an expression *) let rec check_expr (state : state) (symbols : Symbol.t) (expr : expr) : unit result = match expr with - | EVar id -> + | ExprVar id -> begin match expr_to_place symbols expr with - | Some place -> check_use state place id.id_span + | Some place -> check_use state place id.span | None -> Ok () end - | ELit _ -> Ok () + | ExprLit _ -> Ok () - | EApp (func, arg, _) -> + | ExprApp (func, args) -> let* () = check_expr state symbols func in - check_expr state symbols arg - - | ELam lam -> - check_expr state symbols lam.lam_body + List.fold_left (fun acc arg -> + let* () = acc in + check_expr state symbols arg + ) (Ok ()) args + + | ExprLambda lam -> + check_expr state symbols lam.elam_body + + | ExprLet lb -> + let* () = check_expr state symbols lb.el_value in + match lb.el_body with + | Some body -> check_expr state symbols body + | None -> Ok () + + | ExprIf ei -> + let* () = check_expr state symbols ei.ei_cond in + (* TODO: Proper branch handling - save/restore state *) + let* () = check_expr state symbols ei.ei_then in + begin match ei.ei_else with + | Some e -> check_expr state symbols e + | None -> Ok () + end - | ELet lb -> - let* () = check_expr state symbols lb.lb_rhs in - check_expr state symbols lb.lb_body + | ExprMatch em -> + let* () = check_expr state symbols em.em_scrutinee in + List.fold_left (fun acc arm -> + let* () = acc in + let* () = match arm.ma_guard with + | Some g -> check_expr state symbols g + | None -> Ok () + in + check_expr state symbols arm.ma_body + ) (Ok ()) em.em_arms + + | ExprTuple exprs -> + List.fold_left (fun acc e -> + let* () = acc in + check_expr state symbols e + ) (Ok ()) exprs - | EIf (cond, then_, else_, _) -> - let* () = check_expr state symbols cond in - (* TODO: Proper branch handling - save/restore state *) - let* () = check_expr state symbols then_ in - check_expr state symbols else_ - - | ECase (scrut, branches, _) -> - let* () = check_expr state symbols scrut in - List.fold_left (fun acc branch -> - match acc with - | Error e -> Error e - | Ok () -> check_expr state symbols branch.cb_body - ) (Ok ()) branches - - | ETuple (exprs, _) -> + | ExprArray exprs -> List.fold_left (fun acc e -> - match acc with - | Error e -> Error e - | Ok () -> check_expr state symbols e + let* () = acc in + check_expr state symbols e ) (Ok ()) exprs - | ERecord (fields, _) -> - List.fold_left (fun acc (_, e) -> - match acc with - | Error e -> Error e - | Ok () -> check_expr state symbols e - ) (Ok ()) fields + | ExprRecord er -> + let* () = List.fold_left (fun acc (_id, e_opt) -> + let* () = acc in + match e_opt with + | Some e -> check_expr state symbols e + | None -> Ok () + ) (Ok ()) er.er_fields in + begin match er.er_spread with + | Some e -> check_expr state symbols e + | None -> Ok () + end - | ERecordAccess (base, _, _) -> + | ExprField (base, _) -> check_expr state symbols base - | ERecordUpdate (base, _, value, _) -> - let* () = check_expr state symbols base in - check_expr state symbols value + | ExprTupleIndex (base, _) -> + check_expr state symbols base - | EBlock (exprs, _) -> - List.fold_left (fun acc e -> - match acc with - | Error e -> Error e - | Ok () -> check_expr state symbols e - ) (Ok ()) exprs + | ExprIndex (arr, idx) -> + let* () = check_expr state symbols arr in + check_expr state symbols idx + + | ExprRowRestrict (base, _) -> + check_expr state symbols base + + | ExprBlock blk -> + check_block state symbols blk - | EBinOp (left, _, right, _) -> + | ExprBinary (left, _, right) -> let* () = check_expr state symbols left in check_expr state symbols right - | _ -> Ok () + | ExprUnary (_, e) -> + check_expr state symbols e -(* Result bind *) -let ( let* ) = Result.bind + | ExprReturn e_opt -> + begin match e_opt with + | Some e -> check_expr state symbols e + | None -> Ok () + end + + | ExprHandle eh -> + let* () = check_expr state symbols eh.eh_body in + List.fold_left (fun acc arm -> + let* () = acc in + match arm with + | HandlerReturn (_pat, body) -> check_expr state symbols body + | HandlerOp (_op, _pats, body) -> check_expr state symbols body + ) (Ok ()) eh.eh_handlers + + | ExprResume e_opt -> + begin match e_opt with + | Some e -> check_expr state symbols e + | None -> Ok () + end + + | ExprTry et -> + let* () = check_block state symbols et.et_body in + let* () = match et.et_catch with + | Some arms -> + List.fold_left (fun acc arm -> + let* () = acc in + let* () = match arm.ma_guard with + | Some g -> check_expr state symbols g + | None -> Ok () + in + check_expr state symbols arm.ma_body + ) (Ok ()) arms + | None -> Ok () + in + begin match et.et_finally with + | Some blk -> check_block state symbols blk + | None -> Ok () + end + + | ExprUnsafe ops -> + List.fold_left (fun acc op -> + let* () = acc in + match op with + | UnsafeRead e -> check_expr state symbols e + | UnsafeWrite (e1, e2) -> + let* () = check_expr state symbols e1 in + check_expr state symbols e2 + | UnsafeOffset (e1, e2) -> + let* () = check_expr state symbols e1 in + check_expr state symbols e2 + | UnsafeTransmute (_, _, e) -> check_expr state symbols e + | UnsafeForget e -> check_expr state symbols e + | UnsafeAssume _ -> Ok () + ) (Ok ()) ops + + | ExprVariant _ -> Ok () + + | ExprSpan (e, _) -> + check_expr state symbols e + +and check_block (state : state) (symbols : Symbol.t) (blk : block) : unit result = + let* () = List.fold_left (fun acc stmt -> + let* () = acc in + check_stmt state symbols stmt + ) (Ok ()) blk.blk_stmts in + match blk.blk_expr with + | Some e -> check_expr state symbols e + | None -> Ok () + +and check_stmt (state : state) (symbols : Symbol.t) (stmt : stmt) : unit result = + match stmt with + | StmtLet sl -> + check_expr state symbols sl.sl_value + | StmtExpr e -> + check_expr state symbols e + | StmtAssign (lhs, _, rhs) -> + let* () = check_expr state symbols lhs in + check_expr state symbols rhs + | StmtWhile (cond, body) -> + let* () = check_expr state symbols cond in + check_block state symbols body + | StmtFor (_pat, iter, body) -> + let* () = check_expr state symbols iter in + check_block state symbols body (** Check a function *) -let check_function (symbols : Symbol.t) (fd : fun_decl) : unit result = +let check_function (symbols : Symbol.t) (fd : fn_decl) : unit result = let state = create () in match fd.fd_body with - | Some body -> check_expr state symbols body - | None -> Ok () + | FnBlock blk -> check_block state symbols blk + | FnExpr e -> check_expr state symbols e (** Check a program *) let check_program (symbols : Symbol.t) (program : program) : unit result = @@ -244,10 +433,15 @@ let check_program (symbols : Symbol.t) (program : program) : unit result = | Error e -> Error e | Ok () -> match decl with - | DFun fd -> check_function symbols fd + | TopFn fd -> check_function symbols fd | _ -> Ok () ) (Ok ()) program.prog_decls +(* Silence unused warnings for functions that will be used in later phases *) +let _ = record_move +let _ = record_borrow +let _ = end_borrow + (* TODO: Phase 3 implementation - [ ] Non-lexical lifetimes - [ ] Dataflow analysis for precise tracking diff --git a/lib/error.ml b/lib/error.ml index 2c417c9..663e09c 100644 --- a/lib/error.ml +++ b/lib/error.ml @@ -169,7 +169,7 @@ let format_diagnostic ~source diag = (match source with | Some src -> let lines = String.split_on_char '\n' src in - if span.start_pos.line <= List.length lines then begin + if span.start_pos.line > 0 && span.start_pos.line <= List.length lines then begin let line = List.nth lines (span.start_pos.line - 1) in let line_num = string_of_int span.start_pos.line in let padding = String.make (String.length line_num) ' ' in @@ -177,7 +177,7 @@ let format_diagnostic ~source diag = Buffer.add_string buf (Printf.sprintf " %s | %s\n" line_num line); Buffer.add_string buf (Printf.sprintf " %s | %s%s" padding - (String.make (span.start_pos.col - 1) ' ') + (String.make (max 0 (span.start_pos.col - 1)) ' ') (String.make (max 1 (span.end_pos.col - span.start_pos.col)) '^')); if label.message <> "" then Buffer.add_string buf (Printf.sprintf " %s" label.message); diff --git a/lib/parse_driver.ml b/lib/parse_driver.ml index 766ac19..0726447 100644 --- a/lib/parse_driver.ml +++ b/lib/parse_driver.ml @@ -164,9 +164,11 @@ let parse_string ~file content = (** Parse a program from a file *) let parse_file filename = let chan = open_in_bin filename in - let content = really_input_string chan (in_channel_length chan) in - close_in chan; - parse_string ~file:filename content + Fun.protect + ~finally:(fun () -> close_in chan) + (fun () -> + let content = really_input_string chan (in_channel_length chan) in + parse_string ~file:filename content) (** Parse a single expression from a string *) let parse_expr ~file content = diff --git a/lib/quantity.ml b/lib/quantity.ml index de17998..d03e544 100644 --- a/lib/quantity.ml +++ b/lib/quantity.ml @@ -74,110 +74,172 @@ let check_variable (ctx : context) (sym : Symbol.symbol) (id : ident) match (q, u) with (* Erased: must not be used *) | (QZero, UZero) -> Ok () - | (QZero, _) -> Error (ErasedVariableUsed id, id.id_span) + | (QZero, _) -> Error (ErasedVariableUsed id, id.span) (* Linear: must be used exactly once (or zero for affine) *) | (QOne, UZero) -> Ok () (* Affine: can drop *) | (QOne, UOne) -> Ok () - | (QOne, UMany) -> Error (LinearVariableUsedMultiple id, id.id_span) + | (QOne, UMany) -> Error (LinearVariableUsedMultiple id, id.span) (* Unrestricted: any usage is fine *) | (QOmega, _) -> Ok () (** Analyze usage in an expression *) let rec analyze_expr (ctx : context) (symbols : Symbol.t) (expr : expr) : unit = match expr with - | EVar id -> - begin match Symbol.lookup symbols id.id_name with + | ExprVar id -> + begin match Symbol.lookup symbols id.name with | Some sym -> use ctx sym | None -> () end - | ELit _ -> () + | ExprLit _ -> () - | EApp (func, arg, _) -> + | ExprApp (func, args) -> analyze_expr ctx symbols func; - analyze_expr ctx symbols arg + List.iter (analyze_expr ctx symbols) args - | ELam lam -> + | ExprLambda lam -> (* Parameters are bound; analyze body *) - analyze_expr ctx symbols lam.lam_body + analyze_expr ctx symbols lam.elam_body - | ELet lb -> - analyze_expr ctx symbols lb.lb_rhs; - analyze_expr ctx symbols lb.lb_body + | ExprLet lb -> + analyze_expr ctx symbols lb.el_value; + Option.iter (analyze_expr ctx symbols) lb.el_body - | EIf (cond, then_, else_, _) -> - analyze_expr ctx symbols cond; + | ExprIf ei -> + analyze_expr ctx symbols ei.ei_cond; (* For branches, we need to join usages *) (* TODO: Proper branch handling *) - analyze_expr ctx symbols then_; - analyze_expr ctx symbols else_ + analyze_expr ctx symbols ei.ei_then; + Option.iter (analyze_expr ctx symbols) ei.ei_else + + | ExprMatch em -> + analyze_expr ctx symbols em.em_scrutinee; + List.iter (fun arm -> + Option.iter (analyze_expr ctx symbols) arm.ma_guard; + analyze_expr ctx symbols arm.ma_body + ) em.em_arms - | ECase (scrut, branches, _) -> - analyze_expr ctx symbols scrut; - List.iter (fun branch -> - analyze_expr ctx symbols branch.cb_body - ) branches + | ExprTuple exprs -> + List.iter (analyze_expr ctx symbols) exprs - | ETuple (exprs, _) -> + | ExprArray exprs -> List.iter (analyze_expr ctx symbols) exprs - | ERecord (fields, _) -> - List.iter (fun (_, e) -> analyze_expr ctx symbols e) fields + | ExprRecord er -> + List.iter (fun (_id, e_opt) -> + Option.iter (analyze_expr ctx symbols) e_opt + ) er.er_fields; + Option.iter (analyze_expr ctx symbols) er.er_spread - | ERecordAccess (e, _, _) -> + | ExprField (e, _) -> analyze_expr ctx symbols e - | ERecordUpdate (base, _, value, _) -> - analyze_expr ctx symbols base; - analyze_expr ctx symbols value + | ExprTupleIndex (e, _) -> + analyze_expr ctx symbols e - | EBlock (exprs, _) -> - List.iter (analyze_expr ctx symbols) exprs + | ExprIndex (arr, idx) -> + analyze_expr ctx symbols arr; + analyze_expr ctx symbols idx + + | ExprRowRestrict (e, _) -> + analyze_expr ctx symbols e - | EBinOp (left, _, right, _) -> + | ExprBlock blk -> + analyze_block ctx symbols blk + + | ExprBinary (left, _, right) -> analyze_expr ctx symbols left; analyze_expr ctx symbols right - | EUnaryOp (_, e, _) -> + | ExprUnary (_, e) -> analyze_expr ctx symbols e - | EHandle (body, handler, _) -> - analyze_expr ctx symbols body; - begin match handler.h_return with - | Some (_, e) -> analyze_expr ctx symbols e - | None -> () - end; - List.iter (fun clause -> - analyze_expr ctx symbols clause.oc_body - ) handler.h_ops + | ExprHandle eh -> + analyze_expr ctx symbols eh.eh_body; + List.iter (fun arm -> + match arm with + | HandlerReturn (_pat, body) -> analyze_expr ctx symbols body + | HandlerOp (_op, _pats, body) -> analyze_expr ctx symbols body + ) eh.eh_handlers + + | ExprResume e_opt -> + Option.iter (analyze_expr ctx symbols) e_opt + + | ExprReturn e_opt -> + Option.iter (analyze_expr ctx symbols) e_opt + + | ExprTry et -> + analyze_block ctx symbols et.et_body; + Option.iter (fun arms -> + List.iter (fun arm -> + Option.iter (analyze_expr ctx symbols) arm.ma_guard; + analyze_expr ctx symbols arm.ma_body + ) arms + ) et.et_catch; + Option.iter (analyze_block ctx symbols) et.et_finally + + | ExprUnsafe ops -> + List.iter (fun op -> + match op with + | UnsafeRead e -> analyze_expr ctx symbols e + | UnsafeWrite (e1, e2) -> + analyze_expr ctx symbols e1; + analyze_expr ctx symbols e2 + | UnsafeOffset (e1, e2) -> + analyze_expr ctx symbols e1; + analyze_expr ctx symbols e2 + | UnsafeTransmute (_, _, e) -> analyze_expr ctx symbols e + | UnsafeForget e -> analyze_expr ctx symbols e + | UnsafeAssume _ -> () + ) ops + + | ExprVariant _ -> () + + | ExprSpan (e, _) -> + analyze_expr ctx symbols e - | EPerform (_, arg, _) -> - analyze_expr ctx symbols arg +and analyze_block (ctx : context) (symbols : Symbol.t) (blk : block) : unit = + List.iter (analyze_stmt ctx symbols) blk.blk_stmts; + Option.iter (analyze_expr ctx symbols) blk.blk_expr - | _ -> () +and analyze_stmt (ctx : context) (symbols : Symbol.t) (stmt : stmt) : unit = + match stmt with + | StmtLet sl -> + analyze_expr ctx symbols sl.sl_value + | StmtExpr e -> + analyze_expr ctx symbols e + | StmtAssign (lhs, _, rhs) -> + analyze_expr ctx symbols lhs; + analyze_expr ctx symbols rhs + | StmtWhile (cond, body) -> + analyze_expr ctx symbols cond; + analyze_block ctx symbols body + | StmtFor (_pat, iter, body) -> + analyze_expr ctx symbols iter; + analyze_block ctx symbols body (** Check quantities for a function *) -let check_function (symbols : Symbol.t) (fd : fun_decl) : unit result = +let check_function (symbols : Symbol.t) (fd : fn_decl) : unit result = let ctx = create () in (* Declare parameter quantities *) - List.iter (fun (id, _, q_opt) -> - let q = Option.value q_opt ~default:QOmega in - match Symbol.lookup symbols id.id_name with + List.iter (fun param -> + let q = Option.value param.p_quantity ~default:QOmega in + match Symbol.lookup symbols param.p_name.name with | Some sym -> declare ctx sym q | None -> () ) fd.fd_params; (* Analyze body *) begin match fd.fd_body with - | Some body -> analyze_expr ctx symbols body - | None -> () + | FnBlock blk -> analyze_block ctx symbols blk + | FnExpr e -> analyze_expr ctx symbols e end; (* Check all parameters *) - List.fold_left (fun acc (id, _, _) -> + List.fold_left (fun acc param -> match acc with | Error e -> Error e | Ok () -> - match Symbol.lookup symbols id.id_name with - | Some sym -> check_variable ctx sym id + match Symbol.lookup symbols param.p_name.name with + | Some sym -> check_variable ctx sym param.p_name | None -> Ok () ) (Ok ()) fd.fd_params @@ -188,7 +250,7 @@ let check_program (symbols : Symbol.t) (program : program) : unit result = | Error e -> Error e | Ok () -> match decl with - | DFun fd -> check_function symbols fd + | TopFn fd -> check_function symbols fd | _ -> Ok () ) (Ok ()) program.prog_decls diff --git a/lib/resolve.ml b/lib/resolve.ml index ffbc328..dc406ce 100644 --- a/lib/resolve.ml +++ b/lib/resolve.ml @@ -30,6 +30,9 @@ type context = { imports : (string * Symbol.symbol) list; } +(* Helper for Result bind *) +let ( let* ) = Result.bind + (** Create a new resolution context *) let create_context () : context = { @@ -40,298 +43,350 @@ let create_context () : context = (** Resolve an identifier *) let resolve_ident (ctx : context) (id : ident) : Symbol.symbol result = - let name = id.id_name in + let name = id.name in match Symbol.lookup ctx.symbols name with | Some sym -> Ok sym - | None -> Error (UndefinedVariable id, id.id_span) + | None -> Error (UndefinedVariable id, id.span) (** Resolve a type identifier *) let resolve_type_ident (ctx : context) (id : ident) : Symbol.symbol result = - let name = id.id_name in + let name = id.name in match Symbol.lookup ctx.symbols name with | Some sym when sym.sym_kind = Symbol.SKType -> Ok sym | Some sym when sym.sym_kind = Symbol.SKTypeVar -> Ok sym - | Some _ -> Error (UndefinedType id, id.id_span) - | None -> Error (UndefinedType id, id.id_span) + | Some _ -> Error (UndefinedType id, id.span) + | None -> Error (UndefinedType id, id.span) (** Resolve an effect identifier *) let resolve_effect_ident (ctx : context) (id : ident) : Symbol.symbol result = - let name = id.id_name in + let name = id.name in match Symbol.lookup ctx.symbols name with | Some sym when sym.sym_kind = Symbol.SKEffect -> Ok sym - | Some _ -> Error (UndefinedEffect id, id.id_span) - | None -> Error (UndefinedEffect id, id.id_span) + | Some _ -> Error (UndefinedEffect id, id.span) + | None -> Error (UndefinedEffect id, id.span) (** Resolve a pattern, binding variables *) let rec resolve_pattern (ctx : context) (pat : pattern) : context result = match pat with - | PWild _ -> Ok ctx - | PVar id -> - if Symbol.is_defined_locally ctx.symbols id.id_name then - Error (DuplicateDefinition id, id.id_span) + | PatWildcard _ -> Ok ctx + | PatVar id -> + if Symbol.is_defined_locally ctx.symbols id.name then + Error (DuplicateDefinition id, id.span) else begin - let _ = Symbol.define ctx.symbols id.id_name - Symbol.SKVariable id.id_span Private in + let _ = Symbol.define ctx.symbols id.name + Symbol.SKVariable id.span Private in Ok ctx end - | PLit _ -> Ok ctx - | PTuple (pats, _) -> + | PatLit _ -> Ok ctx + | PatTuple pats -> List.fold_left (fun acc pat -> match acc with | Error e -> Error e | Ok ctx -> resolve_pattern ctx pat ) (Ok ctx) pats - | PRecord (fields, _, _) -> - List.fold_left (fun acc (_, pat) -> + | PatRecord (fields, _has_rest) -> + List.fold_left (fun acc (_id, pat_opt) -> match acc with | Error e -> Error e - | Ok ctx -> resolve_pattern ctx pat + | Ok ctx -> + match pat_opt with + | Some pat -> resolve_pattern ctx pat + | None -> Ok ctx ) (Ok ctx) fields - | PConstructor (_, pats, _) -> + | PatCon (_con, pats) -> List.fold_left (fun acc pat -> match acc with | Error e -> Error e | Ok ctx -> resolve_pattern ctx pat ) (Ok ctx) pats - | POr (p1, p2, _) -> + | PatOr (p1, p2) -> (* Both branches must bind the same variables *) let* ctx1 = resolve_pattern ctx p1 in let* _ctx2 = resolve_pattern ctx p2 in Ok ctx1 - | PAs (pat, id, _) -> + | PatAs (id, pat) -> let* ctx = resolve_pattern ctx pat in - if Symbol.is_defined_locally ctx.symbols id.id_name then - Error (DuplicateDefinition id, id.id_span) + if Symbol.is_defined_locally ctx.symbols id.name then + Error (DuplicateDefinition id, id.span) else begin - let _ = Symbol.define ctx.symbols id.id_name - Symbol.SKVariable id.id_span Private in + let _ = Symbol.define ctx.symbols id.name + Symbol.SKVariable id.span Private in Ok ctx end (** Resolve an expression *) let rec resolve_expr (ctx : context) (expr : expr) : unit result = match expr with - | EVar id -> + | ExprVar id -> let* _ = resolve_ident ctx id in Ok () - | ELit _ -> Ok () + | ExprLit _ -> Ok () - | EApp (func, arg, _) -> + | ExprApp (func, args) -> let* () = resolve_expr ctx func in - resolve_expr ctx arg + List.fold_left (fun acc arg -> + match acc with + | Error e -> Error e + | Ok () -> resolve_expr ctx arg + ) (Ok ()) args - | ELam lam -> + | ExprLambda lam -> Symbol.enter_scope ctx.symbols (Symbol.ScopeFunction "lambda"); (* Bind parameters *) - List.iter (fun (id, _, _) -> - let _ = Symbol.define ctx.symbols id.id_name - Symbol.SKVariable id.id_span Private in + List.iter (fun param -> + let _ = Symbol.define ctx.symbols param.p_name.name + Symbol.SKVariable param.p_name.span Private in () - ) lam.lam_params; - let result = resolve_expr ctx lam.lam_body in + ) lam.elam_params; + let result = resolve_expr ctx lam.elam_body in Symbol.exit_scope ctx.symbols; result - | ELet lb -> - let* () = resolve_expr ctx lb.lb_rhs in + | ExprLet lb -> + let* () = resolve_expr ctx lb.el_value in Symbol.enter_scope ctx.symbols Symbol.ScopeBlock; - let* _ = resolve_pattern ctx lb.lb_pat in - let result = resolve_expr ctx lb.lb_body in + let* _ = resolve_pattern ctx lb.el_pat in + let result = match lb.el_body with + | Some body -> resolve_expr ctx body + | None -> Ok () + in Symbol.exit_scope ctx.symbols; result - | EIf (cond, then_, else_, _) -> - let* () = resolve_expr ctx cond in - let* () = resolve_expr ctx then_ in - resolve_expr ctx else_ + | ExprIf ei -> + let* () = resolve_expr ctx ei.ei_cond in + let* () = resolve_expr ctx ei.ei_then in + (match ei.ei_else with + | Some e -> resolve_expr ctx e + | None -> Ok ()) - | ECase (scrut, branches, _) -> - let* () = resolve_expr ctx scrut in - List.fold_left (fun acc branch -> + | ExprMatch em -> + let* () = resolve_expr ctx em.em_scrutinee in + List.fold_left (fun acc arm -> match acc with | Error e -> Error e | Ok () -> Symbol.enter_scope ctx.symbols Symbol.ScopeMatch; let result = - let* _ = resolve_pattern ctx branch.cb_pat in - let* () = match branch.cb_guard with + let* _ = resolve_pattern ctx arm.ma_pat in + let* () = match arm.ma_guard with | Some g -> resolve_expr ctx g | None -> Ok () in - resolve_expr ctx branch.cb_body + resolve_expr ctx arm.ma_body in Symbol.exit_scope ctx.symbols; result - ) (Ok ()) branches + ) (Ok ()) em.em_arms - | ETuple (exprs, _) -> + | ExprTuple exprs -> List.fold_left (fun acc e -> match acc with | Error e -> Error e | Ok () -> resolve_expr ctx e ) (Ok ()) exprs - | ERecord (fields, _) -> - List.fold_left (fun acc (_, e) -> + | ExprArray exprs -> + List.fold_left (fun acc e -> match acc with | Error e -> Error e | Ok () -> resolve_expr ctx e - ) (Ok ()) fields - - | ERecordAccess (e, _, _) -> - resolve_expr ctx e - - | ERecordUpdate (base, _, value, _) -> - let* () = resolve_expr ctx base in - resolve_expr ctx value + ) (Ok ()) exprs - | EArray (elems, _) -> - List.fold_left (fun acc e -> + | ExprRecord er -> + let* () = List.fold_left (fun acc (_id, e_opt) -> match acc with | Error e -> Error e - | Ok () -> resolve_expr ctx e - ) (Ok ()) elems + | Ok () -> + match e_opt with + | Some e -> resolve_expr ctx e + | None -> Ok () + ) (Ok ()) er.er_fields in + (match er.er_spread with + | Some e -> resolve_expr ctx e + | None -> Ok ()) + + | ExprField (e, _field) -> + resolve_expr ctx e + + | ExprTupleIndex (e, _idx) -> + resolve_expr ctx e - | EIndex (arr, idx, _) -> + | ExprIndex (arr, idx) -> let* () = resolve_expr ctx arr in resolve_expr ctx idx - | EHandle (body, handler, _) -> - let* () = resolve_expr ctx body in - resolve_handler ctx handler + | ExprBinary (left, _op, right) -> + let* () = resolve_expr ctx left in + resolve_expr ctx right + + | ExprUnary (_op, e) -> + resolve_expr ctx e - | EPerform (_, arg, _) -> - resolve_expr ctx arg + | ExprBlock blk -> + resolve_block ctx blk - | EResume (arg, _) -> - resolve_expr ctx arg + | ExprReturn e_opt -> + (match e_opt with + | Some e -> resolve_expr ctx e + | None -> Ok ()) - | EBlock (exprs, _) -> - Symbol.enter_scope ctx.symbols Symbol.ScopeBlock; - let result = List.fold_left (fun acc e -> + | ExprHandle eh -> + let* () = resolve_expr ctx eh.eh_body in + List.fold_left (fun acc arm -> match acc with | Error e -> Error e - | Ok () -> resolve_expr ctx e - ) (Ok ()) exprs in - Symbol.exit_scope ctx.symbols; - result + | Ok () -> + Symbol.enter_scope ctx.symbols Symbol.ScopeHandler; + let result = match arm with + | HandlerReturn (pat, body) -> + let* _ = resolve_pattern ctx pat in + resolve_expr ctx body + | HandlerOp (_op, pats, body) -> + let* _ = List.fold_left (fun acc pat -> + match acc with + | Error e -> Error e + | Ok ctx -> resolve_pattern ctx pat + ) (Ok ctx) pats in + resolve_expr ctx body + in + Symbol.exit_scope ctx.symbols; + result + ) (Ok ()) eh.eh_handlers - | EBinOp (left, _, right, _) -> - let* () = resolve_expr ctx left in - resolve_expr ctx right + | ExprResume e_opt -> + (match e_opt with + | Some e -> resolve_expr ctx e + | None -> Ok ()) - | EUnaryOp (_, e, _) -> + | ExprRowRestrict (e, _field) -> resolve_expr ctx e - | ETyApp (e, _, _) -> - resolve_expr ctx e + | ExprTry et -> + let* () = resolve_block ctx et.et_body in + let* () = match et.et_catch with + | Some arms -> + List.fold_left (fun acc arm -> + match acc with + | Error e -> Error e + | Ok () -> + Symbol.enter_scope ctx.symbols Symbol.ScopeMatch; + let* _ = resolve_pattern ctx arm.ma_pat in + let result = resolve_expr ctx arm.ma_body in + Symbol.exit_scope ctx.symbols; + result + ) (Ok ()) arms + | None -> Ok () + in + (match et.et_finally with + | Some blk -> resolve_block ctx blk + | None -> Ok ()) - | EUnsafe (e, _) -> - resolve_expr ctx e + | ExprUnsafe _ops -> + (* TODO: Resolve unsafe operations *) + Ok () - | EUnsafeCoerce (e, _, _) -> + | ExprVariant (_ty, _variant) -> + Ok () + + | ExprSpan (e, _span) -> resolve_expr ctx e -and resolve_handler (ctx : context) (handler : handler) : unit result = - (* Resolve return clause *) - let* () = match handler.h_return with - | Some (id, body) -> - Symbol.enter_scope ctx.symbols Symbol.ScopeHandler; - let _ = Symbol.define ctx.symbols id.id_name - Symbol.SKVariable id.id_span Private in - let result = resolve_expr ctx body in - Symbol.exit_scope ctx.symbols; - result +and resolve_block (ctx : context) (blk : block) : unit result = + Symbol.enter_scope ctx.symbols Symbol.ScopeBlock; + let result = + let* () = List.fold_left (fun acc stmt -> + match acc with + | Error e -> Error e + | Ok () -> resolve_stmt ctx stmt + ) (Ok ()) blk.blk_stmts in + match blk.blk_expr with + | Some e -> resolve_expr ctx e | None -> Ok () in - (* Resolve operation clauses *) - List.fold_left (fun acc clause -> - match acc with - | Error e -> Error e - | Ok () -> - Symbol.enter_scope ctx.symbols Symbol.ScopeHandler; - (* Bind operation parameters and continuation *) - List.iter (fun (id, _) -> - let _ = Symbol.define ctx.symbols id.id_name - Symbol.SKVariable id.id_span Private in - () - ) clause.oc_params; - let _ = Symbol.define ctx.symbols clause.oc_resume.id_name - Symbol.SKVariable clause.oc_resume.id_span Private in - let result = resolve_expr ctx clause.oc_body in - Symbol.exit_scope ctx.symbols; - result - ) (Ok ()) handler.h_ops + Symbol.exit_scope ctx.symbols; + result + +and resolve_stmt (ctx : context) (stmt : stmt) : unit result = + match stmt with + | StmtLet sl -> + let* () = resolve_expr ctx sl.sl_value in + let* _ = resolve_pattern ctx sl.sl_pat in + Ok () + | StmtExpr e -> + resolve_expr ctx e + | StmtAssign (lhs, _op, rhs) -> + let* () = resolve_expr ctx lhs in + resolve_expr ctx rhs + | StmtWhile (cond, body) -> + let* () = resolve_expr ctx cond in + resolve_block ctx body + | StmtFor (pat, iter, body) -> + let* () = resolve_expr ctx iter in + Symbol.enter_scope ctx.symbols Symbol.ScopeBlock; + let* _ = resolve_pattern ctx pat in + let result = resolve_block ctx body in + Symbol.exit_scope ctx.symbols; + result (** Resolve a top-level declaration *) -let resolve_decl (ctx : context) (decl : decl) : unit result = +let resolve_decl (ctx : context) (decl : top_level) : unit result = match decl with - | DFun fd -> + | TopFn fd -> (* First, define the function itself for recursion *) - let _ = Symbol.define ctx.symbols fd.fd_name.id_name - Symbol.SKFunction fd.fd_name.id_span fd.fd_vis in + let _ = Symbol.define ctx.symbols fd.fd_name.name + Symbol.SKFunction fd.fd_name.span fd.fd_vis in (* Then resolve the body *) - Symbol.enter_scope ctx.symbols (Symbol.ScopeFunction fd.fd_name.id_name); + Symbol.enter_scope ctx.symbols (Symbol.ScopeFunction fd.fd_name.name); (* Bind type parameters *) List.iter (fun tp -> - let _ = Symbol.define ctx.symbols tp.tp_name.id_name - Symbol.SKTypeVar tp.tp_name.id_span Private in + let _ = Symbol.define ctx.symbols tp.tp_name.name + Symbol.SKTypeVar tp.tp_name.span Private in () - ) fd.fd_ty_params; + ) fd.fd_type_params; (* Bind parameters *) - List.iter (fun (id, _, _) -> - let _ = Symbol.define ctx.symbols id.id_name - Symbol.SKVariable id.id_span Private in + List.iter (fun param -> + let _ = Symbol.define ctx.symbols param.p_name.name + Symbol.SKVariable param.p_name.span Private in () ) fd.fd_params; let result = match fd.fd_body with - | Some body -> resolve_expr ctx body - | None -> Ok () + | FnBlock blk -> resolve_block ctx blk + | FnExpr e -> resolve_expr ctx e in Symbol.exit_scope ctx.symbols; result - | DType td -> - let _ = Symbol.define ctx.symbols td.td_name.id_name - Symbol.SKType td.td_name.id_span td.td_vis in + | TopType td -> + let _ = Symbol.define ctx.symbols td.td_name.name + Symbol.SKType td.td_name.span td.td_vis in Ok () - | DEffect ed -> - let _ = Symbol.define ctx.symbols ed.ed_name.id_name - Symbol.SKEffect ed.ed_name.id_span ed.ed_vis in + | TopEffect ed -> + let _ = Symbol.define ctx.symbols ed.ed_name.name + Symbol.SKEffect ed.ed_name.span ed.ed_vis in (* Define each operation *) List.iter (fun op -> - let _ = Symbol.define ctx.symbols op.eo_name.id_name - Symbol.SKEffectOp op.eo_name.id_span ed.ed_vis in + let _ = Symbol.define ctx.symbols op.eod_name.name + Symbol.SKEffectOp op.eod_name.span ed.ed_vis in () ) ed.ed_ops; Ok () - | DTrait td -> - let _ = Symbol.define ctx.symbols td.trd_name.id_name - Symbol.SKTrait td.trd_name.id_span td.trd_vis in + | TopTrait td -> + let _ = Symbol.define ctx.symbols td.trd_name.name + Symbol.SKTrait td.trd_name.span td.trd_vis in Ok () - | DImpl _ -> + | TopImpl _ -> (* TODO: Resolve impl blocks *) Ok () - | DModule (name, decls, _) -> - let _ = Symbol.define ctx.symbols name.id_name - Symbol.SKModule name.id_span Private in - Symbol.enter_scope ctx.symbols (Symbol.ScopeModule name.id_name); - let result = List.fold_left (fun acc d -> - match acc with - | Error e -> Error e - | Ok () -> resolve_decl ctx d - ) (Ok ()) decls in - Symbol.exit_scope ctx.symbols; - result - - | DImport _ -> - (* TODO: Handle imports *) - Ok () + | TopConst tc -> + let _ = Symbol.define ctx.symbols tc.tc_name.name + Symbol.SKVariable tc.tc_name.span tc.tc_vis in + resolve_expr ctx tc.tc_value (** Resolve an entire program *) let resolve_program (program : program) : (context, resolve_error * Span.t) Result.t = @@ -344,9 +399,6 @@ let resolve_program (program : program) : (context, resolve_error * Span.t) Resu | Ok () -> Ok ctx | Error e -> Error e -(* Helper for Result bind *) -let ( let* ) = Result.bind - (* TODO: Phase 1 implementation - [ ] Module qualified lookups - [ ] Import resolution (use, use as, use *) diff --git a/lib/typecheck.ml b/lib/typecheck.ml index 8bf2579..ec85953 100644 --- a/lib/typecheck.ml +++ b/lib/typecheck.ml @@ -44,6 +44,9 @@ type context = { current_effect : effect; } +(* Result bind - define before use *) +let ( let* ) = Result.bind + (** Create a new type checking context *) let create_context (symbols : Symbol.t) : context = { @@ -138,20 +141,20 @@ let instantiate (ctx : context) (scheme : scheme) : ty = (** Look up a variable's type *) let lookup_var (ctx : context) (id : ident) : ty result = - match Symbol.lookup ctx.symbols id.id_name with + match Symbol.lookup ctx.symbols id.name with | Some sym -> begin match Hashtbl.find_opt ctx.var_types sym.sym_id with | Some scheme -> Ok (instantiate ctx scheme) | None -> (* Variable exists but not yet typed - this shouldn't happen after resolve *) - Error (CannotInfer id.id_span) + Error (CannotInfer id.span) end | None -> - Error (CannotInfer id.id_span) + Error (CannotInfer id.span) (** Bind a variable with a type *) let bind_var (ctx : context) (id : ident) (ty : ty) : unit = - match Symbol.lookup ctx.symbols id.id_name with + match Symbol.lookup ctx.symbols id.name with | Some sym -> let scheme = { sc_tyvars = []; sc_effvars = []; sc_rowvars = []; sc_body = ty } in Hashtbl.replace ctx.var_types sym.sym_id scheme @@ -159,7 +162,7 @@ let bind_var (ctx : context) (id : ident) (ty : ty) : unit = (** Bind a variable with a scheme (polymorphic) *) let bind_var_scheme (ctx : context) (id : ident) (scheme : scheme) : unit = - match Symbol.lookup ctx.symbols id.id_name with + match Symbol.lookup ctx.symbols id.name with | Some sym -> Hashtbl.replace ctx.var_types sym.sym_id scheme | None -> () @@ -167,9 +170,9 @@ let bind_var_scheme (ctx : context) (id : ident) (scheme : scheme) : unit = (** Convert AST type to internal type *) let rec ast_to_ty (ctx : context) (ty : type_expr) : ty = match ty with - | TyVar id -> fresh_tyvar ctx.level (* TODO: Look up type variable *) + | TyVar _id -> fresh_tyvar ctx.level (* TODO: Look up type variable *) | TyCon id -> - begin match id.id_name with + begin match id.name with | "Unit" -> ty_unit | "Bool" -> ty_bool | "Int" -> ty_int @@ -180,7 +183,7 @@ let rec ast_to_ty (ctx : context) (ty : type_expr) : ty = | name -> TCon name end | TyApp (id, args) -> - TApp (TCon id.id_name, List.map (ast_to_ty_arg ctx) args) + TApp (TCon id.name, List.map (ast_to_ty_arg ctx) args) | TyArrow (a, b, eff) -> let eff' = match eff with | Some e -> ast_to_eff ctx e @@ -191,7 +194,7 @@ let rec ast_to_ty (ctx : context) (ty : type_expr) : ty = TTuple (List.map (ast_to_ty ctx) tys) | TyRecord (fields, rest) -> let row = List.fold_right (fun field acc -> - RExtend (field.rf_name.id_name, ast_to_ty ctx field.rf_ty, acc) + RExtend (field.rf_name.name, ast_to_ty ctx field.rf_ty, acc) ) fields (match rest with | Some _ -> fresh_rowvar ctx.level | None -> REmpty @@ -203,178 +206,380 @@ let rec ast_to_ty (ctx : context) (ty : type_expr) : ty = | TyRefined (t, _pred) -> (* TODO: Convert predicate *) TRefined (ast_to_ty ctx t, PTrue) - | _ -> fresh_tyvar ctx.level (* TODO: Handle other cases *) + | TyDepArrow _ -> fresh_tyvar ctx.level (* TODO: Handle dependent arrows *) + | TyHole -> fresh_tyvar ctx.level and ast_to_ty_arg (ctx : context) (arg : type_arg) : ty = match arg with - | TaType ty -> ast_to_ty ctx ty - | TaNat _ -> TNat (NLit 0) (* TODO: Convert nat expr *) + | TyArg ty -> ast_to_ty ctx ty + | NatArg _ -> TNat (NLit 0) (* TODO: Convert nat expr *) and ast_to_eff (ctx : context) (eff : effect_expr) : effect = match eff with - | EffNamed id -> ESingleton id.id_name - | EffVar id -> fresh_effvar ctx.level - | EffUnion effs -> EUnion (List.map (ast_to_eff ctx) effs) - | EffApp (id, _) -> ESingleton id.id_name + | EffCon (id, _) -> ESingleton id.name + | EffVar _id -> fresh_effvar ctx.level + | EffUnion (e1, e2) -> EUnion [ast_to_eff ctx e1; ast_to_eff ctx e2] + +(** Get span from an expression *) +let rec expr_span (expr : expr) : Span.t = + match expr with + | ExprSpan (_, span) -> span + | ExprLit lit -> lit_span lit + | ExprVar id -> id.span + | ExprLet { el_pat; _ } -> pattern_span el_pat + | ExprIf { ei_cond; _ } -> expr_span ei_cond + | ExprMatch { em_scrutinee; _ } -> expr_span em_scrutinee + | ExprLambda { elam_params; _ } -> + begin match elam_params with + | p :: _ -> p.p_name.span + | [] -> Span.dummy + end + | ExprApp (f, _) -> expr_span f + | ExprField (e, _) -> expr_span e + | ExprTupleIndex (e, _) -> expr_span e + | ExprIndex (e, _) -> expr_span e + | ExprTuple exprs -> + begin match exprs with + | e :: _ -> expr_span e + | [] -> Span.dummy + end + | ExprArray exprs -> + begin match exprs with + | e :: _ -> expr_span e + | [] -> Span.dummy + end + | ExprRecord { er_fields; _ } -> + begin match er_fields with + | (id, _) :: _ -> id.span + | [] -> Span.dummy + end + | ExprRowRestrict (e, _) -> expr_span e + | ExprBinary (e, _, _) -> expr_span e + | ExprUnary (_, e) -> expr_span e + | ExprBlock { blk_stmts; blk_expr } -> + begin match blk_stmts with + | StmtLet { sl_pat; _ } :: _ -> pattern_span sl_pat + | StmtExpr e :: _ -> expr_span e + | StmtAssign (e, _, _) :: _ -> expr_span e + | StmtWhile (e, _) :: _ -> expr_span e + | StmtFor (p, _, _) :: _ -> pattern_span p + | [] -> match blk_expr with Some e -> expr_span e | None -> Span.dummy + end + | ExprReturn _ -> Span.dummy + | ExprTry _ -> Span.dummy + | ExprHandle { eh_body; _ } -> expr_span eh_body + | ExprResume _ -> Span.dummy + | ExprUnsafe _ -> Span.dummy + | ExprVariant (id, _) -> id.span + +and lit_span (lit : literal) : Span.t = + match lit with + | LitInt (_, span) -> span + | LitFloat (_, span) -> span + | LitBool (_, span) -> span + | LitChar (_, span) -> span + | LitString (_, span) -> span + | LitUnit span -> span + +and pattern_span (pat : pattern) : Span.t = + match pat with + | PatWildcard span -> span + | PatVar id -> id.span + | PatLit lit -> lit_span lit + | PatCon (id, _) -> id.span + | PatTuple pats -> + begin match pats with + | p :: _ -> pattern_span p + | [] -> Span.dummy + end + | PatRecord ((id, _) :: _, _) -> id.span + | PatRecord ([], _) -> Span.dummy + | PatOr (p1, _) -> pattern_span p1 + | PatAs (id, _) -> id.span (** Synthesize (infer) the type of an expression *) let rec synth (ctx : context) (expr : expr) : (ty * effect) result = match expr with - | EVar id -> + | ExprVar id -> let* ty = lookup_var ctx id in Ok (ty, EPure) - | ELit lit -> + | ExprLit lit -> let ty = synth_literal lit in Ok (ty, EPure) - | EApp (func, arg, span) -> + | ExprApp (func, args) -> + let span = expr_span expr in let* (func_ty, func_eff) = synth ctx func in - begin match repr func_ty with - | TArrow (param_ty, ret_ty, call_eff) -> - let* arg_eff = check ctx arg param_ty in - Ok (ret_ty, union_eff [func_eff; arg_eff; call_eff]) - | TVar _ as tv -> - (* Infer function type *) - let param_ty = fresh_tyvar ctx.level in - let ret_ty = fresh_tyvar ctx.level in - let call_eff = fresh_effvar ctx.level in - begin match Unify.unify tv (TArrow (param_ty, ret_ty, call_eff)) with - | Ok () -> - let* arg_eff = check ctx arg param_ty in - Ok (ret_ty, union_eff [func_eff; arg_eff; call_eff]) - | Error e -> - Error (UnificationFailed (e, span)) - end - | _ -> - Error (ExpectedFunction (func_ty, span)) - end + synth_app ctx func_ty func_eff args span - | ELam lam -> + | ExprLambda lam -> (* For lambdas, we need annotations or we infer fresh variables *) - let param_tys = List.map (fun (id, ty_opt, _q) -> - match ty_opt with - | Some ty -> (id, ast_to_ty ctx ty) - | None -> (id, fresh_tyvar ctx.level) - ) lam.lam_params in + let param_tys = List.map (fun param -> + (param.p_name, ast_to_ty ctx param.p_ty) + ) lam.elam_params in (* Bind parameters *) List.iter (fun (id, ty) -> bind_var ctx id ty) param_tys; (* Infer body *) - let* (body_ty, body_eff) = synth ctx lam.lam_body in + let* (body_ty, body_eff) = synth ctx lam.elam_body in (* Build arrow type *) let ty = List.fold_right (fun (_, param_ty) acc -> TArrow (param_ty, acc, body_eff) ) param_tys body_ty in Ok (ty, EPure) - | ELet lb -> + | ExprLet lb -> (* Infer RHS at higher level for generalization *) let ctx' = enter_level ctx in - let* (rhs_ty, rhs_eff) = synth ctx' lb.lb_rhs in + let* (rhs_ty, rhs_eff) = synth ctx' lb.el_value in (* Generalize *) let scheme = generalize ctx rhs_ty in (* Bind pattern *) - let* () = bind_pattern ctx lb.lb_pat scheme in - (* Infer body *) - let* (body_ty, body_eff) = synth ctx lb.lb_body in - Ok (body_ty, union_eff [rhs_eff; body_eff]) + let* () = bind_pattern ctx lb.el_pat scheme in + (* Infer body if present *) + begin match lb.el_body with + | Some body -> + let* (body_ty, body_eff) = synth ctx body in + Ok (body_ty, union_eff [rhs_eff; body_eff]) + | None -> + Ok (ty_unit, rhs_eff) + end - | EIf (cond, then_, else_, span) -> - let* cond_eff = check ctx cond ty_bool in - let* (then_ty, then_eff) = synth ctx then_ in - let* else_eff = check ctx else_ then_ty in - Ok (then_ty, union_eff [cond_eff; then_eff; else_eff]) + | ExprIf ei -> + let* cond_eff = check ctx ei.ei_cond ty_bool in + let* (then_ty, then_eff) = synth ctx ei.ei_then in + begin match ei.ei_else with + | Some else_expr -> + let* else_eff = check ctx else_expr then_ty in + Ok (then_ty, union_eff [cond_eff; then_eff; else_eff]) + | None -> + Ok (ty_unit, union_eff [cond_eff; then_eff]) + end + + | ExprMatch em -> + let* (scrut_ty, scrut_eff) = synth ctx em.em_scrutinee in + begin match em.em_arms with + | [] -> Error (CannotInfer (expr_span expr)) + | first_arm :: rest_arms -> + let* () = check_pattern ctx first_arm.ma_pat scrut_ty in + let* (arm_ty, arm_eff) = synth ctx first_arm.ma_body in + let* effs = List.fold_left (fun acc arm -> + let* effs = acc in + let* () = check_pattern ctx arm.ma_pat scrut_ty in + let* eff = check ctx arm.ma_body arm_ty in + Ok (eff :: effs) + ) (Ok [arm_eff]) rest_arms in + Ok (arm_ty, union_eff (scrut_eff :: effs)) + end - | ETuple (exprs, _) -> + | ExprTuple exprs -> let* results = synth_list ctx exprs in let tys = List.map fst results in let effs = List.map snd results in Ok (TTuple tys, union_eff effs) - | ERecord (fields, _) -> - let* field_results = synth_fields ctx fields in + | ExprArray exprs -> + begin match exprs with + | [] -> Ok (TApp (TCon "Array", [fresh_tyvar ctx.level]), EPure) + | first :: rest -> + let* (elem_ty, first_eff) = synth ctx first in + let* effs = List.fold_left (fun acc e -> + let* effs = acc in + let* eff = check ctx e elem_ty in + Ok (eff :: effs) + ) (Ok [first_eff]) rest in + Ok (TApp (TCon "Array", [elem_ty]), union_eff effs) + end + + | ExprRecord er -> + let* field_results = synth_record_fields ctx er.er_fields in let row = List.fold_right (fun (name, ty, _eff) acc -> RExtend (name, ty, acc) ) field_results REmpty in let effs = List.map (fun (_, _, eff) -> eff) field_results in Ok (TRecord row, union_eff effs) - | ERecordAccess (expr, field, span) -> - let* (expr_ty, expr_eff) = synth ctx expr in - begin match repr expr_ty with + | ExprField (base, field) -> + let span = expr_span expr in + let* (base_ty, base_eff) = synth ctx base in + begin match repr base_ty with | TRecord row -> - begin match find_field field.id_name row with - | Some ty -> Ok (ty, expr_eff) - | None -> Error (UndefinedField (field.id_name, span)) + begin match find_field field.name row with + | Some ty -> Ok (ty, base_eff) + | None -> Error (UndefinedField (field.name, span)) end | TVar _ as tv -> let field_ty = fresh_tyvar ctx.level in let rest = fresh_rowvar ctx.level in - let row = RExtend (field.id_name, field_ty, rest) in + let row = RExtend (field.name, field_ty, rest) in begin match Unify.unify tv (TRecord row) with - | Ok () -> Ok (field_ty, expr_eff) + | Ok () -> Ok (field_ty, base_eff) + | Error e -> Error (UnificationFailed (e, span)) + end + | _ -> + Error (ExpectedRecord (base_ty, span)) + end + + | ExprTupleIndex (base, idx) -> + let span = expr_span expr in + let* (base_ty, base_eff) = synth ctx base in + begin match repr base_ty with + | TTuple tys when idx >= 0 && idx < List.length tys -> + Ok (List.nth tys idx, base_eff) + | TTuple _ -> + Error (ArityMismatch (idx + 1, 0, span)) + | _ -> + Error (ExpectedTuple (base_ty, span)) + end + + | ExprIndex (arr, idx_expr) -> + let span = expr_span expr in + let* (arr_ty, arr_eff) = synth ctx arr in + let* idx_eff = check ctx idx_expr ty_int in + begin match repr arr_ty with + | TApp (TCon "Array", [elem_ty]) -> + Ok (elem_ty, union_eff [arr_eff; idx_eff]) + | TVar _ as tv -> + let elem_ty = fresh_tyvar ctx.level in + begin match Unify.unify tv (TApp (TCon "Array", [elem_ty])) with + | Ok () -> Ok (elem_ty, union_eff [arr_eff; idx_eff]) | Error e -> Error (UnificationFailed (e, span)) end | _ -> - Error (ExpectedRecord (expr_ty, span)) + Error (CannotInfer span) end - | EBlock (exprs, _) -> - synth_block ctx exprs + | ExprBlock blk -> + synth_block ctx blk - | EBinOp (left, op, right, span) -> + | ExprBinary (left, op, right) -> + let span = expr_span expr in synth_binop ctx left op right span - | EPerform (op, arg, span) -> - (* TODO: Look up effect operation type *) - let* (_arg_ty, arg_eff) = synth ctx arg in - let eff = ESingleton op.id_name in - let ret_ty = fresh_tyvar ctx.level in - Ok (ret_ty, union_eff [arg_eff; eff]) - - | EHandle (body, handler, span) -> - let* (body_ty, body_eff) = synth ctx body in - (* TODO: Check handler and compute resulting effect *) - let result_ty = match handler.h_return with - | Some _ -> fresh_tyvar ctx.level - | None -> body_ty - in - Ok (result_ty, EPure) (* TODO: Proper effect computation *) + | ExprUnary (op, operand) -> + synth_unary ctx op operand - | _ -> - Error (CannotInfer (Span.dummy)) + | ExprReturn e_opt -> + (* Return types need context from enclosing function *) + begin match e_opt with + | Some e -> + let* (ty, eff) = synth ctx e in + Ok (ty, eff) + | None -> + Ok (ty_unit, EPure) + end + + | ExprHandle eh -> + let* (body_ty, _body_eff) = synth ctx eh.eh_body in + (* TODO: Check handlers and compute resulting effect *) + Ok (body_ty, EPure) + + | ExprResume e_opt -> + begin match e_opt with + | Some e -> + let* (ty, eff) = synth ctx e in + Ok (ty, eff) + | None -> + Ok (ty_unit, EPure) + end + + | ExprTry et -> + let* (body_ty, body_eff) = synth_block ctx et.et_body in + (* TODO: Check catch arms and finally block *) + Ok (body_ty, body_eff) + + | ExprRowRestrict (base, _field) -> + let* (base_ty, base_eff) = synth ctx base in + (* Row restriction removes a field from a record type *) + Ok (base_ty, base_eff) (* TODO: Proper row restriction *) + + | ExprUnsafe _ -> + Ok (fresh_tyvar ctx.level, EPure) + + | ExprVariant (ty_id, _variant_id) -> + Ok (TCon ty_id.name, EPure) + + | ExprSpan (e, _span) -> + synth ctx e + +and synth_app (ctx : context) (func_ty : ty) (func_eff : effect) + (args : expr list) (span : Span.t) : (ty * effect) result = + match args with + | [] -> Ok (func_ty, func_eff) + | arg :: rest -> + begin match repr func_ty with + | TArrow (param_ty, ret_ty, call_eff) -> + let* arg_eff = check ctx arg param_ty in + synth_app ctx ret_ty (union_eff [func_eff; arg_eff; call_eff]) rest span + | TVar _ as tv -> + let param_ty = fresh_tyvar ctx.level in + let ret_ty = fresh_tyvar ctx.level in + let call_eff = fresh_effvar ctx.level in + begin match Unify.unify tv (TArrow (param_ty, ret_ty, call_eff)) with + | Ok () -> + let* arg_eff = check ctx arg param_ty in + synth_app ctx ret_ty (union_eff [func_eff; arg_eff; call_eff]) rest span + | Error e -> + Error (UnificationFailed (e, span)) + end + | _ -> + Error (ExpectedFunction (func_ty, span)) + end (** Check an expression against an expected type *) and check (ctx : context) (expr : expr) (expected : ty) : effect result = match (expr, repr expected) with (* Lambda checking *) - | (ELam lam, TArrow (param_ty, ret_ty, arr_eff)) -> - (* Bind parameters with expected types *) - begin match lam.lam_params with - | [(id, _, _)] -> - bind_var ctx id param_ty; - let* body_eff = check ctx lam.lam_body ret_ty in + | (ExprLambda lam, TArrow (param_ty, ret_ty, arr_eff)) -> + begin match lam.elam_params with + | [param] -> + bind_var ctx param.p_name param_ty; + let* body_eff = check ctx lam.elam_body ret_ty in begin match Unify.unify_eff body_eff arr_eff with | Ok () -> Ok EPure | Error e -> Error (UnificationFailed (e, Span.dummy)) end | _ -> - (* TODO: Multi-param lambdas *) + (* Multi-param lambdas: fall through to subsumption *) check_subsumption ctx expr expected end (* If checking *) - | (EIf (cond, then_, else_, _), _) -> - let* cond_eff = check ctx cond ty_bool in - let* then_eff = check ctx then_ expected in - let* else_eff = check ctx else_ expected in - Ok (union_eff [cond_eff; then_eff; else_eff]) + | (ExprIf ei, _) -> + let* cond_eff = check ctx ei.ei_cond ty_bool in + let* then_eff = check ctx ei.ei_then expected in + begin match ei.ei_else with + | Some else_expr -> + let* else_eff = check ctx else_expr expected in + Ok (union_eff [cond_eff; then_eff; else_eff]) + | None -> + (* If without else must have unit type *) + begin match Unify.unify expected ty_unit with + | Ok () -> Ok (union_eff [cond_eff; then_eff]) + | Error e -> Error (UnificationFailed (e, expr_span expr)) + end + end (* Tuple checking *) - | (ETuple (exprs, _), TTuple tys) when List.length exprs = List.length tys -> + | (ExprTuple exprs, TTuple tys) when List.length exprs = List.length tys -> let* effs = check_list ctx exprs tys in Ok (union_eff effs) + (* Match checking *) + | (ExprMatch em, _) -> + let* (scrut_ty, scrut_eff) = synth ctx em.em_scrutinee in + let* effs = List.fold_left (fun acc arm -> + let* effs = acc in + let* () = check_pattern ctx arm.ma_pat scrut_ty in + let* eff = check ctx arm.ma_body expected in + Ok (eff :: effs) + ) (Ok [scrut_eff]) em.em_arms in + Ok (union_eff effs) + + (* Block checking *) + | (ExprBlock blk, _) -> + check_block ctx blk expected + (* Subsumption: synth and unify *) | _ -> check_subsumption ctx expr expected @@ -383,7 +588,7 @@ and check_subsumption (ctx : context) (expr : expr) (expected : ty) : effect res let* (actual, eff) = synth ctx expr in match Unify.unify actual expected with | Ok () -> Ok eff - | Error e -> Error (UnificationFailed (e, Span.dummy)) + | Error e -> Error (UnificationFailed (e, expr_span expr)) and synth_list (ctx : context) (exprs : expr list) : ((ty * effect) list) result = List.fold_right (fun expr acc -> @@ -405,58 +610,148 @@ and check_list (ctx : context) (exprs : expr list) (tys : ty list) : (effect lis | Ok eff -> Ok (eff :: effs) ) exprs tys (Ok []) -and synth_fields (ctx : context) (fields : (ident * expr) list) +and synth_record_fields (ctx : context) (fields : (ident * expr option) list) : ((string * ty * effect) list) result = - List.fold_right (fun (id, expr) acc -> + List.fold_right (fun (id, expr_opt) acc -> match acc with | Error e -> Error e | Ok results -> - match synth ctx expr with - | Error e -> Error e - | Ok (ty, eff) -> Ok ((id.id_name, ty, eff) :: results) + match expr_opt with + | Some expr -> + begin match synth ctx expr with + | Error e -> Error e + | Ok (ty, eff) -> Ok ((id.name, ty, eff) :: results) + end + | None -> + (* Punning: {x} is short for {x: x} *) + begin match lookup_var ctx id with + | Error e -> Error e + | Ok ty -> Ok ((id.name, ty, EPure) :: results) + end ) fields (Ok []) -and synth_block (ctx : context) (exprs : expr list) : (ty * effect) result = - match exprs with - | [] -> Ok (ty_unit, EPure) - | [e] -> synth ctx e - | e :: rest -> - let* (_, eff1) = synth ctx e in - let* (ty, eff2) = synth_block ctx rest in - Ok (ty, union_eff [eff1; eff2]) +and synth_block (ctx : context) (blk : block) : (ty * effect) result = + let* effs = List.fold_left (fun acc stmt -> + let* effs = acc in + let* eff = synth_stmt ctx stmt in + Ok (eff :: effs) + ) (Ok []) blk.blk_stmts in + match blk.blk_expr with + | Some e -> + let* (ty, eff) = synth ctx e in + Ok (ty, union_eff (eff :: effs)) + | None -> + Ok (ty_unit, union_eff effs) + +and check_block (ctx : context) (blk : block) (expected : ty) : effect result = + let* effs = List.fold_left (fun acc stmt -> + let* effs = acc in + let* eff = synth_stmt ctx stmt in + Ok (eff :: effs) + ) (Ok []) blk.blk_stmts in + match blk.blk_expr with + | Some e -> + let* eff = check ctx e expected in + Ok (union_eff (eff :: effs)) + | None -> + begin match Unify.unify expected ty_unit with + | Ok () -> Ok (union_eff effs) + | Error e -> Error (UnificationFailed (e, Span.dummy)) + end -and synth_binop (ctx : context) (left : expr) (op : binop) (right : expr) +and synth_stmt (ctx : context) (stmt : stmt) : effect result = + match stmt with + | StmtLet sl -> + let ctx' = enter_level ctx in + let* (rhs_ty, rhs_eff) = synth ctx' sl.sl_value in + let scheme = generalize ctx rhs_ty in + let* () = bind_pattern ctx sl.sl_pat scheme in + Ok rhs_eff + | StmtExpr e -> + let* (_, eff) = synth ctx e in + Ok eff + | StmtAssign (lhs, _op, rhs) -> + let* (lhs_ty, lhs_eff) = synth ctx lhs in + let* rhs_eff = check ctx rhs lhs_ty in + Ok (union_eff [lhs_eff; rhs_eff]) + | StmtWhile (cond, body) -> + let* cond_eff = check ctx cond ty_bool in + let* (_, body_eff) = synth_block ctx body in + Ok (union_eff [cond_eff; body_eff]) + | StmtFor (pat, iter, body) -> + let* (iter_ty, iter_eff) = synth ctx iter in + (* Assume iterator yields element type *) + let elem_ty = fresh_tyvar ctx.level in + let* () = check_pattern ctx pat elem_ty in + let* (_, body_eff) = synth_block ctx body in + let _ = iter_ty in (* Silence unused warning for now *) + Ok (union_eff [iter_eff; body_eff]) + +and synth_binop (ctx : context) (left : expr) (op : binary_op) (right : expr) (span : Span.t) : (ty * effect) result = let* (left_ty, left_eff) = synth ctx left in let* (right_ty, right_eff) = synth ctx right in let eff = union_eff [left_eff; right_eff] in match op with - | BAdd | BSub | BMul | BDiv | BMod -> + | OpAdd | OpSub | OpMul | OpDiv | OpMod -> begin match Unify.unify left_ty ty_int, Unify.unify right_ty ty_int with | Ok (), Ok () -> Ok (ty_int, eff) | Error e, _ | _, Error e -> Error (UnificationFailed (e, span)) end - | BEq | BNe | BLt | BLe | BGt | BGe -> + | OpEq | OpNe | OpLt | OpLe | OpGt | OpGe -> begin match Unify.unify left_ty right_ty with | Ok () -> Ok (ty_bool, eff) | Error e -> Error (UnificationFailed (e, span)) end - | BAnd | BOr -> + | OpAnd | OpOr -> begin match Unify.unify left_ty ty_bool, Unify.unify right_ty ty_bool with | Ok (), Ok () -> Ok (ty_bool, eff) | Error e, _ | _, Error e -> Error (UnificationFailed (e, span)) end - | _ -> - (* TODO: Other operators *) - Ok (fresh_tyvar ctx.level, eff) + | OpBitAnd | OpBitOr | OpBitXor | OpShl | OpShr -> + begin match Unify.unify left_ty ty_int, Unify.unify right_ty ty_int with + | Ok (), Ok () -> Ok (ty_int, eff) + | Error e, _ | _, Error e -> Error (UnificationFailed (e, span)) + end + +and synth_unary (ctx : context) (op : unary_op) (operand : expr) : (ty * effect) result = + let* (operand_ty, operand_eff) = synth ctx operand in + match op with + | OpNeg -> + begin match Unify.unify operand_ty ty_int with + | Ok () -> Ok (ty_int, operand_eff) + | Error _ -> + begin match Unify.unify operand_ty ty_float with + | Ok () -> Ok (ty_float, operand_eff) + | Error e -> Error (UnificationFailed (e, expr_span operand)) + end + end + | OpNot -> + begin match Unify.unify operand_ty ty_bool with + | Ok () -> Ok (ty_bool, operand_eff) + | Error e -> Error (UnificationFailed (e, expr_span operand)) + end + | OpBitNot -> + begin match Unify.unify operand_ty ty_int with + | Ok () -> Ok (ty_int, operand_eff) + | Error e -> Error (UnificationFailed (e, expr_span operand)) + end + | OpRef -> + Ok (TRef operand_ty, operand_eff) + | OpDeref -> + begin match repr operand_ty with + | TRef t | TMut t | TOwn t -> Ok (t, operand_eff) + | _ -> Error (CannotInfer (expr_span operand)) + end and bind_pattern (ctx : context) (pat : pattern) (scheme : scheme) : unit result = match pat with - | PVar id -> + | PatVar id -> bind_var_scheme ctx id scheme; Ok () - | PWild _ -> Ok () - | PTuple (pats, _) -> + | PatWildcard _ -> Ok () + | PatLit _ -> Ok () (* Literal patterns don't bind *) + | PatTuple pats -> begin match scheme.sc_body with | TTuple tys when List.length pats = List.length tys -> List.fold_left2 (fun acc pat ty -> @@ -468,18 +763,56 @@ and bind_pattern (ctx : context) (pat : pattern) (scheme : scheme) : unit result ) (Ok ()) pats tys | _ -> Error (InvalidPattern Span.dummy) end - | _ -> - (* TODO: Other patterns *) - Ok () + | PatRecord (fields, _has_rest) -> + begin match scheme.sc_body with + | TRecord row -> + List.fold_left (fun acc (field_id, pat_opt) -> + match acc with + | Error e -> Error e + | Ok () -> + match find_field field_id.name row with + | Some ty -> + begin match pat_opt with + | Some p -> + let sc = { scheme with sc_body = ty } in + bind_pattern ctx p sc + | None -> + bind_var ctx field_id ty; + Ok () + end + | None -> Error (InvalidPattern field_id.span) + ) (Ok ()) fields + | _ -> Error (InvalidPattern Span.dummy) + end + | PatCon (_con, pats) -> + (* TODO: Look up constructor type and bind subpatterns *) + List.fold_left (fun acc pat -> + match acc with + | Error e -> Error e + | Ok () -> + let sc = { scheme with sc_body = fresh_tyvar ctx.level } in + bind_pattern ctx pat sc + ) (Ok ()) pats + | PatOr (p1, p2) -> + (* Both branches must bind the same variables with same types *) + let* () = bind_pattern ctx p1 scheme in + bind_pattern ctx p2 scheme + | PatAs (id, pat) -> + bind_var_scheme ctx id scheme; + bind_pattern ctx pat scheme + +and check_pattern (ctx : context) (pat : pattern) (expected : ty) : unit result = + let scheme = { sc_tyvars = []; sc_effvars = []; sc_rowvars = []; sc_body = expected } in + bind_pattern ctx pat scheme and synth_literal (lit : literal) : ty = match lit with - | LUnit _ -> ty_unit - | LBool _ -> ty_bool - | LInt _ -> ty_int - | LFloat _ -> ty_float - | LChar _ -> ty_char - | LString _ -> ty_string + | LitUnit _ -> ty_unit + | LitBool _ -> ty_bool + | LitInt _ -> ty_int + | LitFloat _ -> ty_float + | LitChar _ -> ty_char + | LitString _ -> ty_string and find_field (name : string) (row : row) : ty option = match repr_row row with @@ -496,18 +829,13 @@ and union_eff (effs : effect list) : effect = | [e] -> e | es -> EUnion es -(* Result bind *) -let ( let* ) = Result.bind - (** Type check a declaration *) -let check_decl (ctx : context) (decl : decl) : unit result = +let check_decl (ctx : context) (decl : top_level) : unit result = match decl with - | DFun fd -> + | TopFn fd -> (* Create function type from signature *) - let param_tys = List.map (fun (id, ty_opt, _q) -> - match ty_opt with - | Some ty -> (id, ast_to_ty ctx ty) - | None -> (id, fresh_tyvar ctx.level) + let param_tys = List.map (fun param -> + (param.p_name, ast_to_ty ctx param.p_ty) ) fd.fd_params in let ret_ty = match fd.fd_ret_ty with | Some ty -> ast_to_ty ctx ty @@ -521,38 +849,35 @@ let check_decl (ctx : context) (decl : decl) : unit result = bind_var ctx fd.fd_name func_ty; (* Bind parameters *) List.iter (fun (id, ty) -> bind_var ctx id ty) param_tys; - (* Check body if present *) + (* Check body *) begin match fd.fd_body with - | Some body -> - let* _ = check ctx body ret_ty in + | FnBlock blk -> + let* _ = check_block ctx blk ret_ty in + Ok () + | FnExpr e -> + let* _ = check ctx e ret_ty in Ok () - | None -> Ok () end - | DType _ -> + | TopType _ -> (* TODO: Check type definitions *) Ok () - | DEffect _ -> + | TopEffect _ -> (* TODO: Register effect *) Ok () - | DTrait _ -> + | TopTrait _ -> (* TODO: Check trait definitions *) Ok () - | DImpl _ -> + | TopImpl _ -> (* TODO: Check implementations *) Ok () - | DModule (_, decls, _) -> - List.fold_left (fun acc d -> - match acc with - | Error e -> Error e - | Ok () -> check_decl ctx d - ) (Ok ()) decls - - | DImport _ -> + | TopConst tc -> + let expected = ast_to_ty ctx tc.tc_ty in + let* _ = check ctx tc.tc_value expected in Ok () (** Type check a program *) diff --git a/runtime/src/gc.rs b/runtime/src/gc.rs index 5bceb3d..63191a0 100644 --- a/runtime/src/gc.rs +++ b/runtime/src/gc.rs @@ -206,8 +206,11 @@ pub extern "C" fn gc_alloc(size: usize, type_tag: u32) -> *mut GcHeader { } } - // Allocate header + data - let total_size = core::mem::size_of::() + size; + // Allocate header + data (with overflow check) + let total_size = match size.checked_add(core::mem::size_of::()) { + Some(s) => s, + None => return core::ptr::null_mut(), // Size overflow + }; let ptr = crate::alloc::allocate(total_size, core::mem::align_of::()); if ptr.is_null() { diff --git a/tools/affine-doc/assets/search.js b/tools/affine-doc/assets/search.js index 0601e10..e5580b3 100644 --- a/tools/affine-doc/assets/search.js +++ b/tools/affine-doc/assets/search.js @@ -4,6 +4,29 @@ (function() { 'use strict'; + // HTML escape function to prevent XSS + function escapeHtml(text) { + if (typeof text !== 'string') return ''; + const div = document.createElement('div'); + div.textContent = text; + return div.innerHTML; + } + + // Safe URL validation + function isValidUrl(url) { + if (typeof url !== 'string') return false; + // Only allow relative URLs or same-origin URLs + try { + if (url.startsWith('/') || url.startsWith('./') || url.startsWith('../')) { + return true; + } + const parsed = new URL(url, window.location.origin); + return parsed.origin === window.location.origin; + } catch { + return false; + } + } + // Wait for DOM and search index to load document.addEventListener('DOMContentLoaded', function() { const searchInput = document.getElementById('search'); @@ -26,8 +49,7 @@ // Search function function search(query) { if (!query || query.length < 2) { - searchResults.innerHTML = ''; - searchResults.style.display = 'none'; + clearResults(); return; } @@ -52,8 +74,8 @@ } function computeScore(entry, query) { - const nameLower = entry.name.toLowerCase(); - const pathLower = entry.path.toLowerCase(); + const nameLower = (entry.name || '').toLowerCase(); + const pathLower = (entry.path || '').toLowerCase(); if (nameLower === query) return 100; if (nameLower.startsWith(query)) return 50; @@ -63,25 +85,62 @@ return 0; } + function clearResults() { + while (searchResults.firstChild) { + searchResults.removeChild(searchResults.firstChild); + } + searchResults.style.display = 'none'; + } + function renderResults(results) { + clearResults(); + if (results.length === 0) { - searchResults.innerHTML = '
No results found
'; + const noResults = document.createElement('div'); + noResults.className = 'no-results'; + noResults.textContent = 'No results found'; + searchResults.appendChild(noResults); searchResults.style.display = 'block'; return; } - const html = results.map(function(r) { - return ` - - ${r.entry.kind} - ${r.entry.name} - ${r.entry.path} - ${r.entry.description ? `${r.entry.description}` : ''} - - `; - }).join(''); - - searchResults.innerHTML = html; + for (const r of results) { + const link = document.createElement('a'); + link.className = 'search-result'; + + // Validate URL before setting href + const url = r.entry.url; + if (isValidUrl(url)) { + link.href = url; + } else { + link.href = '#'; + } + + const kindSpan = document.createElement('span'); + kindSpan.className = 'result-kind'; + kindSpan.textContent = r.entry.kind || ''; + link.appendChild(kindSpan); + + const nameSpan = document.createElement('span'); + nameSpan.className = 'result-name'; + nameSpan.textContent = r.entry.name || ''; + link.appendChild(nameSpan); + + const pathSpan = document.createElement('span'); + pathSpan.className = 'result-path'; + pathSpan.textContent = r.entry.path || ''; + link.appendChild(pathSpan); + + if (r.entry.description) { + const descSpan = document.createElement('span'); + descSpan.className = 'result-desc'; + descSpan.textContent = r.entry.description; + link.appendChild(descSpan); + } + + searchResults.appendChild(link); + } + searchResults.style.display = 'block'; }