diff --git a/Cargo.toml b/Cargo.toml index 0da395b..5276e95 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,14 +13,15 @@ edition = "2018" [dependencies] serde_json = "1.0" tokio = { version = "1", features = ["full"] } -hyper = { version = "0.14", features = ["full"] } +hyper = { version = "1", features = ["full"] } futures = "0.3" -async-trait = "0.1" tower-http = "0.2" tokio-test = "0.4.2" -test-context = "0.1.3" +test-context = "0.4.1" lazy_static = "1.4" queues = "1.1" +hyper-util = { version = "0.1.10", features = ["full"] } +http-body-util = "0.1.2" [dev-dependencies] serial_test = "0.6.0" diff --git a/README.md b/README.md index d052821..7daa348 100644 --- a/README.md +++ b/README.md @@ -7,17 +7,32 @@ A small test server utility to run http request against. The test context instantiates a new server with a random port between 12300 and 12400. The test will use this port : ```rust,no_run -use test_context::{AsyncTestContext, test_context}; -use hyper::{Uri, StatusCode, Client}; -use tokiotest_httpserver::handler::{HandlerBuilder}; +use http_body_util::combinators::BoxBody; +use hyper::body::Bytes; +use hyper::{StatusCode, Uri}; +use hyper_util::{ + client::legacy::{connect::HttpConnector, Client}, + rt::TokioExecutor, +}; +use std::convert::Infallible; +use test_context::{test_context, AsyncTestContext}; +use tokiotest_httpserver::handler::HandlerBuilder; use tokiotest_httpserver::HttpTestContext; #[test_context(HttpTestContext)] #[tokio::test] async fn test_get_respond_200(ctx: &mut HttpTestContext) { - ctx.add(HandlerBuilder::new("/ok").status_code(StatusCode::OK).build()); - - let resp = Client::new().get(ctx.uri("/ok")).await.unwrap(); + ctx.add( + HandlerBuilder::new("/ok") + .status_code(StatusCode::OK) + .build(), + ); + + let resp = Client::builder(TokioExecutor::new()) + .build::<_, BoxBody>(HttpConnector::new()) + .get(ctx.uri("/ok")) + .await + .unwrap(); assert_eq!(200, resp.status()); } @@ -31,4 +46,4 @@ It is also possible to use it with a sequential workflow. You just have to inclu With serial workflow you can choose to use a fixed port for the http test server by setting the environment variable `TOKIOTEST_HTTP_PORT` to the desired port. -See for example [test_serial](tests/test_serial.rs). \ No newline at end of file +See for example [test_serial](tests/test_serial.rs). diff --git a/src/handler.rs b/src/handler.rs index 36035e8..79935f1 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,18 +1,22 @@ +use crate::StatusCode; +use futures::future::BoxFuture; +use http_body_util::combinators::BoxBody; +use hyper::body::{Bytes, Incoming}; +use hyper::{HeaderMap, Method, Request}; use std::convert::Infallible; use std::sync::Arc; -use futures::future::BoxFuture; -use hyper::{Body, HeaderMap, Method, Request}; -use crate::StatusCode; -pub type Response = hyper::Response; -pub type HandlerCallback = Arc) -> BoxFuture<'static, Result> + Send + Sync>; +pub type Response = hyper::Response>; +pub type HandlerCallback = Arc< + dyn Fn(Request) -> BoxFuture<'static, Result> + Send + Sync, +>; #[derive(Default, Clone)] pub struct HandlerBuilder { path: String, method: Method, headers: HeaderMap, - status_code: StatusCode + status_code: StatusCode, } #[allow(dead_code)] @@ -22,7 +26,7 @@ impl HandlerBuilder { path: String::from(path), method: Method::GET, headers: HeaderMap::new(), - status_code: StatusCode::INTERNAL_SERVER_ERROR + status_code: StatusCode::INTERNAL_SERVER_ERROR, } } @@ -42,23 +46,39 @@ impl HandlerBuilder { } pub fn build(self) -> HandlerCallback { - let Self { path, method, status_code, headers } = self; - Arc::new(move |req: Request| { + let Self { + path, + method, + status_code, + headers, + } = self; + Arc::new(move |req: Request| { let cloned_path = path.clone(); let cloned_method = method.clone(); let cloned_headers = headers.clone(); Box::pin(async move { - if req.uri().path().eq(cloned_path.as_str()) && req.method().eq(&cloned_method) - && Self::contains_headers(req.headers(), &cloned_headers) { - Ok(hyper::Response::builder().status(status_code).body(Body::empty()).unwrap()) + if req.uri().path().eq(cloned_path.as_str()) + && req.method().eq(&cloned_method) + && Self::contains_headers(req.headers(), &cloned_headers) + { + Ok(hyper::Response::builder() + .status(status_code) + .body(BoxBody::default()) + .unwrap()) } else { - Ok(hyper::Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(Body::empty()).unwrap()) + Ok(hyper::Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(BoxBody::default()) + .unwrap()) } }) }) } - fn contains_headers(headers_reference: &HeaderMap, headers_to_be_contained: &HeaderMap) -> bool { + fn contains_headers( + headers_reference: &HeaderMap, + headers_to_be_contained: &HeaderMap, + ) -> bool { for (header, value) in headers_to_be_contained { if !headers_reference.get(header).eq(&Some(value)) { return false; @@ -68,6 +88,9 @@ impl HandlerBuilder { } } -pub async fn default_handle(_req: Request) -> Result { - Ok(hyper::Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR).body(Body::empty()).unwrap()) +pub async fn default_handle(_req: Request) -> Result { + Ok(hyper::Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(BoxBody::default()) + .unwrap()) } diff --git a/src/lib.rs b/src/lib.rs index 3023f27..fda5b96 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,26 +1,31 @@ #![doc = include_str!("../README.md")] pub mod handler; +use crate::handler::{default_handle, HandlerCallback}; +use hyper::service::service_fn; +use hyper::{StatusCode, Uri}; +use hyper_util::rt::{TokioExecutor, TokioIo}; +use hyper_util::server::graceful::GracefulShutdown; +use lazy_static::lazy_static; +use queues::{queue, IsQueue, Queue}; use std::collections::BinaryHeap; -use std::future::Future; -use std::net::{SocketAddr}; +use std::env; +use std::net::SocketAddr; +use std::sync::Arc; +use std::sync::Mutex; use test_context::AsyncTestContext; +use tokio::net::TcpListener; +use tokio::select; use tokio::sync::oneshot::{Receiver, Sender}; use tokio::task::JoinHandle; -use hyper::{Server, StatusCode, Uri}; -use hyper::service::{make_service_fn, service_fn}; -use lazy_static::lazy_static; -use std::sync::Mutex; -use std::sync::Arc; -use queues::{Queue, IsQueue, queue}; -use crate::handler::{default_handle, HandlerCallback}; -use std::env; +use tokio::time::Duration; pub type Error = Box; pub static TOKIOTEST_HTTP_PORT_ENV: &str = "TOKIOTEST_HTTP_PORT"; lazy_static! { - static ref PORTS: Mutex> = Mutex::new(BinaryHeap::from((12300u16..12400u16).collect::>())); + static ref PORTS: Mutex> = + Mutex::new(BinaryHeap::from((12300u16..12400u16).collect::>())); } /// function that can be called to avoid port collision when tests have to open a listen port @@ -39,7 +44,7 @@ pub fn release_port(port: u16) { pub struct HttpTestContext { pub port: u16, pub handlers: Arc>>, - server_handler: JoinHandle>, + server_handler: JoinHandle>, sender: Sender<()>, } @@ -49,53 +54,77 @@ impl HttpTestContext { } pub fn uri(&self, path: &str) -> Uri { - format!("http://{}:{}{}", "localhost", self.port, path).parse::().unwrap() + format!("http://{}:{}{}", "localhost", self.port, path) + .parse::() + .unwrap() } } -pub async fn run_service(addr: SocketAddr, rx: Receiver<()>, - handlers: Arc>>) -> impl Future> { - - let new_service = make_service_fn(move |_| { - let cloned_handlers = handlers.clone(); - async { - Ok::<_, Error>(service_fn(move |req| { - match cloned_handlers.lock() { - Ok(mut handlers_rw) => { - match handlers_rw.remove() { - Ok(handler) => { handler(req) } - Err(_err) => { Box::pin(default_handle(req)) } - } - } - Err(_err_lock) => Box::pin(default_handle(req)) - } - })) +pub async fn run_service( + addr: SocketAddr, + mut rx: Receiver<()>, + handlers: Arc>>, +) -> Result<(), Error> { + let graceful = GracefulShutdown::new(); + let listener = TcpListener::bind(addr).await?; + loop { + select! { + conn = listener.accept() => { + let (stream, _) = conn?; + let io = TokioIo::new(stream); + let cloned_handlers = handlers.clone(); + let builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()); + let conn = builder.serve_connection_with_upgrades( + io, + service_fn(move |req| match cloned_handlers.lock() { + Ok(mut handlers_rw) => match handlers_rw.remove() { + Ok(handler) => handler(req), + Err(_err) => Box::pin(default_handle(req)), + }, + Err(_err_lock) => Box::pin(default_handle(req)), + }), + ); + tokio::spawn(graceful.watch(conn.into_owned())); + }, + _ = &mut rx => { + drop(listener); + break; + } } - }); - Server::bind(&addr).serve(new_service).with_graceful_shutdown(async { rx.await.ok(); }) + } + select! { + _ = graceful.shutdown() => { + Ok(()) + }, + _ = tokio::time::sleep(Duration::from_secs(10)) => { + Err(Box::new(tokio::io::Error::new( + tokio::io::ErrorKind::TimedOut, + "Server graceful shutdown timed out", + ))) + }, + } } -#[async_trait::async_trait] impl AsyncTestContext for HttpTestContext { async fn setup() -> HttpTestContext { let port: u16 = match env::var(TOKIOTEST_HTTP_PORT_ENV) { Ok(port_str) => port_str.parse::().unwrap(), - Err(_e) => take_port() + Err(_e) => take_port(), }; let addr = SocketAddr::new("127.0.0.1".parse().unwrap(), port); let (sender, receiver) = tokio::sync::oneshot::channel::<()>(); let handlers: Arc>> = Arc::new(Mutex::new(queue![])); - let server_handler = tokio::spawn(run_service(addr, receiver, handlers.clone()).await); + let server_handler = tokio::spawn(run_service(addr, receiver, handlers.clone())); HttpTestContext { server_handler, sender, port, - handlers + handlers, } } async fn teardown(self) { - let _ = self.sender.send(()).unwrap(); + self.sender.send(()).unwrap(); let _ = tokio::join!(self.server_handler); release_port(self.port); } @@ -103,24 +132,42 @@ impl AsyncTestContext for HttpTestContext { #[cfg(test)] mod test { - use hyper::{StatusCode, Method, Request, Body, HeaderMap, Client}; - use crate::{HttpTestContext}; - use test_context::test_context; + use std::convert::Infallible; + use crate::handler::HandlerBuilder; + use crate::HttpTestContext; + use http_body_util::{combinators::BoxBody, Full}; + use hyper::{body::Bytes, HeaderMap, Method, Request, StatusCode}; + use hyper_util::{ + client::legacy::{connect::HttpConnector, Client}, + rt::TokioExecutor, + }; + use test_context::test_context; + + macro_rules! make_client { + () => { + Client::builder(TokioExecutor::new()) + .build::<_, BoxBody>(HttpConnector::new()) + }; + } #[test_context(HttpTestContext)] #[tokio::test] async fn test_get_without_expect_should_send_500(ctx: &mut HttpTestContext) { - let resp = Client::new().get(ctx.uri("/whatever")).await.unwrap(); + let resp = make_client!().get(ctx.uri("/whatever")).await.unwrap(); assert_eq!(500, resp.status()); } #[test_context(HttpTestContext)] #[tokio::test] async fn test_get_respond_404(ctx: &mut HttpTestContext) { - ctx.add(HandlerBuilder::new("/unknown").status_code(StatusCode::NOT_FOUND).build()); + ctx.add( + HandlerBuilder::new("/unknown") + .status_code(StatusCode::NOT_FOUND) + .build(), + ); - let resp = Client::new().get(ctx.uri("/unknown")).await.unwrap(); + let resp = make_client!().get(ctx.uri("/unknown")).await.unwrap(); assert_eq!(404, resp.status()); } @@ -128,12 +175,16 @@ mod test { #[test_context(HttpTestContext)] #[tokio::test] async fn test_get_endpoint(ctx: &mut HttpTestContext) { - ctx.add(HandlerBuilder::new("/foo").status_code(StatusCode::OK).build()); + ctx.add( + HandlerBuilder::new("/foo") + .status_code(StatusCode::OK) + .build(), + ); - let resp = Client::new().get(ctx.uri("/foo")).await.unwrap(); + let resp = make_client!().get(ctx.uri("/foo")).await.unwrap(); assert_eq!(200, resp.status()); - let resp = Client::new().get(ctx.uri("/foo")).await.unwrap(); + let resp = make_client!().get(ctx.uri("/foo")).await.unwrap(); assert_eq!(500, resp.status()); } @@ -142,32 +193,50 @@ mod test { async fn test_get_with_headers(ctx: &mut HttpTestContext) { let mut headers = HeaderMap::new(); headers.append("foo", "bar".parse().unwrap()); - ctx.add(HandlerBuilder::new("/headers").status_code(StatusCode::OK).headers(headers.clone()).build()); - ctx.add(HandlerBuilder::new("/headers").status_code(StatusCode::OK).headers(headers).build()); - - let resp = Client::new().get(ctx.uri("/headers")).await.unwrap(); + ctx.add( + HandlerBuilder::new("/headers") + .status_code(StatusCode::OK) + .headers(headers.clone()) + .build(), + ); + ctx.add( + HandlerBuilder::new("/headers") + .status_code(StatusCode::OK) + .headers(headers) + .build(), + ); + + let resp = make_client!().get(ctx.uri("/headers")).await.unwrap(); assert_eq!(500, resp.status()); - let req = Request::builder().method(Method::GET).uri(ctx.uri("/headers")).header("foo", "bar").body(Body::empty()).unwrap(); - let resp = Client::new().request(req).await.unwrap(); + let req = Request::builder() + .method(Method::GET) + .uri(ctx.uri("/headers")) + .header("foo", "bar") + .body(BoxBody::default()) + .unwrap(); + let resp = make_client!().request(req).await.unwrap(); assert_eq!(200, resp.status()); } #[test_context(HttpTestContext)] #[tokio::test] async fn test_post_endpoint(ctx: &mut HttpTestContext) { - ctx.add(HandlerBuilder::new("/bar") - .status_code(StatusCode::OK) - .method(Method::POST).build()); + ctx.add( + HandlerBuilder::new("/bar") + .status_code(StatusCode::OK) + .method(Method::POST) + .build(), + ); let req = Request::builder() .method(Method::POST) .uri(ctx.uri("/bar")) - .body(Body::from("foo=bar")) + .body(BoxBody::new(Full::new(Bytes::from("foo=bar")))) .expect("request builder"); - let resp = Client::new().request(req).await.unwrap(); + let resp = make_client!().request(req).await.unwrap(); assert_eq!(200, resp.status()); } -} \ No newline at end of file +} diff --git a/tests/test_parallel.rs b/tests/test_parallel.rs index 0ed397e..299e446 100644 --- a/tests/test_parallel.rs +++ b/tests/test_parallel.rs @@ -1,14 +1,29 @@ -use hyper::{StatusCode, Client}; +use http_body_util::combinators::BoxBody; +use hyper::body::Bytes; +use hyper::StatusCode; +use hyper_util::{ + client::legacy::{connect::HttpConnector, Client}, + rt::TokioExecutor, +}; +use std::convert::Infallible; +use test_context::test_context; use tokiotest_httpserver::handler::HandlerBuilder; use tokiotest_httpserver::HttpTestContext; -use test_context::test_context; #[test_context(HttpTestContext)] #[tokio::test] async fn test_get_respond_200(ctx: &mut HttpTestContext) { - ctx.add(HandlerBuilder::new("/ok").status_code(StatusCode::OK).build()); + ctx.add( + HandlerBuilder::new("/ok") + .status_code(StatusCode::OK) + .build(), + ); - let resp = Client::new().get(ctx.uri("/ok")).await.unwrap(); + let resp = Client::builder(TokioExecutor::new()) + .build::<_, BoxBody>(HttpConnector::new()) + .get(ctx.uri("/ok")) + .await + .unwrap(); assert_eq!(200, resp.status()); } @@ -16,9 +31,17 @@ async fn test_get_respond_200(ctx: &mut HttpTestContext) { #[test_context(HttpTestContext)] #[tokio::test] async fn test_get_respond_404(ctx: &mut HttpTestContext) { - ctx.add(HandlerBuilder::new("/notfound").status_code(StatusCode::NOT_FOUND).build()); + ctx.add( + HandlerBuilder::new("/notfound") + .status_code(StatusCode::NOT_FOUND) + .build(), + ); - let resp = Client::new().get(ctx.uri("/notfound")).await.unwrap(); + let resp = Client::builder(TokioExecutor::new()) + .build::<_, BoxBody>(HttpConnector::new()) + .get(ctx.uri("/notfound")) + .await + .unwrap(); assert_eq!(404, resp.status()); } diff --git a/tests/test_serial.rs b/tests/test_serial.rs index 7fcdf28..f2e9fb6 100644 --- a/tests/test_serial.rs +++ b/tests/test_serial.rs @@ -1,17 +1,32 @@ -use tokiotest_httpserver::HttpTestContext; -use hyper::{Uri, StatusCode, Client}; -use tokiotest_httpserver::handler::{HandlerBuilder}; +use http_body_util::combinators::BoxBody; +use hyper::body::Bytes; +use hyper::{StatusCode, Uri}; +use hyper_util::{ + client::legacy::{connect::HttpConnector, Client}, + rt::TokioExecutor, +}; use serial_test::serial; -use test_context::{test_context, AsyncTestContext}; +use std::convert::Infallible; use std::env; +use test_context::{test_context, AsyncTestContext}; +use tokiotest_httpserver::handler::HandlerBuilder; +use tokiotest_httpserver::HttpTestContext; #[test_context(PortContext)] #[tokio::test] #[serial] -async fn test_get_respond_200(&mut ctx: PortContext) { - ctx.http_context.add(HandlerBuilder::new("/ok").status_code(StatusCode::OK).build()); +async fn test_get_respond_200(ctx: &mut PortContext) { + ctx.http_context.add( + HandlerBuilder::new("/ok") + .status_code(StatusCode::OK) + .build(), + ); - let resp = Client::new().get(Uri::from_static("http://localhost:54321/ok")).await.unwrap(); + let resp = Client::builder(TokioExecutor::new()) + .build::<_, BoxBody>(HttpConnector::new()) + .get(Uri::from_static("http://localhost:54321/ok")) + .await + .unwrap(); assert_eq!(200, resp.status()); } @@ -19,10 +34,18 @@ async fn test_get_respond_200(&mut ctx: PortContext) { #[test_context(PortContext)] #[tokio::test] #[serial] -async fn test_get_respond_404(&mut ctx: PortContext) { - ctx.http_context.add(HandlerBuilder::new("/notfound").status_code(StatusCode::NOT_FOUND).build()); +async fn test_get_respond_404(ctx: &mut PortContext) { + ctx.http_context.add( + HandlerBuilder::new("/notfound") + .status_code(StatusCode::NOT_FOUND) + .build(), + ); - let resp = Client::new().get(Uri::from_static("http://localhost:54321/notfound")).await.unwrap(); + let resp = Client::builder(TokioExecutor::new()) + .build::<_, BoxBody>(HttpConnector::new()) + .get(Uri::from_static("http://localhost:54321/notfound")) + .await + .unwrap(); assert_eq!(404, resp.status()); } @@ -30,19 +53,21 @@ async fn test_get_respond_404(&mut ctx: PortContext) { #[allow(dead_code)] struct PortContext { port_string: String, - http_context: HttpTestContext + http_context: HttpTestContext, } -#[async_trait::async_trait] impl AsyncTestContext for PortContext { async fn setup() -> PortContext { let port_string = "54321".to_string(); env::set_var("TOKIOTEST_HTTP_PORT", port_string.clone()); - PortContext { port_string, http_context: HttpTestContext::setup().await } + PortContext { + port_string, + http_context: HttpTestContext::setup().await, + } } async fn teardown(self) { - let _ = self.http_context.teardown(); + self.http_context.teardown().await; env::remove_var("TOKIOTEST_HTTP_PORT"); } -} \ No newline at end of file +}