diff --git a/lib/parse.ml b/lib/parse.ml index 03165f82..1fedbc23 100644 --- a/lib/parse.ml +++ b/lib/parse.ml @@ -304,6 +304,8 @@ module Reader = struct let consumed = match t.parse_state with | Fail _ -> 0 + (* Don't feed empty input when we're at a request boundary *) + | Done when len = 0 -> 0 | Done -> start t (AU.parse t.parser); read_with_more t bs ~off ~len more; @@ -311,8 +313,8 @@ module Reader = struct transition t (continue bs more ~off ~len) in begin match more with - | Complete -> t.closed <- true; - | Incomplete -> () + | Complete when consumed = len -> t.closed <- true; + | Complete | Incomplete -> () end; consumed; ;; @@ -322,13 +324,11 @@ module Reader = struct ;; let next t = - if t.closed - then `Close - else ( - match t.parse_state with - | Fail err -> `Error err - | Done -> `Read - | Partial _ -> `Read - ) + match t.parse_state with + | Fail err -> `Error err + | Done | Partial _ -> + if t.closed + then `Close + else `Read ;; end diff --git a/lib/server_connection.ml b/lib/server_connection.ml index e48a25c3..8a6e4ea9 100644 --- a/lib/server_connection.ml +++ b/lib/server_connection.ml @@ -64,15 +64,15 @@ type t = ; request_queue : Reqd.t Queue.t (* invariant: If [request_queue] is not empty, then the head of the queue has already had [request_handler] called on it. *) + ; mutable is_errored : bool + (* if there is a parse or connection error, we invoke the [error_handler] + and set [is_errored] to indicate we should not close the writer yet. *) ; mutable wakeup_reader : Optional_thunk.t } let is_closed t = Reader.is_closed t.reader && Writer.is_closed t.writer -let is_waiting t = - not (is_closed t) && Queue.is_empty t.request_queue - let is_active t = not (Queue.is_empty t.request_queue) @@ -134,6 +134,7 @@ let create ?(config=Config.default) ?(error_handler=default_error_handler) reque ; request_handler = request_handler ; error_handler = error_handler ; request_queue + ; is_errored = false ; wakeup_reader = Optional_thunk.none } @@ -166,6 +167,7 @@ let set_error_and_handle ?request t error = let reqd = current_reqd_exn t in Reqd.report_error reqd error end else begin + t.is_errored <- true; let status = match (error :> [error | Status.standard]) with | `Exn _ -> `Internal_server_error @@ -191,8 +193,11 @@ let advance_request_queue t = let rec _next_read_operation t = if not (is_active t) then ( - if Reader.is_closed t.reader - then shutdown t; + (* If the request queue is empty, there is no connection error, and the + reader is closed, then we can assume that no more user code will be able + to write. *) + if Reader.is_closed t.reader && not t.is_errored + then shutdown_writer t; Reader.next t.reader ) else ( let reqd = current_reqd_exn t in @@ -289,6 +294,8 @@ and _final_write_operation_for t reqd = _next_write_operation t; ) in + (* The only reason the reader yields is to wait for the writer, so we need to + notify it that we've completed. *) wakeup_reader t; next ;; diff --git a/lib_test/test_server_connection.ml b/lib_test/test_server_connection.ml index 472ebcb3..8b3c954b 100644 --- a/lib_test/test_server_connection.ml +++ b/lib_test/test_server_connection.ml @@ -1,7 +1,7 @@ open Httpaf open Helpers -let trace fmt = Format.ksprintf (Format.printf "%s\n") fmt +let trace fmt = Format.ksprintf (Format.printf "%s\n%!") fmt let request_error_pp_hum fmt = function | `Bad_request -> Format.fprintf fmt "Bad_request" @@ -94,6 +94,15 @@ end = struct ;; let create ?config ?error_handler request_handler = + let request_handler r = + trace "invoked: request_handler"; + request_handler r + in + let error_handler = + Option.map (fun error_handler ?request -> + trace "invoked: request_handler"; + error_handler ?request) error_handler + in let rec t = lazy ( { server_connection = create ?config ?error_handler request_handler @@ -126,23 +135,27 @@ end = struct let do_read t f = match current_read_operation t with | `Read -> + trace "read: start"; let res = f t.server_connection in + trace "read: finished"; t.read_loop (); res | `Yield | `Close as op -> - Alcotest.failf "Read attempted during operation: %a" - Read_operation.pp_hum op + Alcotest.failf "Read attempted during operation: %a" + Read_operation.pp_hum op ;; let do_write t f = match current_write_operation t with | `Write bufs -> - let res = f t.server_connection bufs in - t.write_loop (); - res + trace "write: start"; + let res = f t.server_connection bufs in + trace "write: finished"; + t.write_loop (); + res | `Yield | `Close _ as op -> - Alcotest.failf "Write attempted during operation: %a" - Write_operation.pp_hum op + Alcotest.failf "Write attempted during operation: %a" + Write_operation.pp_hum op ;; let on_reader_unyield t f = @@ -263,6 +276,12 @@ let connection_is_shutdown t = writer_closed t; ;; +let raises_writer_closed f = + (* This is raised when you write to a closed [Faraday.t] *) + Alcotest.check_raises "raises because writer is closed" + (Failure "cannot write to closed writer") f +;; + let request_handler_with_body body reqd = Body.Reader.close (Reqd.request_body reqd); Reqd.respond_with_string reqd (Response.create `OK) body @@ -279,7 +298,10 @@ let echo_handler response reqd = Body.Writer.write_string response_body (Bigstringaf.substring ~off ~len buffer); Body.Writer.flush response_body (fun () -> Body.Reader.schedule_read request_body ~on_eof ~on_read) - and on_eof () = print_endline "got eof"; Body.Writer.close response_body in + and on_eof () = + print_endline "echo handler eof"; + Body.Writer.close response_body + in Body.Reader.schedule_read request_body ~on_eof ~on_read; ;; @@ -888,6 +910,33 @@ let test_parse_failure_after_checkpoint () = | Some error -> Alcotest.(check request_error) "Error" error `Bad_request ;; +let test_parse_failure_at_eof () = + let error_queue = ref None in + let continue = ref (fun () -> ()) in + let error_handler ?request error start_response = + Alcotest.(check (option reject)) "Error queue is empty" !error_queue None; + Alcotest.(check (option reject)) "Request was not parsed" request None; + error_queue := Some error; + continue := (fun () -> + let resp_body = start_response Headers.empty in + Body.Writer.write_string resp_body "got an error"; + Body.Writer.close resp_body); + in + let request_handler _reqd = assert false in + let t = create ~error_handler request_handler in + reader_ready t; + read_string t "GET index.html HTTP/1.1\r\n"; + let result = feed_string ~eof:true t " index.html HTTP/1.1\r\n\r\n" in + Alcotest.(check int) "Bad header not consumed" result 0; + reader_closed t; + (match !error_queue with + | None -> Alcotest.fail "Expected error" + | Some error -> Alcotest.(check request_error) "Error" error `Bad_request); + !continue (); + write_response t (Response.create `Bad_request) ~body:"got an error"; + writer_closed t; +;; + let test_response_finished_before_body_read () = let response = Response.create `OK ~headers:(Headers.encoding_fixed 4) in let rev_body_chunks = ref [] in @@ -935,10 +984,7 @@ let test_shutdown_during_asynchronous_request () = in read_request t request; shutdown t; - (* This is raised from Faraday *) - Alcotest.check_raises "[continue] raises because writer is closed" - (Failure "cannot write to closed writer") - !continue; + raises_writer_closed !continue; reader_closed t; writer_closed t ;; @@ -1009,6 +1055,7 @@ let tests = ; "multiple requests with connection close", `Quick, test_multiple_requests_in_single_read_with_close ; "multiple requests with eof", `Quick, test_multiple_requests_in_single_read_with_eof ; "parse failure after checkpoint", `Quick, test_parse_failure_after_checkpoint + ; "parse failure at eof", `Quick, test_parse_failure_at_eof ; "response finished before body read", `Quick, test_response_finished_before_body_read ; "shutdown in request handler", `Quick, test_shutdown_in_request_handler ; "shutdown during asynchronous request", `Quick, test_shutdown_during_asynchronous_request