From 1605d040d4f54ba3a46e3abe14f43cf94642fd18 Mon Sep 17 00:00:00 2001 From: Dylan Abraham Date: Wed, 17 Sep 2025 10:54:48 -0700 Subject: [PATCH] Allow configuring tcp-keepalive duration for rpc clients and servers --- example-messagepack/src/server.rs | 1 + example-proto-tls/src/server.rs | 1 + example-proto/src/server.rs | 1 + protosocket-rpc/src/client/configuration.rs | 34 ++++++++++++++++++--- protosocket-rpc/src/server/socket_server.rs | 13 ++++++-- 5 files changed, 43 insertions(+), 7 deletions(-) diff --git a/example-messagepack/src/server.rs b/example-messagepack/src/server.rs index 7469b6f..8b186c5 100644 --- a/example-messagepack/src/server.rs +++ b/example-messagepack/src/server.rs @@ -38,6 +38,7 @@ async fn run_main() -> Result<(), Box> { 1 << 20, 128, 64 << 10, + None, ) .await?; server.set_max_queued_outbound_messages(512); diff --git a/example-proto-tls/src/server.rs b/example-proto-tls/src/server.rs index 0b93e4b..7e81fb2 100644 --- a/example-proto-tls/src/server.rs +++ b/example-proto-tls/src/server.rs @@ -73,6 +73,7 @@ async fn run_main() -> Result<(), Box> { 1 << 20, 128, 64 << 10, + None, ) .await?; server.set_max_queued_outbound_messages(512); diff --git a/example-proto/src/server.rs b/example-proto/src/server.rs index f4a4201..41f882a 100644 --- a/example-proto/src/server.rs +++ b/example-proto/src/server.rs @@ -39,6 +39,7 @@ async fn run_main() -> Result<(), Box> { 1 << 20, 128, 64 << 10, + None, ) .await?; server.set_max_queued_outbound_messages(512); diff --git a/protosocket-rpc/src/client/configuration.rs b/protosocket-rpc/src/client/configuration.rs index 206fffd..bb4763a 100644 --- a/protosocket-rpc/src/client/configuration.rs +++ b/protosocket-rpc/src/client/configuration.rs @@ -1,6 +1,6 @@ -use std::{future::Future, net::SocketAddr, sync::Arc}; - use protosocket::Connection; +use socket2::TcpKeepalive; +use std::{future::Future, net::SocketAddr, sync::Arc}; use tokio::{net::TcpStream, sync::mpsc}; use tokio_rustls::rustls::pki_types::ServerName; @@ -172,6 +172,7 @@ pub struct Configuration { max_buffer_length: usize, buffer_allocation_increment: usize, max_queued_outbound_messages: usize, + tcp_keepalive_duration: Option, stream_connector: TStreamConnector, } @@ -185,6 +186,7 @@ where max_buffer_length: 4 * (1 << 20), // 4 MiB buffer_allocation_increment: 1 << 20, max_queued_outbound_messages: 256, + tcp_keepalive_duration: None, stream_connector, } } @@ -209,6 +211,13 @@ where pub fn buffer_allocation_increment(&mut self, buffer_allocation_increment: usize) { self.buffer_allocation_increment = buffer_allocation_increment; } + + /// The duration to set for tcp_keepalive on the underlying socket. + /// + /// Default: None + pub fn tcp_keepalive_duration(&mut self, tcp_keepalive_duration: Option) { + self.tcp_keepalive_duration = tcp_keepalive_duration; + } } /// Connect a new protosocket rpc client to a server @@ -233,8 +242,25 @@ where { log::trace!("new client {address}, {configuration:?}"); - let stream = tokio::net::TcpStream::connect(address).await?; - stream.set_nodelay(true)?; + let socket = socket2::Socket::new( + match address { + SocketAddr::V4(_) => socket2::Domain::IPV4, + SocketAddr::V6(_) => socket2::Domain::IPV6, + }, + socket2::Type::STREAM, + None, + )?; + + let mut tcp_keepalive = TcpKeepalive::new(); + if let Some(duration) = configuration.tcp_keepalive_duration { + tcp_keepalive = tcp_keepalive.with_time(duration); + } + + socket.set_nonblocking(true)?; + socket.set_tcp_nodelay(true)?; + socket.set_tcp_keepalive(&tcp_keepalive)?; + + let stream = TcpStream::from_std(socket.into())?; let message_reactor: RpcCompletionReactor< Deserializer::Message, diff --git a/protosocket-rpc/src/server/socket_server.rs b/protosocket-rpc/src/server/socket_server.rs index fb43766..a30a8fe 100644 --- a/protosocket-rpc/src/server/socket_server.rs +++ b/protosocket-rpc/src/server/socket_server.rs @@ -1,11 +1,12 @@ +use protosocket::Connection; +use socket2::TcpKeepalive; use std::ffi::c_int; use std::future::Future; use std::io::Error; use std::pin::Pin; use std::task::Context; use std::task::Poll; - -use protosocket::Connection; +use std::time::Duration; use tokio::sync::mpsc; use super::connection_server::RpcConnectionServer; @@ -44,6 +45,7 @@ where buffer_allocation_increment: usize, max_queued_outbound_messages: usize, listen_backlog: u32, + tcp_keepalive_duration: Option, ) -> crate::Result { let socket = socket2::Socket::new( match address { @@ -54,9 +56,14 @@ where None, )?; + let mut tcp_keepalive = TcpKeepalive::new(); + if let Some(duration) = tcp_keepalive_duration { + tcp_keepalive = tcp_keepalive.with_time(duration); + } + socket.set_nonblocking(true)?; socket.set_tcp_nodelay(true)?; - socket.set_keepalive(true)?; + socket.set_tcp_keepalive(&tcp_keepalive)?; socket.set_reuse_port(true)?; socket.set_reuse_address(true)?;