From da96da4a19056a96fc9174d0bca0beb3e5d57dbf Mon Sep 17 00:00:00 2001 From: Frans Klaver Date: Mon, 10 Feb 2025 11:35:26 +0100 Subject: [PATCH 1/6] chore: reformat code using rustfmt Make consistent formatting using rustfmt. --- src/handler.rs | 49 ++++++++++++------ src/lib.rs | 109 ++++++++++++++++++++++++++--------------- tests/test_parallel.rs | 16 ++++-- tests/test_serial.rs | 39 ++++++++++----- 4 files changed, 145 insertions(+), 68 deletions(-) diff --git a/src/handler.rs b/src/handler.rs index 36035e8..b37434b 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,18 +1,20 @@ -use std::convert::Infallible; -use std::sync::Arc; +use crate::StatusCode; use futures::future::BoxFuture; use hyper::{Body, HeaderMap, Method, Request}; -use crate::StatusCode; +use std::convert::Infallible; +use std::sync::Arc; pub type Response = hyper::Response; -pub type HandlerCallback = Arc) -> BoxFuture<'static, Result> + Send + Sync>; +pub type HandlerCallback = Arc< + dyn Fn(Request) -> BoxFuture<'static, Result> + Send + Sync, +>; #[derive(Default, Clone)] pub struct HandlerBuilder { path: String, method: Method, headers: HeaderMap, - status_code: StatusCode + status_code: StatusCode, } #[allow(dead_code)] @@ -22,7 +24,7 @@ impl HandlerBuilder { path: String::from(path), method: Method::GET, headers: HeaderMap::new(), - status_code: StatusCode::INTERNAL_SERVER_ERROR + status_code: StatusCode::INTERNAL_SERVER_ERROR, } } @@ -42,23 +44,39 @@ impl HandlerBuilder { } pub fn build(self) -> HandlerCallback { - let Self { path, method, status_code, headers } = self; + let Self { + path, + method, + status_code, + headers, + } = self; Arc::new(move |req: Request| { let cloned_path = path.clone(); let cloned_method = method.clone(); let cloned_headers = headers.clone(); Box::pin(async move { - if req.uri().path().eq(cloned_path.as_str()) && req.method().eq(&cloned_method) - && Self::contains_headers(req.headers(), &cloned_headers) { - Ok(hyper::Response::builder().status(status_code).body(Body::empty()).unwrap()) + if req.uri().path().eq(cloned_path.as_str()) + && req.method().eq(&cloned_method) + && Self::contains_headers(req.headers(), &cloned_headers) + { + Ok(hyper::Response::builder() + .status(status_code) + .body(Body::empty()) + .unwrap()) } else { - Ok(hyper::Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(Body::empty()).unwrap()) + Ok(hyper::Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::empty()) + .unwrap()) } }) }) } - fn contains_headers(headers_reference: &HeaderMap, headers_to_be_contained: &HeaderMap) -> bool { + fn contains_headers( + headers_reference: &HeaderMap, + headers_to_be_contained: &HeaderMap, + ) -> bool { for (header, value) in headers_to_be_contained { if !headers_reference.get(header).eq(&Some(value)) { return false; @@ -68,6 +86,9 @@ impl HandlerBuilder { } } -pub async fn default_handle(_req: Request) -> Result { - Ok(hyper::Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(Body::empty()).unwrap()) +pub async fn default_handle(_req: Request) -> Result { + Ok(hyper::Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::empty()) + .unwrap()) } diff --git a/src/lib.rs b/src/lib.rs index 3023f27..2429b4c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,26 +1,27 @@ #![doc = include_str!("../README.md")] pub mod handler; +use crate::handler::{default_handle, HandlerCallback}; +use hyper::service::{make_service_fn, service_fn}; +use hyper::{Server, StatusCode, Uri}; +use lazy_static::lazy_static; +use queues::{queue, IsQueue, Queue}; use std::collections::BinaryHeap; +use std::env; use std::future::Future; -use std::net::{SocketAddr}; +use std::net::SocketAddr; +use std::sync::Arc; +use std::sync::Mutex; use test_context::AsyncTestContext; use tokio::sync::oneshot::{Receiver, Sender}; use tokio::task::JoinHandle; -use hyper::{Server, StatusCode, Uri}; -use hyper::service::{make_service_fn, service_fn}; -use lazy_static::lazy_static; -use std::sync::Mutex; -use std::sync::Arc; -use queues::{Queue, IsQueue, queue}; -use crate::handler::{default_handle, HandlerCallback}; -use std::env; pub type Error = Box; pub static TOKIOTEST_HTTP_PORT_ENV: &str = "TOKIOTEST_HTTP_PORT"; lazy_static! { - static ref PORTS: Mutex> = Mutex::new(BinaryHeap::from((12300u16..12400u16).collect::>())); + static ref PORTS: Mutex> = + Mutex::new(BinaryHeap::from((12300u16..12400u16).collect::>())); } /// function that can be called to avoid port collision when tests have to open a listen port @@ -49,30 +50,34 @@ impl HttpTestContext { } pub fn uri(&self, path: &str) -> Uri { - format!("http://{}:{}{}", "localhost", self.port, path).parse::().unwrap() + format!("http://{}:{}{}", "localhost", self.port, path) + .parse::() + .unwrap() } } -pub async fn run_service(addr: SocketAddr, rx: Receiver<()>, - handlers: Arc>>) -> impl Future> { - +pub async fn run_service( + addr: SocketAddr, + rx: Receiver<()>, + handlers: Arc>>, +) -> impl Future> { let new_service = make_service_fn(move |_| { let cloned_handlers = handlers.clone(); async { - Ok::<_, Error>(service_fn(move |req| { - match cloned_handlers.lock() { - Ok(mut handlers_rw) => { - match handlers_rw.remove() { - Ok(handler) => { handler(req) } - Err(_err) => { Box::pin(default_handle(req)) } - } - } - Err(_err_lock) => Box::pin(default_handle(req)) - } + Ok::<_, Error>(service_fn(move |req| match cloned_handlers.lock() { + Ok(mut handlers_rw) => match handlers_rw.remove() { + Ok(handler) => handler(req), + Err(_err) => Box::pin(default_handle(req)), + }, + Err(_err_lock) => Box::pin(default_handle(req)), })) } }); - Server::bind(&addr).serve(new_service).with_graceful_shutdown(async { rx.await.ok(); }) + Server::bind(&addr) + .serve(new_service) + .with_graceful_shutdown(async { + rx.await.ok(); + }) } #[async_trait::async_trait] @@ -80,7 +85,7 @@ impl AsyncTestContext for HttpTestContext { async fn setup() -> HttpTestContext { let port: u16 = match env::var(TOKIOTEST_HTTP_PORT_ENV) { Ok(port_str) => port_str.parse::().unwrap(), - Err(_e) => take_port() + Err(_e) => take_port(), }; let addr = SocketAddr::new("127.0.0.1".parse().unwrap(), port); let (sender, receiver) = tokio::sync::oneshot::channel::<()>(); @@ -90,7 +95,7 @@ impl AsyncTestContext for HttpTestContext { server_handler, sender, port, - handlers + handlers, } } @@ -103,10 +108,10 @@ impl AsyncTestContext for HttpTestContext { #[cfg(test)] mod test { - use hyper::{StatusCode, Method, Request, Body, HeaderMap, Client}; - use crate::{HttpTestContext}; - use test_context::test_context; use crate::handler::HandlerBuilder; + use crate::HttpTestContext; + use hyper::{Body, Client, HeaderMap, Method, Request, StatusCode}; + use test_context::test_context; #[test_context(HttpTestContext)] #[tokio::test] @@ -118,7 +123,11 @@ mod test { #[test_context(HttpTestContext)] #[tokio::test] async fn test_get_respond_404(ctx: &mut HttpTestContext) { - ctx.add(HandlerBuilder::new("/unknown").status_code(StatusCode::NOT_FOUND).build()); + ctx.add( + HandlerBuilder::new("/unknown") + .status_code(StatusCode::NOT_FOUND) + .build(), + ); let resp = Client::new().get(ctx.uri("/unknown")).await.unwrap(); @@ -128,7 +137,11 @@ mod test { #[test_context(HttpTestContext)] #[tokio::test] async fn test_get_endpoint(ctx: &mut HttpTestContext) { - ctx.add(HandlerBuilder::new("/foo").status_code(StatusCode::OK).build()); + ctx.add( + HandlerBuilder::new("/foo") + .status_code(StatusCode::OK) + .build(), + ); let resp = Client::new().get(ctx.uri("/foo")).await.unwrap(); assert_eq!(200, resp.status()); @@ -142,13 +155,28 @@ mod test { async fn test_get_with_headers(ctx: &mut HttpTestContext) { let mut headers = HeaderMap::new(); headers.append("foo", "bar".parse().unwrap()); - ctx.add(HandlerBuilder::new("/headers").status_code(StatusCode::OK).headers(headers.clone()).build()); - ctx.add(HandlerBuilder::new("/headers").status_code(StatusCode::OK).headers(headers).build()); + ctx.add( + HandlerBuilder::new("/headers") + .status_code(StatusCode::OK) + .headers(headers.clone()) + .build(), + ); + ctx.add( + HandlerBuilder::new("/headers") + .status_code(StatusCode::OK) + .headers(headers) + .build(), + ); let resp = Client::new().get(ctx.uri("/headers")).await.unwrap(); assert_eq!(500, resp.status()); - let req = Request::builder().method(Method::GET).uri(ctx.uri("/headers")).header("foo", "bar").body(Body::empty()).unwrap(); + let req = Request::builder() + .method(Method::GET) + .uri(ctx.uri("/headers")) + .header("foo", "bar") + .body(Body::empty()) + .unwrap(); let resp = Client::new().request(req).await.unwrap(); assert_eq!(200, resp.status()); } @@ -156,9 +184,12 @@ mod test { #[test_context(HttpTestContext)] #[tokio::test] async fn test_post_endpoint(ctx: &mut HttpTestContext) { - ctx.add(HandlerBuilder::new("/bar") - .status_code(StatusCode::OK) - .method(Method::POST).build()); + ctx.add( + HandlerBuilder::new("/bar") + .status_code(StatusCode::OK) + .method(Method::POST) + .build(), + ); let req = Request::builder() .method(Method::POST) @@ -170,4 +201,4 @@ mod test { assert_eq!(200, resp.status()); } -} \ No newline at end of file +} diff --git a/tests/test_parallel.rs b/tests/test_parallel.rs index 0ed397e..d2038c6 100644 --- a/tests/test_parallel.rs +++ b/tests/test_parallel.rs @@ -1,12 +1,16 @@ -use hyper::{StatusCode, Client}; +use hyper::{Client, StatusCode}; +use test_context::test_context; use tokiotest_httpserver::handler::HandlerBuilder; use tokiotest_httpserver::HttpTestContext; -use test_context::test_context; #[test_context(HttpTestContext)] #[tokio::test] async fn test_get_respond_200(ctx: &mut HttpTestContext) { - ctx.add(HandlerBuilder::new("/ok").status_code(StatusCode::OK).build()); + ctx.add( + HandlerBuilder::new("/ok") + .status_code(StatusCode::OK) + .build(), + ); let resp = Client::new().get(ctx.uri("/ok")).await.unwrap(); @@ -16,7 +20,11 @@ async fn test_get_respond_200(ctx: &mut HttpTestContext) { #[test_context(HttpTestContext)] #[tokio::test] async fn test_get_respond_404(ctx: &mut HttpTestContext) { - ctx.add(HandlerBuilder::new("/notfound").status_code(StatusCode::NOT_FOUND).build()); + ctx.add( + HandlerBuilder::new("/notfound") + .status_code(StatusCode::NOT_FOUND) + .build(), + ); let resp = Client::new().get(ctx.uri("/notfound")).await.unwrap(); diff --git a/tests/test_serial.rs b/tests/test_serial.rs index 7fcdf28..bfbdcd7 100644 --- a/tests/test_serial.rs +++ b/tests/test_serial.rs @@ -1,17 +1,24 @@ -use tokiotest_httpserver::HttpTestContext; -use hyper::{Uri, StatusCode, Client}; -use tokiotest_httpserver::handler::{HandlerBuilder}; +use hyper::{Client, StatusCode, Uri}; use serial_test::serial; -use test_context::{test_context, AsyncTestContext}; use std::env; +use test_context::{test_context, AsyncTestContext}; +use tokiotest_httpserver::handler::HandlerBuilder; +use tokiotest_httpserver::HttpTestContext; #[test_context(PortContext)] #[tokio::test] #[serial] async fn test_get_respond_200(&mut ctx: PortContext) { - ctx.http_context.add(HandlerBuilder::new("/ok").status_code(StatusCode::OK).build()); + ctx.http_context.add( + HandlerBuilder::new("/ok") + .status_code(StatusCode::OK) + .build(), + ); - let resp = Client::new().get(Uri::from_static("http://localhost:54321/ok")).await.unwrap(); + let resp = Client::new() + .get(Uri::from_static("http://localhost:54321/ok")) + .await + .unwrap(); assert_eq!(200, resp.status()); } @@ -20,9 +27,16 @@ async fn test_get_respond_200(&mut ctx: PortContext) { #[tokio::test] #[serial] async fn test_get_respond_404(&mut ctx: PortContext) { - ctx.http_context.add(HandlerBuilder::new("/notfound").status_code(StatusCode::NOT_FOUND).build()); + ctx.http_context.add( + HandlerBuilder::new("/notfound") + .status_code(StatusCode::NOT_FOUND) + .build(), + ); - let resp = Client::new().get(Uri::from_static("http://localhost:54321/notfound")).await.unwrap(); + let resp = Client::new() + .get(Uri::from_static("http://localhost:54321/notfound")) + .await + .unwrap(); assert_eq!(404, resp.status()); } @@ -30,7 +44,7 @@ async fn test_get_respond_404(&mut ctx: PortContext) { #[allow(dead_code)] struct PortContext { port_string: String, - http_context: HttpTestContext + http_context: HttpTestContext, } #[async_trait::async_trait] @@ -38,11 +52,14 @@ impl AsyncTestContext for PortContext { async fn setup() -> PortContext { let port_string = "54321".to_string(); env::set_var("TOKIOTEST_HTTP_PORT", port_string.clone()); - PortContext { port_string, http_context: HttpTestContext::setup().await } + PortContext { + port_string, + http_context: HttpTestContext::setup().await, + } } async fn teardown(self) { let _ = self.http_context.teardown(); env::remove_var("TOKIOTEST_HTTP_PORT"); } -} \ No newline at end of file +} From 09ae54dd52ff176574e30a6cc1f6825c4e0a9f4f Mon Sep 17 00:00:00 2001 From: Frans Klaver Date: Tue, 11 Feb 2025 09:22:23 +0100 Subject: [PATCH 2/6] fix: test_serial: fix incorrect function signatures Move the &mut to the type. error[E0308]: mismatched types --> tests/test_serial.rs:11:31 | 11 | async fn test_get_respond_200(&mut ctx: PortContext) { | ^^^^^^^^ | | | expected `PortContext`, found `&mut _` | this expression has type `PortContext` | help: to declare a mutable variable use: `mut ctx` | = note: expected struct `PortContext` found mutable reference `&mut _` --- tests/test_serial.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_serial.rs b/tests/test_serial.rs index bfbdcd7..d0a9a87 100644 --- a/tests/test_serial.rs +++ b/tests/test_serial.rs @@ -8,7 +8,7 @@ use tokiotest_httpserver::HttpTestContext; #[test_context(PortContext)] #[tokio::test] #[serial] -async fn test_get_respond_200(&mut ctx: PortContext) { +async fn test_get_respond_200(ctx: &mut PortContext) { ctx.http_context.add( HandlerBuilder::new("/ok") .status_code(StatusCode::OK) @@ -26,7 +26,7 @@ async fn test_get_respond_200(&mut ctx: PortContext) { #[test_context(PortContext)] #[tokio::test] #[serial] -async fn test_get_respond_404(&mut ctx: PortContext) { +async fn test_get_respond_404(ctx: &mut PortContext) { ctx.http_context.add( HandlerBuilder::new("/notfound") .status_code(StatusCode::NOT_FOUND) From f9051258669b8172f7fa3a46a05040eec42654f4 Mon Sep 17 00:00:00 2001 From: Frans Klaver Date: Tue, 11 Feb 2025 09:28:25 +0100 Subject: [PATCH 3/6] fix: test_serial: await the shutdown future We're calling an async function, but aren't awaiting it. Let's do that. warning: non-binding `let` on a future --> tests/test_serial.rs:62:9 | 62 | let _ = self.http_context.teardown(); | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | = help: consider awaiting the future or dropping explicitly with `std::mem::drop` = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#let_underscore_future = note: `#[warn(clippy::let_underscore_future)]` on by default --- tests/test_serial.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_serial.rs b/tests/test_serial.rs index d0a9a87..6c6d752 100644 --- a/tests/test_serial.rs +++ b/tests/test_serial.rs @@ -59,7 +59,7 @@ impl AsyncTestContext for PortContext { } async fn teardown(self) { - let _ = self.http_context.teardown(); + self.http_context.teardown().await; env::remove_var("TOKIOTEST_HTTP_PORT"); } } From 25c2b4b21a557d2468e6127db03563b33f37357f Mon Sep 17 00:00:00 2001 From: Frans Klaver Date: Mon, 10 Feb 2025 11:40:49 +0100 Subject: [PATCH 4/6] chore: update test-context This removes the need for async_trait. --- Cargo.toml | 3 +-- src/lib.rs | 1 - tests/test_serial.rs | 1 - 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0da395b..e1205aa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,10 +15,9 @@ serde_json = "1.0" tokio = { version = "1", features = ["full"] } hyper = { version = "0.14", features = ["full"] } futures = "0.3" -async-trait = "0.1" tower-http = "0.2" tokio-test = "0.4.2" -test-context = "0.1.3" +test-context = "0.4.1" lazy_static = "1.4" queues = "1.1" diff --git a/src/lib.rs b/src/lib.rs index 2429b4c..247e471 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -80,7 +80,6 @@ pub async fn run_service( }) } -#[async_trait::async_trait] impl AsyncTestContext for HttpTestContext { async fn setup() -> HttpTestContext { let port: u16 = match env::var(TOKIOTEST_HTTP_PORT_ENV) { diff --git a/tests/test_serial.rs b/tests/test_serial.rs index 6c6d752..823770f 100644 --- a/tests/test_serial.rs +++ b/tests/test_serial.rs @@ -47,7 +47,6 @@ struct PortContext { http_context: HttpTestContext, } -#[async_trait::async_trait] impl AsyncTestContext for PortContext { async fn setup() -> PortContext { let port_string = "54321".to_string(); From 25f3e193809ed3fd9cbcc7a3cfb3ac7d223d5369 Mon Sep 17 00:00:00 2001 From: Frans Klaver Date: Tue, 11 Feb 2025 09:34:23 +0100 Subject: [PATCH 5/6] fix: remove pointless let binding --- src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 247e471..4ae5eda 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -99,7 +99,7 @@ impl AsyncTestContext for HttpTestContext { } async fn teardown(self) { - let _ = self.sender.send(()).unwrap(); + self.sender.send(()).unwrap(); let _ = tokio::join!(self.server_handler); release_port(self.port); } From 3c2368b7133023fce6ddb567c0b44268386b12ee Mon Sep 17 00:00:00 2001 From: Frans Klaver Date: Mon, 10 Feb 2025 11:32:06 +0100 Subject: [PATCH 6/6] chore: upgrade hyper to 1.x hyper 0.14 is pretty old by now. Update to version 1. --- Cargo.toml | 4 +- README.md | 29 +++++++++--- src/handler.rs | 18 +++---- src/lib.rs | 105 ++++++++++++++++++++++++++++------------- tests/test_parallel.rs | 21 +++++++-- tests/test_serial.rs | 15 ++++-- 6 files changed, 137 insertions(+), 55 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e1205aa..5276e95 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,13 +13,15 @@ edition = "2018" [dependencies] serde_json = "1.0" tokio = { version = "1", features = ["full"] } -hyper = { version = "0.14", features = ["full"] } +hyper = { version = "1", features = ["full"] } futures = "0.3" tower-http = "0.2" tokio-test = "0.4.2" test-context = "0.4.1" lazy_static = "1.4" queues = "1.1" +hyper-util = { version = "0.1.10", features = ["full"] } +http-body-util = "0.1.2" [dev-dependencies] serial_test = "0.6.0" diff --git a/README.md b/README.md index d052821..7daa348 100644 --- a/README.md +++ b/README.md @@ -7,17 +7,32 @@ A small test server utility to run http request against. The test context instantiates a new server with a random port between 12300 and 12400. The test will use this port : ```rust,no_run -use test_context::{AsyncTestContext, test_context}; -use hyper::{Uri, StatusCode, Client}; -use tokiotest_httpserver::handler::{HandlerBuilder}; +use http_body_util::combinators::BoxBody; +use hyper::body::Bytes; +use hyper::{StatusCode, Uri}; +use hyper_util::{ + client::legacy::{connect::HttpConnector, Client}, + rt::TokioExecutor, +}; +use std::convert::Infallible; +use test_context::{test_context, AsyncTestContext}; +use tokiotest_httpserver::handler::HandlerBuilder; use tokiotest_httpserver::HttpTestContext; #[test_context(HttpTestContext)] #[tokio::test] async fn test_get_respond_200(ctx: &mut HttpTestContext) { - ctx.add(HandlerBuilder::new("/ok").status_code(StatusCode::OK).build()); - - let resp = Client::new().get(ctx.uri("/ok")).await.unwrap(); + ctx.add( + HandlerBuilder::new("/ok") + .status_code(StatusCode::OK) + .build(), + ); + + let resp = Client::builder(TokioExecutor::new()) + .build::<_, BoxBody>(HttpConnector::new()) + .get(ctx.uri("/ok")) + .await + .unwrap(); assert_eq!(200, resp.status()); } @@ -31,4 +46,4 @@ It is also possible to use it with a sequential workflow. You just have to inclu With serial workflow you can choose to use a fixed port for the http test server by setting the environment variable `TOKIOTEST_HTTP_PORT` to the desired port. -See for example [test_serial](tests/test_serial.rs). \ No newline at end of file +See for example [test_serial](tests/test_serial.rs). diff --git a/src/handler.rs b/src/handler.rs index b37434b..79935f1 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,12 +1,14 @@ use crate::StatusCode; use futures::future::BoxFuture; -use hyper::{Body, HeaderMap, Method, Request}; +use http_body_util::combinators::BoxBody; +use hyper::body::{Bytes, Incoming}; +use hyper::{HeaderMap, Method, Request}; use std::convert::Infallible; use std::sync::Arc; -pub type Response = hyper::Response; +pub type Response = hyper::Response>; pub type HandlerCallback = Arc< - dyn Fn(Request) -> BoxFuture<'static, Result> + Send + Sync, + dyn Fn(Request) -> BoxFuture<'static, Result> + Send + Sync, >; #[derive(Default, Clone)] @@ -50,7 +52,7 @@ impl HandlerBuilder { status_code, headers, } = self; - Arc::new(move |req: Request| { + Arc::new(move |req: Request| { let cloned_path = path.clone(); let cloned_method = method.clone(); let cloned_headers = headers.clone(); @@ -61,12 +63,12 @@ impl HandlerBuilder { { Ok(hyper::Response::builder() .status(status_code) - .body(Body::empty()) + .body(BoxBody::default()) .unwrap()) } else { Ok(hyper::Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::empty()) + .body(BoxBody::default()) .unwrap()) } }) @@ -86,9 +88,9 @@ impl HandlerBuilder { } } -pub async fn default_handle(_req: Request) -> Result { +pub async fn default_handle(_req: Request) -> Result { Ok(hyper::Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::empty()) + .body(BoxBody::default()) .unwrap()) } diff --git a/src/lib.rs b/src/lib.rs index 4ae5eda..fda5b96 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,19 +2,23 @@ pub mod handler; use crate::handler::{default_handle, HandlerCallback}; -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Server, StatusCode, Uri}; +use hyper::service::service_fn; +use hyper::{StatusCode, Uri}; +use hyper_util::rt::{TokioExecutor, TokioIo}; +use hyper_util::server::graceful::GracefulShutdown; use lazy_static::lazy_static; use queues::{queue, IsQueue, Queue}; use std::collections::BinaryHeap; use std::env; -use std::future::Future; use std::net::SocketAddr; use std::sync::Arc; use std::sync::Mutex; use test_context::AsyncTestContext; +use tokio::net::TcpListener; +use tokio::select; use tokio::sync::oneshot::{Receiver, Sender}; use tokio::task::JoinHandle; +use tokio::time::Duration; pub type Error = Box; pub static TOKIOTEST_HTTP_PORT_ENV: &str = "TOKIOTEST_HTTP_PORT"; @@ -40,7 +44,7 @@ pub fn release_port(port: u16) { pub struct HttpTestContext { pub port: u16, pub handlers: Arc>>, - server_handler: JoinHandle>, + server_handler: JoinHandle>, sender: Sender<()>, } @@ -58,26 +62,47 @@ impl HttpTestContext { pub async fn run_service( addr: SocketAddr, - rx: Receiver<()>, + mut rx: Receiver<()>, handlers: Arc>>, -) -> impl Future> { - let new_service = make_service_fn(move |_| { - let cloned_handlers = handlers.clone(); - async { - Ok::<_, Error>(service_fn(move |req| match cloned_handlers.lock() { - Ok(mut handlers_rw) => match handlers_rw.remove() { - Ok(handler) => handler(req), - Err(_err) => Box::pin(default_handle(req)), - }, - Err(_err_lock) => Box::pin(default_handle(req)), - })) +) -> Result<(), Error> { + let graceful = GracefulShutdown::new(); + let listener = TcpListener::bind(addr).await?; + loop { + select! { + conn = listener.accept() => { + let (stream, _) = conn?; + let io = TokioIo::new(stream); + let cloned_handlers = handlers.clone(); + let builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()); + let conn = builder.serve_connection_with_upgrades( + io, + service_fn(move |req| match cloned_handlers.lock() { + Ok(mut handlers_rw) => match handlers_rw.remove() { + Ok(handler) => handler(req), + Err(_err) => Box::pin(default_handle(req)), + }, + Err(_err_lock) => Box::pin(default_handle(req)), + }), + ); + tokio::spawn(graceful.watch(conn.into_owned())); + }, + _ = &mut rx => { + drop(listener); + break; + } } - }); - Server::bind(&addr) - .serve(new_service) - .with_graceful_shutdown(async { - rx.await.ok(); - }) + } + select! { + _ = graceful.shutdown() => { + Ok(()) + }, + _ = tokio::time::sleep(Duration::from_secs(10)) => { + Err(Box::new(tokio::io::Error::new( + tokio::io::ErrorKind::TimedOut, + "Server graceful shutdown timed out", + ))) + }, + } } impl AsyncTestContext for HttpTestContext { @@ -89,7 +114,7 @@ impl AsyncTestContext for HttpTestContext { let addr = SocketAddr::new("127.0.0.1".parse().unwrap(), port); let (sender, receiver) = tokio::sync::oneshot::channel::<()>(); let handlers: Arc>> = Arc::new(Mutex::new(queue![])); - let server_handler = tokio::spawn(run_service(addr, receiver, handlers.clone()).await); + let server_handler = tokio::spawn(run_service(addr, receiver, handlers.clone())); HttpTestContext { server_handler, sender, @@ -107,15 +132,29 @@ impl AsyncTestContext for HttpTestContext { #[cfg(test)] mod test { + use std::convert::Infallible; + use crate::handler::HandlerBuilder; use crate::HttpTestContext; - use hyper::{Body, Client, HeaderMap, Method, Request, StatusCode}; + use http_body_util::{combinators::BoxBody, Full}; + use hyper::{body::Bytes, HeaderMap, Method, Request, StatusCode}; + use hyper_util::{ + client::legacy::{connect::HttpConnector, Client}, + rt::TokioExecutor, + }; use test_context::test_context; + macro_rules! make_client { + () => { + Client::builder(TokioExecutor::new()) + .build::<_, BoxBody>(HttpConnector::new()) + }; + } + #[test_context(HttpTestContext)] #[tokio::test] async fn test_get_without_expect_should_send_500(ctx: &mut HttpTestContext) { - let resp = Client::new().get(ctx.uri("/whatever")).await.unwrap(); + let resp = make_client!().get(ctx.uri("/whatever")).await.unwrap(); assert_eq!(500, resp.status()); } @@ -128,7 +167,7 @@ mod test { .build(), ); - let resp = Client::new().get(ctx.uri("/unknown")).await.unwrap(); + let resp = make_client!().get(ctx.uri("/unknown")).await.unwrap(); assert_eq!(404, resp.status()); } @@ -142,10 +181,10 @@ mod test { .build(), ); - let resp = Client::new().get(ctx.uri("/foo")).await.unwrap(); + let resp = make_client!().get(ctx.uri("/foo")).await.unwrap(); assert_eq!(200, resp.status()); - let resp = Client::new().get(ctx.uri("/foo")).await.unwrap(); + let resp = make_client!().get(ctx.uri("/foo")).await.unwrap(); assert_eq!(500, resp.status()); } @@ -167,16 +206,16 @@ mod test { .build(), ); - let resp = Client::new().get(ctx.uri("/headers")).await.unwrap(); + let resp = make_client!().get(ctx.uri("/headers")).await.unwrap(); assert_eq!(500, resp.status()); let req = Request::builder() .method(Method::GET) .uri(ctx.uri("/headers")) .header("foo", "bar") - .body(Body::empty()) + .body(BoxBody::default()) .unwrap(); - let resp = Client::new().request(req).await.unwrap(); + let resp = make_client!().request(req).await.unwrap(); assert_eq!(200, resp.status()); } @@ -193,10 +232,10 @@ mod test { let req = Request::builder() .method(Method::POST) .uri(ctx.uri("/bar")) - .body(Body::from("foo=bar")) + .body(BoxBody::new(Full::new(Bytes::from("foo=bar")))) .expect("request builder"); - let resp = Client::new().request(req).await.unwrap(); + let resp = make_client!().request(req).await.unwrap(); assert_eq!(200, resp.status()); } diff --git a/tests/test_parallel.rs b/tests/test_parallel.rs index d2038c6..299e446 100644 --- a/tests/test_parallel.rs +++ b/tests/test_parallel.rs @@ -1,4 +1,11 @@ -use hyper::{Client, StatusCode}; +use http_body_util::combinators::BoxBody; +use hyper::body::Bytes; +use hyper::StatusCode; +use hyper_util::{ + client::legacy::{connect::HttpConnector, Client}, + rt::TokioExecutor, +}; +use std::convert::Infallible; use test_context::test_context; use tokiotest_httpserver::handler::HandlerBuilder; use tokiotest_httpserver::HttpTestContext; @@ -12,7 +19,11 @@ async fn test_get_respond_200(ctx: &mut HttpTestContext) { .build(), ); - let resp = Client::new().get(ctx.uri("/ok")).await.unwrap(); + let resp = Client::builder(TokioExecutor::new()) + .build::<_, BoxBody>(HttpConnector::new()) + .get(ctx.uri("/ok")) + .await + .unwrap(); assert_eq!(200, resp.status()); } @@ -26,7 +37,11 @@ async fn test_get_respond_404(ctx: &mut HttpTestContext) { .build(), ); - let resp = Client::new().get(ctx.uri("/notfound")).await.unwrap(); + let resp = Client::builder(TokioExecutor::new()) + .build::<_, BoxBody>(HttpConnector::new()) + .get(ctx.uri("/notfound")) + .await + .unwrap(); assert_eq!(404, resp.status()); } diff --git a/tests/test_serial.rs b/tests/test_serial.rs index 823770f..f2e9fb6 100644 --- a/tests/test_serial.rs +++ b/tests/test_serial.rs @@ -1,5 +1,12 @@ -use hyper::{Client, StatusCode, Uri}; +use http_body_util::combinators::BoxBody; +use hyper::body::Bytes; +use hyper::{StatusCode, Uri}; +use hyper_util::{ + client::legacy::{connect::HttpConnector, Client}, + rt::TokioExecutor, +}; use serial_test::serial; +use std::convert::Infallible; use std::env; use test_context::{test_context, AsyncTestContext}; use tokiotest_httpserver::handler::HandlerBuilder; @@ -15,7 +22,8 @@ async fn test_get_respond_200(ctx: &mut PortContext) { .build(), ); - let resp = Client::new() + let resp = Client::builder(TokioExecutor::new()) + .build::<_, BoxBody>(HttpConnector::new()) .get(Uri::from_static("http://localhost:54321/ok")) .await .unwrap(); @@ -33,7 +41,8 @@ async fn test_get_respond_404(ctx: &mut PortContext) { .build(), ); - let resp = Client::new() + let resp = Client::builder(TokioExecutor::new()) + .build::<_, BoxBody>(HttpConnector::new()) .get(Uri::from_static("http://localhost:54321/notfound")) .await .unwrap();