diff --git a/Cargo.lock b/Cargo.lock index ae96ca1..1248eb9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -560,9 +560,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.155" +version = "0.2.175" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" [[package]] name = "libloading" @@ -815,6 +815,7 @@ dependencies = [ "prost", "protosocket", "rustls-pki-types", + "socket2 0.6.0", "thiserror", "tokio", "tokio-rustls", @@ -1072,6 +1073,16 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "socket2" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "233504af464074f9d066d7b5416c5f9b894a5862a6506e306f7b816cdd6f1807" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + [[package]] name = "subtle" version = "2.6.1" @@ -1141,7 +1152,7 @@ dependencies = [ "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2", + "socket2 0.5.7", "tokio-macros", "windows-sys 0.52.0", ] diff --git a/Cargo.toml b/Cargo.toml index 52d4052..2c87494 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,7 @@ rmp-serde = { version = "1.3" } rustls-pemfile = { version = "2.2" } rustls-pki-types = { version = "1.12" } serde = { version = "1.0" } +socket2 = { version = "0.6" } thiserror = { version = "1.0" } tokio = { version = "1.39", features = ["net", "rt"] } tokio-rustls = { version = "0.26" } diff --git a/example-messagepack/src/server.rs b/example-messagepack/src/server.rs index 78c84fd..7469b6f 100644 --- a/example-messagepack/src/server.rs +++ b/example-messagepack/src/server.rs @@ -34,6 +34,10 @@ async fn run_main() -> Result<(), Box> { .unwrap_or_else(|_| "0.0.0.0:9000".to_string()) .parse()?, DemoRpcSocketService, + 4 << 20, + 1 << 20, + 128, + 64 << 10, ) .await?; server.set_max_queued_outbound_messages(512); diff --git a/example-proto-tls/src/server.rs b/example-proto-tls/src/server.rs index 33ba57c..0b93e4b 100644 --- a/example-proto-tls/src/server.rs +++ b/example-proto-tls/src/server.rs @@ -69,6 +69,10 @@ async fn run_main() -> Result<(), Box> { DemoRpcSocketService { tls_acceptor: Arc::new(server_config).into(), }, + 4 << 20, + 1 << 20, + 128, + 64 << 10, ) .await?; server.set_max_queued_outbound_messages(512); diff --git a/example-proto/src/server.rs b/example-proto/src/server.rs index 0dec0d2..f4a4201 100644 --- a/example-proto/src/server.rs +++ b/example-proto/src/server.rs @@ -35,6 +35,10 @@ async fn run_main() -> Result<(), Box> { .unwrap_or_else(|_| "0.0.0.0:9000".to_string()) .parse()?, DemoRpcSocketService, + 4 << 20, + 1 << 20, + 128, + 64 << 10, ) .await?; server.set_max_queued_outbound_messages(512); diff --git a/protosocket-connection/src/connection.rs b/protosocket-connection/src/connection.rs index 98663ac..03aa225 100644 --- a/protosocket-connection/src/connection.rs +++ b/protosocket-connection/src/connection.rs @@ -38,6 +38,7 @@ pub struct Connection { receive_buffer_unread_index: usize, receive_buffer: Vec, max_buffer_length: usize, + buffer_allocation_increment: usize, deserializer: Bindings::Deserializer, serializer: Bindings::Serializer, reactor: Bindings::Reactor, @@ -125,6 +126,7 @@ where deserializer: Bindings::Deserializer, serializer: Bindings::Serializer, max_buffer_length: usize, + buffer_allocation_increment: usize, max_queued_send_messages: usize, outbound_messages: mpsc::Receiver<::Message>, reactor: Bindings::Reactor, @@ -140,6 +142,7 @@ where receive_buffer: Vec::new(), max_buffer_length, receive_buffer_unread_index: 0, + buffer_allocation_increment, deserializer, serializer, reactor, @@ -148,12 +151,14 @@ where /// ensure buffer state and read from the inbound stream fn poll_read_inbound(&mut self, context: &mut Context<'_>) -> ReadBufferState { - const BUFFER_INCREMENT: usize = 1 << 20; if self.receive_buffer.len() < self.max_buffer_length - && self.receive_buffer.len() - self.receive_buffer_unread_index < BUFFER_INCREMENT + && self.receive_buffer.len() - self.receive_buffer_unread_index + < self.buffer_allocation_increment { - self.receive_buffer - .resize(self.receive_buffer.len() + BUFFER_INCREMENT, 0); + self.receive_buffer.resize( + self.receive_buffer.len() + self.buffer_allocation_increment, + 0, + ); } if 0 < self.receive_buffer.len() - self.receive_buffer_unread_index { diff --git a/protosocket-prost/src/prost_client_registry.rs b/protosocket-prost/src/prost_client_registry.rs index 96e0518..96e340d 100644 --- a/protosocket-prost/src/prost_client_registry.rs +++ b/protosocket-prost/src/prost_client_registry.rs @@ -9,6 +9,7 @@ use crate::{ProstClientConnectionBindings, ProstSerializer}; #[derive(Debug, Clone)] pub struct ClientRegistry { max_buffer_length: usize, + buffer_allocation_increment: usize, max_queued_outbound_messages: usize, runtime: tokio::runtime::Handle, stream_connector: TConnector, @@ -44,6 +45,7 @@ where Self { max_buffer_length: 4 * (1 << 20), max_queued_outbound_messages: 256, + buffer_allocation_increment: 1 << 20, runtime, stream_connector: connector, } @@ -91,6 +93,7 @@ where ProstSerializer::default(), ProstSerializer::default(), self.max_buffer_length, + self.buffer_allocation_increment, self.max_queued_outbound_messages, outbound_messages, message_reactor, diff --git a/protosocket-rpc/Cargo.toml b/protosocket-rpc/Cargo.toml index 1c58e50..8547ea2 100644 --- a/protosocket-rpc/Cargo.toml +++ b/protosocket-rpc/Cargo.toml @@ -17,6 +17,7 @@ futures = { workspace = true } k-lock = { workspace = true } log = { workspace = true } rustls-pki-types = { workspace = true } +socket2 = { workspace = true, features = ["all"] } tokio = { workspace = true } tokio-rustls = { workspace = true } tokio-util = { workspace = true } diff --git a/protosocket-rpc/src/client/configuration.rs b/protosocket-rpc/src/client/configuration.rs index 014b643..206fffd 100644 --- a/protosocket-rpc/src/client/configuration.rs +++ b/protosocket-rpc/src/client/configuration.rs @@ -170,6 +170,7 @@ impl tokio_rustls::rustls::client::danger::ServerCertVerifier for DoNothingVerif #[derive(Debug, Clone)] pub struct Configuration { max_buffer_length: usize, + buffer_allocation_increment: usize, max_queued_outbound_messages: usize, stream_connector: TStreamConnector, } @@ -182,6 +183,7 @@ where log::trace!("new client configuration"); Self { max_buffer_length: 4 * (1 << 20), // 4 MiB + buffer_allocation_increment: 1 << 20, max_queued_outbound_messages: 256, stream_connector, } @@ -200,6 +202,13 @@ where pub fn max_queued_outbound_messages(&mut self, max_queued_outbound_messages: usize) { self.max_queued_outbound_messages = max_queued_outbound_messages; } + + /// Amount of buffer to allocate at one time when buffer needs extension. + /// + /// Default: 1MiB + pub fn buffer_allocation_increment(&mut self, buffer_allocation_increment: usize) { + self.buffer_allocation_increment = buffer_allocation_increment; + } } /// Connect a new protosocket rpc client to a server @@ -247,6 +256,7 @@ where Deserializer::default(), Serializer::default(), configuration.max_buffer_length, + configuration.buffer_allocation_increment, configuration.max_queued_outbound_messages, outbound_messages, message_reactor, diff --git a/protosocket-rpc/src/server/socket_server.rs b/protosocket-rpc/src/server/socket_server.rs index 61fa6b4..fb43766 100644 --- a/protosocket-rpc/src/server/socket_server.rs +++ b/protosocket-rpc/src/server/socket_server.rs @@ -1,3 +1,4 @@ +use std::ffi::c_int; use std::future::Future; use std::io::Error; use std::pin::Pin; @@ -27,6 +28,7 @@ where socket_server: TSocketService, listener: tokio::net::TcpListener, max_buffer_length: usize, + buffer_allocation_increment: usize, max_queued_outbound_messages: usize, } @@ -38,13 +40,36 @@ where pub async fn new( address: std::net::SocketAddr, socket_server: TSocketService, + max_buffer_length: usize, + buffer_allocation_increment: usize, + max_queued_outbound_messages: usize, + listen_backlog: u32, ) -> crate::Result { - let listener = tokio::net::TcpListener::bind(address).await?; + let socket = socket2::Socket::new( + match address { + std::net::SocketAddr::V4(_) => socket2::Domain::IPV4, + std::net::SocketAddr::V6(_) => socket2::Domain::IPV6, + }, + socket2::Type::STREAM, + None, + )?; + + socket.set_nonblocking(true)?; + socket.set_tcp_nodelay(true)?; + socket.set_keepalive(true)?; + socket.set_reuse_port(true)?; + socket.set_reuse_address(true)?; + + socket.bind(&address.into())?; + socket.listen(listen_backlog as c_int)?; + + let listener = tokio::net::TcpListener::from_std(socket.into())?; Ok(Self { socket_server, listener, - max_buffer_length: 16 * (2 << 20), - max_queued_outbound_messages: 128, + max_buffer_length, + buffer_allocation_increment, + max_queued_outbound_messages, }) } @@ -84,6 +109,7 @@ where let serializer = self.socket_server.serializer(); let max_buffer_length = self.max_buffer_length; let max_queued_outbound_messages = self.max_queued_outbound_messages; + let buffer_allocation_increment = self.buffer_allocation_increment; let stream_future = self.socket_server.accept_stream(stream); @@ -97,6 +123,7 @@ where deserializer, serializer, max_buffer_length, + buffer_allocation_increment, max_queued_outbound_messages, outbound_messages_receiver, submitter, diff --git a/protosocket-server/src/connection_server.rs b/protosocket-server/src/connection_server.rs index 97a43b7..9d7e676 100644 --- a/protosocket-server/src/connection_server.rs +++ b/protosocket-server/src/connection_server.rs @@ -57,6 +57,7 @@ pub struct ProtosocketServer { connector: Connector, listener: tokio::net::TcpListener, max_buffer_length: usize, + buffer_allocation_increment: usize, max_queued_outbound_messages: usize, runtime: tokio::runtime::Handle, } @@ -78,6 +79,7 @@ impl ProtosocketServer { listener, max_buffer_length: 16 * (2 << 20), max_queued_outbound_messages: 128, + buffer_allocation_increment: 1 << 20, runtime, }) } @@ -114,6 +116,7 @@ impl Future for ProtosocketServer { self.connector.deserializer(), self.connector.serializer(), self.max_buffer_length, + self.buffer_allocation_increment, self.max_queued_outbound_messages, outbound_messages, reactor,