diff --git a/crates/server/src/lib.rs b/crates/server/src/lib.rs index c9c7a57d..0b987342 100644 --- a/crates/server/src/lib.rs +++ b/crates/server/src/lib.rs @@ -12,6 +12,8 @@ use std::{ sync::Arc, time::Duration, }; +use tokio::task::JoinSet; +use tokio_util::sync::CancellationToken; use tower_http::{ trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer}, LatencyUnit, @@ -30,7 +32,7 @@ const DEFAULT_CHECKPOINT_INTERVAL: Duration = Duration::from_secs(5); /// The server configuration. pub struct Config { - operator_key: PrivateKey, + operator_key: Option, addr: Option, data_store: Option>, content_dir: Option, @@ -58,7 +60,7 @@ impl Config { /// Creates a new server configuration. pub fn new(operator_key: PrivateKey) -> Self { Self { - operator_key, + operator_key: Some(operator_key), addr: None, data_store: None, content_dir: None, @@ -120,7 +122,16 @@ impl Config { /// Represents the warg registry server. pub struct Server { config: Config, - listener: Option, + endpoints: Option, + token: CancellationToken, + tasks: JoinSet>, +} + +/// The bound endpoints for the warg registry server. +#[derive(Clone)] +pub struct Endpoints { + /// The address of the API endpoint. + pub api: SocketAddr, } impl Server { @@ -128,52 +139,51 @@ impl Server { pub fn new(config: Config) -> Self { Self { config, - listener: None, + token: CancellationToken::new(), + endpoints: None, + tasks: JoinSet::new(), } } - /// Binds the server to the configured address. + /// Starts the server and binds its endpoints to the configured addresses. /// - /// Returns the address the server bound to. - pub fn bind(&mut self) -> Result { + /// Returns the endpoints the server bound to. + pub async fn start(&mut self) -> Result { + assert!( + self.endpoints.is_none(), + "cannot start server multiple times" + ); + + tracing::debug!( + "using server configuration: {config:?}", + config = self.config + ); + let addr = self .config .addr + .to_owned() .unwrap_or_else(|| DEFAULT_BIND_ADDRESS.parse().unwrap()); - tracing::debug!("binding server to address `{addr}`"); + tracing::debug!("binding api endpoint to address `{addr}`"); let listener = TcpListener::bind(addr) - .with_context(|| format!("failed to bind to address `{addr}`"))?; + .with_context(|| format!("failed to bind api endpoint to address `{addr}`"))?; let addr = listener .local_addr() - .context("failed to get local address for listen socket")?; - - tracing::debug!("server bound to address `{addr}`"); - self.config.addr = Some(addr); - self.listener = Some(listener); - Ok(addr) - } - - /// Runs the server. - pub async fn run(mut self) -> Result<()> { - if self.listener.is_none() { - self.bind()?; - } - - let listener = self.listener.unwrap(); + .context("failed to get local address for api endpoint listen socket")?; + tracing::debug!("api endpoint bound to address `{addr}`"); - tracing::debug!( - "using server configuration: {config:?}", - config = self.config - ); + let endpoints = Endpoints { api: addr }; + self.endpoints = Some(endpoints.clone()); let store = self .config .data_store + .take() .unwrap_or_else(|| Box::::default()); let (core, handle) = CoreService::spawn( - self.config.operator_key, + self.config.operator_key.take().unwrap(), store, self.config .checkpoint_interval @@ -183,32 +193,69 @@ impl Server { let server = axum::Server::from_tcp(listener)?.serve( Self::create_router( - format!("http://{addr}", addr = self.config.addr.unwrap()), - self.config.content_dir, + format!("http://{addr}", addr = endpoints.api), + self.config.content_dir.take(), core, )? .into_make_service(), ); - tracing::info!("listening on {addr}", addr = self.config.addr.unwrap()); + tracing::info!("api endpoint listening on {addr}", addr = endpoints.api); - if let Some(shutdown) = self.config.shutdown { - tracing::debug!("server is running with a shutdown signal"); + // Shut down core service when token cancelled. + let token = self.token.clone(); + self.tasks.spawn(async move { + token.cancelled().await; + tracing::info!("waiting for core service to stop"); + handle.stop().await; + Ok(()) + }); + + // Shut down server when token cancelled. + let token: CancellationToken = self.token.clone(); + self.tasks.spawn(async move { + tracing::info!("waiting for api endpoint to stop"); server - .with_graceful_shutdown(async move { shutdown.await }) + .with_graceful_shutdown(async move { token.cancelled().await }) .await?; + Ok(()) + }); + + // Cancel token if shutdown signal received. + if let Some(shutdown) = self.config.shutdown.take() { + tracing::debug!("server is running with a shutdown signal"); + let token = self.token.clone(); + tokio::spawn(async move { + tracing::info!("waiting for shutdown signal"); + shutdown.await; + tracing::info!("shutting down server"); + token.cancel(); + }); } else { tracing::debug!("server is running without a shutdown signal"); - server.await?; } - tracing::info!("waiting for core service to stop"); - handle.stop().await; + Ok(endpoints) + } + + /// Waits on a started server to shutdown. + pub async fn join(&mut self) -> Result<()> { + while (self.tasks.join_next().await).is_some() {} tracing::info!("server shutdown complete"); + Ok(()) + } + /// Starts the server and waits for completion. + pub async fn run(&mut self) -> Result<()> { + self.start().await?; + self.join().await?; Ok(()) } + pub fn stop(&mut self) { + self.token.cancel(); + } + fn create_router( base_url: String, content_dir: Option, diff --git a/tests/server.rs b/tests/server.rs index 505034c2..4cf735ad 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -4,7 +4,6 @@ use std::{fs, str::FromStr}; use warg_client::{api, Config, FileSystemClient, StorageLockResult}; use warg_crypto::{signing::PrivateKey, Encode, Signable}; use wit_component::DecodedWasm; - mod support; #[cfg(feature = "postgres")] diff --git a/tests/support/mod.rs b/tests/support/mod.rs index db397407..08192772 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -82,10 +82,10 @@ pub async fn spawn_server( } let mut server = Server::new(config); - let addr = server.bind()?; + let endpoints = server.start().await?; let task = tokio::spawn(async move { - server.run().await.unwrap(); + server.join().await.unwrap(); }); let instance = ServerInstance { @@ -94,7 +94,7 @@ pub async fn spawn_server( }; let config = warg_client::Config { - default_url: Some(format!("http://{addr}")), + default_url: Some(format!("http://{addr}", addr = endpoints.api)), registries_dir: Some(root.join("registries")), content_dir: Some(root.join("content")), };