Skip to content
This repository was archived by the owner on Jul 28, 2025. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions crates/server/src/bin/warg-server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{net::SocketAddr, path::PathBuf};
use tokio::signal;
use tracing_subscriber::filter::LevelFilter;
use warg_crypto::signing::PrivateKey;
use warg_server::{args::get_opt_content, Config, Server};
use warg_server::{args::get_opt_content, monitoring::MonitoringKind, Config, Server};

#[derive(ValueEnum, Debug, Clone, Copy, PartialEq, Eq, Default)]
enum DataStoreKind {
Expand Down Expand Up @@ -32,6 +32,31 @@ struct Args {
#[arg(long, env = "DATA_STORE", default_value = "memory")]
data_store: DataStoreKind,

/// The amount of time to continue processing client requests after receiving an external
/// signal to terminate, e.g., `SIGTERM`.
///
/// Use this option along with `HealthChecks` monitoring to support external draining of
/// clients by a load balancer.
#[arg(long, env = "GRACEFUL_SHUTDOWN_DURATION_SECONDS")]
graceful_shutdown_duration_seconds: Option<u16>,

/// The data store to use for the server.
#[arg(
long,
env = "MONITORING_ENABLED",
default_value = "",
use_value_delimiter = true,
value_delimiter = ','
)]
monitoring_enabled: Vec<MonitoringKind>,

/// Optional separate address to listen for monitoring if monitoring is enabled.
///
/// Use this option to bind monitoring API routes to another port to avoid setting up firewall
/// rules to protect them from regular API traffic.
#[arg(short, long, env = "MONITORING_LISTEN")]
monitoring_listen: Option<SocketAddr>,

/// The database connection URL if data-store is set to postgres.
///
/// Prefer using database-url-file, or environment variable variation,
Expand Down Expand Up @@ -83,6 +108,8 @@ async fn main() -> Result<()> {

let mut config = Config::new(operator_key)
.with_addr(args.listen)
.with_monitoring_enabled(args.monitoring_enabled)
.with_monitoring_addr(args.monitoring_listen)
.with_shutdown(shutdown_signal());

if let Some(content_dir) = args.content_dir {
Expand All @@ -103,7 +130,9 @@ async fn main() -> Result<()> {
}
}

Server::new(config).run().await
let mut server = Server::new(config);
server.start().await?;
server.join().await
}

/// Returns the operator key from the supplied `args` or panics.
Expand Down
169 changes: 117 additions & 52 deletions crates/server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use anyhow::{Context, Result};
use axum::{body::Body, http::Request, Router};
use datastore::DataStore;
use futures::Future;
use services::CoreService;
use monitoring::{LifecycleManager, MonitoringKind};
use services::{CoreService, StopHandle};
use std::{
fs,
net::{SocketAddr, TcpListener},
Expand All @@ -22,6 +23,7 @@ use warg_crypto::signing::PrivateKey;
pub mod api;
pub mod args;
pub mod datastore;
pub mod monitoring;
mod policy;
pub mod services;

Expand All @@ -30,8 +32,10 @@ const DEFAULT_CHECKPOINT_INTERVAL: Duration = Duration::from_secs(5);

/// The server configuration.
pub struct Config {
operator_key: PrivateKey,
operator_key: Option<PrivateKey>,
addr: Option<SocketAddr>,
monitoring_enabled: Option<Vec<MonitoringKind>>,
monitoring_addr: Option<SocketAddr>,
data_store: Option<Box<dyn datastore::DataStore>>,
content_dir: Option<PathBuf>,
shutdown: Option<Pin<Box<dyn Future<Output = ()> + Send + Sync>>>,
Expand All @@ -47,6 +51,8 @@ impl std::fmt::Debug for Config {
"data_store",
&self.shutdown.as_ref().map(|_| "dyn DataStore"),
)
.field("monitoring_enabled", &self.monitoring_enabled)
.field("monitoring_addr", &self.monitoring_addr)
.field("content", &self.content_dir)
.field("shutdown", &self.shutdown.as_ref().map(|_| "dyn Future"))
.field("checkpoint_interval", &self.checkpoint_interval)
Expand All @@ -58,8 +64,10 @@ impl Config {
/// Creates a new server configuration.
pub fn new(operator_key: PrivateKey) -> Self {
Self {
operator_key,
operator_key: Some(operator_key),
addr: None,
monitoring_enabled: None,
monitoring_addr: None,
data_store: None,
content_dir: None,
shutdown: None,
Expand All @@ -73,6 +81,18 @@ impl Config {
self
}

/// Specify the address for the server to listen on.
pub fn with_monitoring_enabled(mut self, monitoring_enabled: Vec<MonitoringKind>) -> Self {
self.monitoring_enabled = Some(monitoring_enabled);
self
}

/// Specify the address for the server to listen on.
pub fn with_monitoring_addr(mut self, monitoring_addr: impl Into<Option<SocketAddr>>) -> Self {
self.monitoring_addr = monitoring_addr.into();
self
}

/// Specify the data store to use.
///
/// If this is not specified, the server will use an in-memory data store.
Expand Down Expand Up @@ -117,95 +137,140 @@ impl Config {
}
}

pub struct Endpoints {
pub api: SocketAddr,
pub monitoring: Option<SocketAddr>,
}

/// Represents the warg registry server.
pub struct Server {
config: Config,
listener: Option<TcpListener>,
lifecycle: Arc<LifecycleManager>,
stop_handle: Option<StopHandle>,
}

impl Server {
/// Creates a new server with the given configuration.
pub fn new(config: Config) -> Self {
Self {
config,
listener: None,
lifecycle: Arc::new(LifecycleManager::new(monitoring::Config {
shutdown_grace_period: Some(Duration::from_secs(5)),
})),
stop_handle: None,
}
}

/// Binds the server to the configured address.
///
/// Returns the address the server bound to.
pub fn bind(&mut self) -> Result<SocketAddr> {
let addr = self
/// Starts the server.
pub async fn start(&mut self) -> Result<Endpoints> {
tracing::debug!(
"using server configuration: {config:?}",
config = self.config
);

let api_addr = self
.config
.addr
.unwrap_or_else(|| DEFAULT_BIND_ADDRESS.parse().unwrap());

tracing::debug!("binding server to address `{addr}`");
let listener = TcpListener::bind(addr)
.with_context(|| format!("failed to bind to address `{addr}`"))?;
.unwrap_or_else(|| DEFAULT_BIND_ADDRESS.parse().unwrap())
.to_owned();
let api_listener = TcpListener::bind(api_addr)
.with_context(|| format!("failed to bind to address `{api_addr}`"))?;
let local_addr = api_listener.local_addr().unwrap();

let addr = listener
.local_addr()
.context("failed to get local address for listen socket")?;

tracing::debug!("server bound to address `{addr}`");
self.config.addr = Some(addr);
self.listener = Some(listener);
Ok(addr)
}
let health_checks_enabled = self
.config
.monitoring_enabled
.as_ref()
.map(|kinds| kinds.contains(&MonitoringKind::HealthChecks))
.unwrap_or(false);

/// Runs the server.
pub async fn run(mut self) -> Result<()> {
if self.listener.is_none() {
self.bind()?;
let mut local_monitoring_addr: Option<SocketAddr> = None;
if health_checks_enabled {
if let Some(monitoring_addr) = self.config.monitoring_addr.as_ref() {
let monitoring_listener = TcpListener::bind(monitoring_addr.to_owned())
.with_context(|| {
format!("failed to bind health_checks to address `{monitoring_addr}`")
})
.unwrap();
local_monitoring_addr = Some(monitoring_listener.local_addr().unwrap());
let monitoring_server = axum::Server::from_tcp(monitoring_listener)
.unwrap()
.serve(self.lifecycle.health_checks_router().into_make_service());
tokio::spawn(async move {
tracing::info!(
"monitoring server on {addr}",
addr = local_monitoring_addr.unwrap()
);
_ = monitoring_server.await;
tracing::info!("monitoring server shut down");
});
}
}

let listener = self.listener.unwrap();

tracing::debug!(
"using server configuration: {config:?}",
config = self.config
);

let store = self
.config
.data_store
.take()
.unwrap_or_else(|| Box::<MemoryDataStore>::default());
let (core, handle) = CoreService::spawn(
self.config.operator_key,
self.config.operator_key.take().unwrap(),
store,
self.config
.checkpoint_interval
.unwrap_or(DEFAULT_CHECKPOINT_INTERVAL),
)
.await?;
self.stop_handle = Some(handle);

let server = axum::Server::from_tcp(listener)?.serve(
Self::create_router(
format!("http://{addr}", addr = self.config.addr.unwrap()),
self.config.content_dir,
core,
)?
.into_make_service(),
);
let mut api_router = Router::new().merge(Self::create_router(
format!("http://{addr}", addr = local_addr),
self.config.content_dir.take(),
core,
)?);

if health_checks_enabled && local_monitoring_addr.is_none() {
api_router = api_router.merge(self.lifecycle.health_checks_router());
}

tracing::info!("listening on {addr}", addr = self.config.addr.unwrap());
let api_lifecycle = self.lifecycle.clone();
let api_server = axum::Server::from_tcp(api_listener)
.unwrap()
.serve(api_router.into_make_service())
.with_graceful_shutdown(async move {
api_lifecycle.drain_signal().await;
});
tokio::spawn(async move {
tracing::info!("server listening on {local_addr}");
_ = api_server.await;
tracing::info!("server shut down");
});

if let Some(shutdown) = self.config.shutdown {
// NOTE: If warmup needed, set live first, do warmup, and then set ready.
self.lifecycle.set_ready().await?;

// Set shutdown sequence whether nor not it will be graceful.
if let Some(shutdown) = self.config.shutdown.take() {
tracing::debug!("server is running with a shutdown signal");
server
.with_graceful_shutdown(async move { shutdown.await })
.await?;
let lifecycle = self.lifecycle.clone();
tokio::spawn(async move {
shutdown.await;
lifecycle.shutdown().await.unwrap();
});
} else {
tracing::debug!("server is running without a shutdown signal");
server.await?;
}

Ok(Endpoints {
api: local_addr,
monitoring: local_monitoring_addr,
})
}

pub async fn join(&mut self) -> Result<()> {
self.lifecycle.terminate_signal().await;
tracing::info!("waiting for core service to stop");
handle.stop().await;
self.stop_handle.take().unwrap().stop().await;
tracing::info!("server shutdown complete");

Ok(())
}

Expand Down
Loading