diff --git a/crates/server/src/bin/warg-server.rs b/crates/server/src/bin/warg-server.rs index 8c60ffe7..eac3ed02 100644 --- a/crates/server/src/bin/warg-server.rs +++ b/crates/server/src/bin/warg-server.rs @@ -4,7 +4,7 @@ use std::{net::SocketAddr, path::PathBuf}; use tokio::signal; use tracing_subscriber::filter::LevelFilter; use warg_crypto::signing::PrivateKey; -use warg_server::{args::get_opt_content, Config, Server}; +use warg_server::{args::get_opt_content, monitoring::MonitoringKind, Config, Server}; #[derive(ValueEnum, Debug, Clone, Copy, PartialEq, Eq, Default)] enum DataStoreKind { @@ -32,6 +32,31 @@ struct Args { #[arg(long, env = "DATA_STORE", default_value = "memory")] data_store: DataStoreKind, + /// The amount of time to continue processing client requests after receiving an external + /// signal to terminate, e.g., `SIGTERM`. + /// + /// Use this option along with `HealthChecks` monitoring to support external draining of + /// clients by a load balancer. + #[arg(long, env = "GRACEFUL_SHUTDOWN_DURATION_SECONDS")] + graceful_shutdown_duration_seconds: Option, + + /// The data store to use for the server. + #[arg( + long, + env = "MONITORING_ENABLED", + default_value = "", + use_value_delimiter = true, + value_delimiter = ',' + )] + monitoring_enabled: Vec, + + /// Optional separate address to listen for monitoring if monitoring is enabled. + /// + /// Use this option to bind monitoring API routes to another port to avoid setting up firewall + /// rules to protect them from regular API traffic. + #[arg(short, long, env = "MONITORING_LISTEN")] + monitoring_listen: Option, + /// The database connection URL if data-store is set to postgres. /// /// Prefer using database-url-file, or environment variable variation, @@ -83,6 +108,8 @@ async fn main() -> Result<()> { let mut config = Config::new(operator_key) .with_addr(args.listen) + .with_monitoring_enabled(args.monitoring_enabled) + .with_monitoring_addr(args.monitoring_listen) .with_shutdown(shutdown_signal()); if let Some(content_dir) = args.content_dir { @@ -103,7 +130,9 @@ async fn main() -> Result<()> { } } - Server::new(config).run().await + let mut server = Server::new(config); + server.start().await?; + server.join().await } /// Returns the operator key from the supplied `args` or panics. diff --git a/crates/server/src/lib.rs b/crates/server/src/lib.rs index c9c7a57d..a962bd92 100644 --- a/crates/server/src/lib.rs +++ b/crates/server/src/lib.rs @@ -3,7 +3,8 @@ use anyhow::{Context, Result}; use axum::{body::Body, http::Request, Router}; use datastore::DataStore; use futures::Future; -use services::CoreService; +use monitoring::{LifecycleManager, MonitoringKind}; +use services::{CoreService, StopHandle}; use std::{ fs, net::{SocketAddr, TcpListener}, @@ -22,6 +23,7 @@ use warg_crypto::signing::PrivateKey; pub mod api; pub mod args; pub mod datastore; +pub mod monitoring; mod policy; pub mod services; @@ -30,8 +32,10 @@ const DEFAULT_CHECKPOINT_INTERVAL: Duration = Duration::from_secs(5); /// The server configuration. pub struct Config { - operator_key: PrivateKey, + operator_key: Option, addr: Option, + monitoring_enabled: Option>, + monitoring_addr: Option, data_store: Option>, content_dir: Option, shutdown: Option + Send + Sync>>>, @@ -47,6 +51,8 @@ impl std::fmt::Debug for Config { "data_store", &self.shutdown.as_ref().map(|_| "dyn DataStore"), ) + .field("monitoring_enabled", &self.monitoring_enabled) + .field("monitoring_addr", &self.monitoring_addr) .field("content", &self.content_dir) .field("shutdown", &self.shutdown.as_ref().map(|_| "dyn Future")) .field("checkpoint_interval", &self.checkpoint_interval) @@ -58,8 +64,10 @@ impl Config { /// Creates a new server configuration. pub fn new(operator_key: PrivateKey) -> Self { Self { - operator_key, + operator_key: Some(operator_key), addr: None, + monitoring_enabled: None, + monitoring_addr: None, data_store: None, content_dir: None, shutdown: None, @@ -73,6 +81,18 @@ impl Config { self } + /// Specify the address for the server to listen on. + pub fn with_monitoring_enabled(mut self, monitoring_enabled: Vec) -> Self { + self.monitoring_enabled = Some(monitoring_enabled); + self + } + + /// Specify the address for the server to listen on. + pub fn with_monitoring_addr(mut self, monitoring_addr: impl Into>) -> Self { + self.monitoring_addr = monitoring_addr.into(); + self + } + /// Specify the data store to use. /// /// If this is not specified, the server will use an in-memory data store. @@ -117,10 +137,16 @@ impl Config { } } +pub struct Endpoints { + pub api: SocketAddr, + pub monitoring: Option, +} + /// Represents the warg registry server. pub struct Server { config: Config, - listener: Option, + lifecycle: Arc, + stop_handle: Option, } impl Server { @@ -128,84 +154,123 @@ impl Server { pub fn new(config: Config) -> Self { Self { config, - listener: None, + lifecycle: Arc::new(LifecycleManager::new(monitoring::Config { + shutdown_grace_period: Some(Duration::from_secs(5)), + })), + stop_handle: None, } } - /// Binds the server to the configured address. - /// - /// Returns the address the server bound to. - pub fn bind(&mut self) -> Result { - let addr = self + /// Starts the server. + pub async fn start(&mut self) -> Result { + tracing::debug!( + "using server configuration: {config:?}", + config = self.config + ); + + let api_addr = self .config .addr - .unwrap_or_else(|| DEFAULT_BIND_ADDRESS.parse().unwrap()); - - tracing::debug!("binding server to address `{addr}`"); - let listener = TcpListener::bind(addr) - .with_context(|| format!("failed to bind to address `{addr}`"))?; + .unwrap_or_else(|| DEFAULT_BIND_ADDRESS.parse().unwrap()) + .to_owned(); + let api_listener = TcpListener::bind(api_addr) + .with_context(|| format!("failed to bind to address `{api_addr}`"))?; + let local_addr = api_listener.local_addr().unwrap(); - let addr = listener - .local_addr() - .context("failed to get local address for listen socket")?; - - tracing::debug!("server bound to address `{addr}`"); - self.config.addr = Some(addr); - self.listener = Some(listener); - Ok(addr) - } + let health_checks_enabled = self + .config + .monitoring_enabled + .as_ref() + .map(|kinds| kinds.contains(&MonitoringKind::HealthChecks)) + .unwrap_or(false); - /// Runs the server. - pub async fn run(mut self) -> Result<()> { - if self.listener.is_none() { - self.bind()?; + let mut local_monitoring_addr: Option = None; + if health_checks_enabled { + if let Some(monitoring_addr) = self.config.monitoring_addr.as_ref() { + let monitoring_listener = TcpListener::bind(monitoring_addr.to_owned()) + .with_context(|| { + format!("failed to bind health_checks to address `{monitoring_addr}`") + }) + .unwrap(); + local_monitoring_addr = Some(monitoring_listener.local_addr().unwrap()); + let monitoring_server = axum::Server::from_tcp(monitoring_listener) + .unwrap() + .serve(self.lifecycle.health_checks_router().into_make_service()); + tokio::spawn(async move { + tracing::info!( + "monitoring server on {addr}", + addr = local_monitoring_addr.unwrap() + ); + _ = monitoring_server.await; + tracing::info!("monitoring server shut down"); + }); + } } - let listener = self.listener.unwrap(); - - tracing::debug!( - "using server configuration: {config:?}", - config = self.config - ); - let store = self .config .data_store + .take() .unwrap_or_else(|| Box::::default()); let (core, handle) = CoreService::spawn( - self.config.operator_key, + self.config.operator_key.take().unwrap(), store, self.config .checkpoint_interval .unwrap_or(DEFAULT_CHECKPOINT_INTERVAL), ) .await?; + self.stop_handle = Some(handle); - let server = axum::Server::from_tcp(listener)?.serve( - Self::create_router( - format!("http://{addr}", addr = self.config.addr.unwrap()), - self.config.content_dir, - core, - )? - .into_make_service(), - ); + let mut api_router = Router::new().merge(Self::create_router( + format!("http://{addr}", addr = local_addr), + self.config.content_dir.take(), + core, + )?); + + if health_checks_enabled && local_monitoring_addr.is_none() { + api_router = api_router.merge(self.lifecycle.health_checks_router()); + } - tracing::info!("listening on {addr}", addr = self.config.addr.unwrap()); + let api_lifecycle = self.lifecycle.clone(); + let api_server = axum::Server::from_tcp(api_listener) + .unwrap() + .serve(api_router.into_make_service()) + .with_graceful_shutdown(async move { + api_lifecycle.drain_signal().await; + }); + tokio::spawn(async move { + tracing::info!("server listening on {local_addr}"); + _ = api_server.await; + tracing::info!("server shut down"); + }); - if let Some(shutdown) = self.config.shutdown { + // NOTE: If warmup needed, set live first, do warmup, and then set ready. + self.lifecycle.set_ready().await?; + + // Set shutdown sequence whether nor not it will be graceful. + if let Some(shutdown) = self.config.shutdown.take() { tracing::debug!("server is running with a shutdown signal"); - server - .with_graceful_shutdown(async move { shutdown.await }) - .await?; + let lifecycle = self.lifecycle.clone(); + tokio::spawn(async move { + shutdown.await; + lifecycle.shutdown().await.unwrap(); + }); } else { tracing::debug!("server is running without a shutdown signal"); - server.await?; } + Ok(Endpoints { + api: local_addr, + monitoring: local_monitoring_addr, + }) + } + + pub async fn join(&mut self) -> Result<()> { + self.lifecycle.terminate_signal().await; tracing::info!("waiting for core service to stop"); - handle.stop().await; + self.stop_handle.take().unwrap().stop().await; tracing::info!("server shutdown complete"); - Ok(()) } diff --git a/crates/server/src/monitoring.rs b/crates/server/src/monitoring.rs new file mode 100644 index 00000000..80766034 --- /dev/null +++ b/crates/server/src/monitoring.rs @@ -0,0 +1,199 @@ +//! # Monitoring +//! +//! The `monitoring` mod is a collection of utilities to set up and monitoring use cases such as: +//! +//! * health checks, e.g., `/livez` (server healthy but not serving), `/readyz` (server healthy and +//! serving) +//! * shutdown grace period, i.e., time needed for load balancer to recognize need to pull an +//! instance from a service pool when `/readyz` returns error + +use std::sync::Arc; + +use anyhow::Result; +use axum::{body::Body, extract::State, response::IntoResponse, routing::get, Router}; +use clap::ValueEnum; +use reqwest::StatusCode; +use tokio::sync::{broadcast::Sender, Mutex}; + +#[derive(ValueEnum, Debug, Clone, Copy, PartialEq, Eq)] +pub enum MonitoringKind { + HealthChecks, + // TODO: Support metrics via at least one of OpenTelemetry Metrics, Prometheus, etc. + // Metrics, + // TODO: Support tracing via at least one of OpenTelemetry Tracing, Jaeger, etc. + // Tracing, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +enum Stage { + NotLive, + Live, + Ready, + ShuttingDown, + Terminating, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum Event { + ShuttingDown, + Terminating, +} + +struct LifecycleState { + stage: Stage, +} + +struct Lifecycle { + tx: Sender, + state: Mutex, +} + +pub struct LifecycleManager { + lifecycle: Arc, + shutdown_grace_period: Option, +} + +pub struct Config { + pub shutdown_grace_period: Option, +} + +impl LifecycleManager { + pub fn new(config: Config) -> Self { + let (tx, _) = tokio::sync::broadcast::channel::(1); + LifecycleManager { + lifecycle: Arc::new(Lifecycle { + tx, + state: Mutex::new(LifecycleState { + stage: Stage::NotLive, + }), + }), + shutdown_grace_period: config.shutdown_grace_period, + } + } + + pub fn has_graceful_shutdown(&self) -> bool { + self.shutdown_grace_period.is_some() + } + + #[allow(dead_code)] + pub async fn set_live(&self) -> Result<()> { + let mut state = self.lifecycle.state.lock().await; + match state.stage { + Stage::ShuttingDown => Ok(()), + _ => { + state.stage = Stage::Live; + Ok(()) + } + } + } + + pub async fn set_ready(&self) -> Result<()> { + let mut state = self.lifecycle.state.lock().await; + match state.stage { + Stage::ShuttingDown => Ok(()), + _ => { + state.stage = Stage::Ready; + Ok(()) + } + } + } + + /// Initiates the lifecycle shutdown sequence. + /// + /// A `ShuttingDown` event will be sent followed by `Terminating` after the optionally + /// configured shutdown grace period. + pub async fn shutdown(&self) -> Result<()> { + let mut state = self.lifecycle.state.lock().await; + match state.stage { + Stage::ShuttingDown => Ok(()), + _ => { + tracing::debug!("shutting down"); + state.stage = Stage::ShuttingDown; + self.lifecycle.tx.send(Event::ShuttingDown).map(|_| ())?; + + if let Some(shutdown_grace_period) = self.shutdown_grace_period { + tracing::info!( + "shutting down with grace period {:?}", + shutdown_grace_period + ); + let lifecycle = self.lifecycle.clone(); + tokio::spawn(async move { + tokio::time::sleep(shutdown_grace_period).await; + let mut state = lifecycle.state.lock().await; + tracing::info!("terminating"); + state.stage = Stage::Terminating; + lifecycle.tx.send(Event::Terminating).unwrap(); + }); + } else { + tracing::info!("shutting down without grace period"); + tracing::info!("terminating"); + state.stage = Stage::Terminating; + self.lifecycle.tx.send(Event::Terminating).map(|_| ())?; + } + + Ok(()) + } + } + } + + /// Completes when services should immediately drain clients. + pub async fn drain_signal(&self) { + let event = if self.has_graceful_shutdown() { + Event::Terminating + } else { + Event::ShuttingDown + }; + self.signal(event).await; + } + + /// Completes when shutdown event occurs or the lifecycle broadcast channel is closed. + pub async fn shutdown_signal(&self) { + self.signal(Event::ShuttingDown).await + } + + /// Completes when termination event occurs or the lifecycle broadcast channel is closed. + pub async fn terminate_signal(&self) { + self.signal(Event::Terminating).await + } + + async fn signal(&self, event: Event) { + let mut server_rx = self.lifecycle.tx.subscribe(); + loop { + match server_rx.recv().await { + Ok(e) => match e { + e if e == event => return, + _ => continue, + }, + Err(s) => match s { + tokio::sync::broadcast::error::RecvError::Closed => return, + tokio::sync::broadcast::error::RecvError::Lagged(_) => continue, + }, + } + } + } + + pub fn health_checks_router(&self) -> Router<(), Body> { + axum::Router::new() + .route("/livez", get(livez)) + .route("/readyz", get(readyz)) + .with_state(self.lifecycle.clone()) + } +} + +async fn livez(State(lifecycle): State>) -> impl IntoResponse { + // TODO: Support a verbose option for human-readable details. + let state = lifecycle.state.lock().await; + match state.stage { + Stage::Live | Stage::Ready => StatusCode::OK, + _ => StatusCode::SERVICE_UNAVAILABLE, + } +} + +async fn readyz(State(lifecycle): State>) -> impl IntoResponse { + // TODO: Support a verbose option for human-readable details. + let state = lifecycle.state.lock().await; + match state.stage { + Stage::Ready => StatusCode::OK, + _ => StatusCode::SERVICE_UNAVAILABLE, + } +} diff --git a/crates/server/src/services/mod.rs b/crates/server/src/services/mod.rs index 4deda2fa..668c60e5 100644 --- a/crates/server/src/services/mod.rs +++ b/crates/server/src/services/mod.rs @@ -2,5 +2,5 @@ mod core; mod data; mod transparency; -pub use self::core::{CoreService, CoreServiceError}; +pub use self::core::{CoreService, CoreServiceError, StopHandle}; pub use self::data::{log::LogData, map::MapData, DataServiceError}; diff --git a/tests/support/mod.rs b/tests/support/mod.rs index db397407..08192772 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -82,10 +82,10 @@ pub async fn spawn_server( } let mut server = Server::new(config); - let addr = server.bind()?; + let endpoints = server.start().await?; let task = tokio::spawn(async move { - server.run().await.unwrap(); + server.join().await.unwrap(); }); let instance = ServerInstance { @@ -94,7 +94,7 @@ pub async fn spawn_server( }; let config = warg_client::Config { - default_url: Some(format!("http://{addr}")), + default_url: Some(format!("http://{addr}", addr = endpoints.api)), registries_dir: Some(root.join("registries")), content_dir: Some(root.join("content")), };