diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index bae0f6a..9f802c9 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -21,6 +21,7 @@ jobs: matrix: features: - "ext,http,ws,mio,openssl" + - "ext,http,ws,mio,ktls" - "ext,http,ws,mio,rustls-webpki" - "ext,http,ws,mio,rustls-native" steps: @@ -34,6 +35,7 @@ jobs: matrix: features: - "ext,http,ws,mio,openssl" + - "ext,http,ws,mio,ktls" - "ext,http,ws,mio,rustls-webpki" - "ext,http,ws,mio,rustls-native" steps: @@ -76,6 +78,7 @@ jobs: matrix: features: - "ext,http,ws,mio,openssl" + - "ext,http,ws,mio,ktls" - "ext,http,ws,mio,rustls-webpki" - "ext,http,ws,mio,rustls-native" steps: diff --git a/Cargo.toml b/Cargo.toml index ff33a01..7460eae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ mio = ["dep:mio"] rustls-native = ["rustls", "rustls-native-certs"] rustls-webpki = ["rustls", "webpki-roots"] openssl = ["dep:openssl", "dep:openssl-probe"] +ktls = ["openssl", "dep:openssl-sys", "dep:foreign-types", "dep:libc", "dep:openssl-src"] http = ["dep:http", "httparse", "memchr", "itoa"] ws = ["rand", "base64", "dep:http", "httparse"] ext = [] @@ -43,6 +44,9 @@ smallvec = "1.15.0" smallstr = "0.3.1" core_affinity = "0.8.3" log = "0.4.20" +openssl-sys = { version = "0.9", optional = true } +foreign-types = { version = "0.3.1", optional = true } +libc = { version = "0.2", optional = true } [dependencies.webpki-roots] version = "0.26.0" @@ -60,6 +64,12 @@ tungstenite = "0.28.0" criterion = "0.5.1" idle = "0.2.0" +[build-dependencies.openssl-src] +version = "300" +features = ["ktls"] +optional = true +default-features = false + [lints.clippy] uninit_assumed_init = "allow" mem_replace_with_uninit = "allow" diff --git a/examples/io_service_with_context_ktls.rs b/examples/io_service_with_context_ktls.rs new file mode 100644 index 0000000..07aa324 --- /dev/null +++ b/examples/io_service_with_context_ktls.rs @@ -0,0 +1,125 @@ +#[cfg(feature = "ktls")] +mod deps { + pub use boomnet::service::IntoIOService; + pub use boomnet::service::endpoint::{DisconnectReason, Endpoint}; + pub use boomnet::service::select::Selectable; + pub use boomnet::service::select::mio::MioSelector; + pub use boomnet::stream::ktls::{IntoKtlsStream, KtlStream}; + pub use boomnet::stream::mio::{IntoMioStream, MioStream}; + pub use boomnet::stream::tcp::TcpStream; + pub use boomnet::stream::tls::TlsConfigExt; + pub use boomnet::stream::{ConnectionInfo, ConnectionInfoProvider}; + pub use boomnet::ws::{IntoWebsocket, Websocket, WebsocketFrame}; + pub use mio::event::Source; + pub use mio::{Interest, Registry, Token}; + pub use std::net::SocketAddr; + pub use std::time::Duration; +} + +#[cfg(feature = "ktls")] +use deps::*; + +#[cfg(feature = "ktls")] +struct TradeConnectionFactory { + connection_info: ConnectionInfo, +} + +#[cfg(feature = "ktls")] +impl TradeConnectionFactory { + fn new() -> Self { + Self { + connection_info: ("fstream.binance.com", 443).into(), + } + } +} + +#[cfg(feature = "ktls")] +struct TradeConnection { + ws: Websocket>, +} + +#[cfg(feature = "ktls")] +impl TradeConnection { + fn do_work(&mut self) -> std::io::Result<()> { + for frame in self.ws.read_batch()? { + if let WebsocketFrame::Text(fin, body) = frame? { + println!("({fin}) {}", String::from_utf8_lossy(body)); + } + } + Ok(()) + } +} + +#[cfg(feature = "ktls")] +impl Selectable for TradeConnection { + fn connected(&mut self) -> std::io::Result { + self.ws.connected() + } + + fn make_writable(&mut self) -> std::io::Result<()> { + self.ws.make_writable() + } + + fn make_readable(&mut self) -> std::io::Result<()> { + self.ws.make_readable() + } +} + +#[cfg(feature = "ktls")] +impl Source for TradeConnection { + fn register(&mut self, registry: &Registry, token: Token, interests: Interest) -> std::io::Result<()> { + self.ws.register(registry, token, interests) + } + + fn reregister(&mut self, registry: &Registry, token: Token, interests: Interest) -> std::io::Result<()> { + self.ws.reregister(registry, token, interests) + } + + fn deregister(&mut self, registry: &Registry) -> std::io::Result<()> { + self.ws.deregister(registry) + } +} + +#[cfg(feature = "ktls")] +impl ConnectionInfoProvider for TradeConnectionFactory { + fn connection_info(&self) -> &ConnectionInfo { + &self.connection_info + } +} + +#[cfg(feature = "ktls")] +impl Endpoint for TradeConnectionFactory { + type Target = TradeConnection; + + fn create_target(&mut self, addr: SocketAddr) -> std::io::Result> { + let mut ws = TcpStream::try_from((&self.connection_info, addr))? + .into_mio_stream() + .into_ktls_stream_with_config(|cfg| cfg.with_no_cert_verification())? + .into_websocket("/ws"); + + ws.send_text(true, Some(b"{\"method\":\"SUBSCRIBE\",\"params\":[\"btcusdt@trade\"],\"id\":1}"))?; + + Ok(Some(TradeConnection { ws })) + } + + fn can_recreate(&mut self, reason: DisconnectReason) -> bool { + println!("on disconnect: reason={}", reason); + true + } +} + +#[cfg(feature = "ktls")] +fn main() -> anyhow::Result<()> { + let mut io_service = MioSelector::new()? + .into_io_service() + .with_auto_disconnect(Duration::from_secs(10)); + + io_service.register(TradeConnectionFactory::new())?; + + loop { + io_service.poll(|conn, _| conn.do_work())?; + } +} + +#[cfg(not(feature = "ktls"))] +fn main() {} diff --git a/examples/ws_client_ktls.rs b/examples/ws_client_ktls.rs new file mode 100644 index 0000000..7f49247 --- /dev/null +++ b/examples/ws_client_ktls.rs @@ -0,0 +1,31 @@ +#[cfg(feature = "ktls")] +mod deps { + pub use boomnet::stream::ktls::IntoKtlsStream; + pub use boomnet::stream::tcp::TcpStream; + pub use boomnet::stream::tls::TlsConfigExt; + pub use boomnet::ws::{IntoWebsocket, WebsocketFrame}; +} + +#[cfg(feature = "ktls")] +use deps::*; + +#[cfg(feature = "ktls")] +fn main() -> anyhow::Result<()> { + let mut ws = TcpStream::try_from(("fstream.binance.com", 443))? + .into_ktls_stream_with_config(|cfg| cfg.with_no_cert_verification())? + .into_websocket("/ws"); + + ws.send_text(true, Some(b"{\"method\":\"SUBSCRIBE\",\"params\":[\"btcusdt@trade\"],\"id\":1}"))?; + + loop { + for frame in ws.read_batch()? { + if let WebsocketFrame::Text(fin, body) = frame? { + println!("({fin}) {}", String::from_utf8_lossy(body)); + } + } + std::thread::sleep(std::time::Duration::from_millis(1)); + } +} + +#[cfg(not(feature = "ktls"))] +fn main() {} diff --git a/src/service/endpoint.rs b/src/service/endpoint.rs index 2b0c27c..7b548f3 100644 --- a/src/service/endpoint.rs +++ b/src/service/endpoint.rs @@ -16,11 +16,9 @@ pub trait Endpoint: ConnectionInfoProvider { /// await the next connection attempt with (possibly) different `addr`. fn create_target(&mut self, addr: SocketAddr) -> io::Result>; - // /// Called by the `IOService` on each duty cycle. - // fn poll(&mut self, target: &mut Self::Target) -> io::Result<()>; - /// Upon disconnection `IOService` will query the endpoint if the connection can be - /// recreated, passing the disconnect `reason`. If not, it will cause program to panic. + /// recreated, passing the disconnect `reason`. If `false` is returned it will cause + /// program to panic. fn can_recreate(&mut self, _reason: DisconnectReason) -> bool { true } @@ -49,7 +47,8 @@ pub trait EndpointWithContext: ConnectionInfoProvider { fn create_target(&mut self, addr: SocketAddr, context: &mut C) -> io::Result>; /// Upon disconnection `IOService` will query the endpoint if the connection can be - /// recreated, passing the disconnect `reason`. If not, it will cause program to panic. + /// recreated, passing the disconnect `reason`. If `false` is returned it will cause + /// program to panic. fn can_recreate(&mut self, _reason: DisconnectReason, _context: &mut C) -> bool { true } @@ -66,9 +65,8 @@ pub trait EndpointWithContext: ConnectionInfoProvider { pub enum DisconnectReason { /// This is expected disconnection due to `ttl` on the connection expiring. AutoDisconnect(Duration), - /// Some other IO error has occurred such as reaching EOF or peer disconnect. It's normally - /// ok to try and connect again. - Other(io::Error), + /// IO error has occurred such as reaching EOF or peer disconnect. + IO(io::Error), } impl Display for DisconnectReason { @@ -78,7 +76,7 @@ impl Display for DisconnectReason { write!(f, "auto-disconnect after ")?; ttl.fmt(f) } - DisconnectReason::Other(err) => { + DisconnectReason::IO(err) => { write!(f, "{err}") } } @@ -91,7 +89,7 @@ impl DisconnectReason { } pub(crate) fn other(err: io::Error) -> DisconnectReason { - DisconnectReason::Other(err) + DisconnectReason::IO(err) } } diff --git a/src/stream/ktls.rs b/src/stream/ktls.rs new file mode 100644 index 0000000..2d19323 --- /dev/null +++ b/src/stream/ktls.rs @@ -0,0 +1,509 @@ +//! Provides TLS offload to the kernel (KTLS). +//! +use crate::service::select::Selectable; +use crate::stream::ktls::error::Error; +use crate::stream::ktls::net::peer_addr; +use crate::stream::tls::TlsConfig; +use crate::stream::{ConnectionInfo, ConnectionInfoProvider}; +use foreign_types::ForeignType; +use mio::event::Source; +use mio::{Interest, Registry, Token}; +use openssl::ssl::{ErrorCode, SslOptions}; +use smallstr::SmallString; +use std::io; +use std::io::{ErrorKind, Read, Write}; +use std::os::fd::{AsRawFd, BorrowedFd}; +use std::ptr::slice_from_raw_parts; + +/// Offloads TLS to the kernel (KTLS). Uses OpenSSL backend to configure KTLS post handshake (can change in the future). +/// The stream is designed to work with a non-blocking underlying stream. +/// +/// ## Prerequisites +/// Ensure that `tls` kernel module is installed. Otherwise, the code will panic if either KTLS +/// `send` or `recv` are not enabled. This is the minimum required to enable KTLS in the +/// software mode. +/// +/// ## Example +/// ```no_run +/// use boomnet::stream::tcp::TcpStream; +/// use crate::boomnet::stream::ktls::IntoKtlsStream; +/// +/// let ktls_stream = TcpStream::try_from(("fstream.binance.com", 443)).unwrap().into_ktls_stream().unwrap(); +/// ``` +pub struct KtlStream { + stream: S, + ssl: openssl::ssl::Ssl, + state: State, + buffer: Vec, +} + +impl KtlStream { + /// Create KTLS from underlying stream using default [`TlsConfig`]. + pub fn new(stream: S, server_name: impl AsRef) -> io::Result> + where + S: AsRawFd, + { + Self::new_with_config(stream, server_name, |_| ()) + } + + /// Create KTLS from underlying stream. This method also requires an action used + /// further configure [`TlsConfig`]. + pub fn new_with_config(stream: S, server_name: impl AsRef, configure: F) -> io::Result> + where + S: AsRawFd, + F: FnOnce(&mut TlsConfig), + { + const SSL_OP_ENABLE_KTLS: SslOptions = SslOptions::from_bits_retain(ffi::SSL_OP_ENABLE_KTLS); + + let mut builder = openssl::ssl::SslConnector::builder(openssl::ssl::SslMethod::tls_client())?; + builder.set_options(SSL_OP_ENABLE_KTLS); + + let mut tls_config = builder.into(); + configure(&mut tls_config); + + let config = tls_config.into_openssl().build().configure()?; + let ssl = config.into_ssl(server_name.as_ref())?; + + Ok(KtlStream { + stream, + ssl, + state: State::Connecting, + buffer: Vec::with_capacity(4096), + }) + } + + #[inline] + fn connected(&self) -> io::Result + where + S: AsRawFd, + { + let fd = unsafe { BorrowedFd::borrow_raw(self.stream.as_raw_fd()) }; + Ok(peer_addr(fd)?.is_some()) + } + + #[inline] + fn ssl_connect(&self) -> Result<(), Error> { + let result = unsafe { openssl_sys::SSL_connect(self.ssl.as_ptr()) }; + if result <= 0 { + Err(Error::make(result, &self.ssl)) + } else { + Ok(()) + } + } + + fn ktls_send_enabled(&self) -> bool { + unsafe { + let wbio = openssl_sys::SSL_get_wbio(self.ssl.as_ptr()); + ffi::BIO_get_ktls_send(wbio) != 0 + } + } + + fn ktls_recv_enabled(&self) -> bool { + unsafe { + let rbio = openssl_sys::SSL_get_rbio(self.ssl.as_ptr()); + ffi::BIO_get_ktls_recv(rbio) != 0 + } + } + + #[inline] + fn ssl_read(&mut self, buf: &mut [u8]) -> Result { + unsafe { + let len = + openssl_sys::SSL_read(self.ssl.as_ptr(), buf.as_mut_ptr() as *mut _, buf.len().try_into().unwrap()); + if len < 0 { + Err(error::Error::make(len, &self.ssl)) + } else { + Ok(len as usize) + } + } + } + + #[inline] + fn ssl_write(&mut self, buf: &[u8]) -> Result { + if buf.is_empty() { + return Ok(0); + } + unsafe { + let len = + openssl_sys::SSL_write(self.ssl.as_ptr(), buf.as_ptr() as *const _, buf.len().try_into().unwrap()); + if len < 0 { + Err(error::Error::make(len, &self.ssl)) + } else { + Ok(len as usize) + } + } + } +} + +#[derive(Copy, Clone)] +enum State { + Connecting, + Handshake, + Drain(usize), + Ready, +} + +impl ConnectionInfoProvider for KtlStream { + fn connection_info(&self) -> &ConnectionInfo { + self.stream.connection_info() + } +} + +impl Read for KtlStream { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + match self.state { + State::Connecting => { + if self.connected()? { + // we intentionally pass BIO_NO_CLOSE to prevent double free on the file descriptor + let sock_bio = unsafe { openssl_sys::BIO_new_socket(self.stream.as_raw_fd(), ffi::BIO_NO_CLOSE) }; + assert!(!sock_bio.is_null(), "failed to create socket BIO"); + unsafe { + openssl_sys::SSL_set_bio(self.ssl.as_ptr(), sock_bio, sock_bio); + } + self.state = State::Handshake; + } + } + State::Handshake => match self.ssl_connect() { + Ok(_) => { + assert!(self.ktls_recv_enabled(), "ktls recv not enabled, did you install 'tls' kernel module?"); + assert!(self.ktls_send_enabled(), "ktls send not enabled, did you install 'tls' kernel module?"); + self.state = State::Drain(0) + } + Err(err) if err.code() == ErrorCode::WANT_READ => {} + Err(err) if err.code() == ErrorCode::WANT_WRITE => {} + Err(err) => return Err(io::Error::other(err)), + }, + State::Drain(index) => { + let mut from = index; + let remaining = + unsafe { &*slice_from_raw_parts(self.buffer.as_ptr().add(from), self.buffer.len() - from) }; + if remaining.is_empty() { + self.state = State::Ready; + } else { + from += match self.ssl_write(remaining) { + Ok(len) => len, + Err(err) if err.code() == ErrorCode::WANT_READ => 0, + Err(err) if err.code() == ErrorCode::WANT_WRITE => 0, + Err(err) => return Err(io::Error::other(err)), + }; + self.state = State::Drain(from); + } + } + State::Ready => match self.ssl_read(buf) { + Ok(0) => return Err(ErrorKind::UnexpectedEof.into()), + Ok(len) => return Ok(len), + Err(err) if err.code() == ErrorCode::WANT_READ => {} + Err(err) if err.code() == ErrorCode::WANT_WRITE => {} + Err(err) => return Err(io::Error::other(err)), + }, + } + Err(ErrorKind::WouldBlock.into()) + } +} + +impl Write for KtlStream { + #[inline] + fn write(&mut self, buf: &[u8]) -> io::Result { + match self.state { + State::Ready => match self.ssl_write(buf) { + Ok(len) => Ok(len), + Err(err) if err.code() == ErrorCode::WANT_READ => Err(ErrorKind::WouldBlock.into()), + Err(err) if err.code() == ErrorCode::WANT_WRITE => Err(ErrorKind::WouldBlock.into()), + Err(err) => Err(io::Error::other(err)), + }, + _ => { + // we buffer any pending write + self.buffer.extend_from_slice(buf); + Ok(buf.len()) + } + } + } + + #[inline] + fn flush(&mut self) -> io::Result<()> { + match self.state { + State::Connecting | State::Handshake | State::Drain(_) => Ok(()), + State::Ready => self.stream.flush(), + } + } +} + +impl Selectable for KtlStream { + #[inline] + fn connected(&mut self) -> io::Result { + self.stream.connected() + } + + #[inline] + fn make_writable(&mut self) -> io::Result<()> { + self.stream.make_writable() + } + + #[inline] + fn make_readable(&mut self) -> io::Result<()> { + self.stream.make_readable() + } +} + +#[cfg(feature = "mio")] +impl Source for KtlStream { + #[inline] + fn register(&mut self, registry: &Registry, token: Token, interests: Interest) -> io::Result<()> { + registry.register(&mut self.stream, token, interests) + } + + #[inline] + fn reregister(&mut self, registry: &Registry, token: Token, interests: Interest) -> io::Result<()> { + registry.reregister(&mut self.stream, token, interests) + } + + #[inline] + fn deregister(&mut self, registry: &Registry) -> io::Result<()> { + registry.deregister(&mut self.stream) + } +} + +/// Trait to convert underlying stream into [`KtlStream`]. +pub trait IntoKtlsStream { + /// Convert underlying stream into [`KtlStream`] with default tls config. + /// + /// ## Examples + /// ```no_run + /// use boomnet::stream::tcp::TcpStream; + /// use boomnet::stream::ktls::IntoKtlsStream; + /// + /// let ktls = TcpStream::try_from(("127.0.0.1", 4222)).unwrap().into_ktls_stream().unwrap(); + /// ``` + fn into_ktls_stream(self) -> io::Result> + where + Self: Sized, + { + self.into_ktls_stream_with_config(|_| ()) + } + + /// Convert underlying stream into [`KtlStream`] and use provided action to modify tls config. + /// + /// ## Examples + /// ```no_run + /// use boomnet::stream::tcp::TcpStream; + /// use boomnet::stream::ktls::IntoKtlsStream; + /// use boomnet::stream::tls::TlsConfigExt; + /// + /// let ktls = TcpStream::try_from(("127.0.0.1", 4222)).unwrap().into_ktls_stream_with_config(|cfg| cfg.with_no_cert_verification()).unwrap(); + /// ``` + fn into_ktls_stream_with_config(self, builder: F) -> io::Result> + where + Self: Sized, + F: FnOnce(&mut TlsConfig); +} + +impl IntoKtlsStream for T +where + T: Read + Write + AsRawFd + ConnectionInfoProvider, +{ + fn into_ktls_stream_with_config(self, builder: F) -> io::Result> + where + Self: Sized, + F: FnOnce(&mut TlsConfig), + { + let server_name = SmallString::<[u8; 1024]>::from(self.connection_info().host()); + KtlStream::new_with_config(self, server_name, builder) + } +} + +mod error { + use crate::util::NoBlock; + use foreign_types::ForeignTypeRef; + use openssl::{error::ErrorStack, ssl::ErrorCode}; + use std::{error, ffi::c_int, fmt, io}; + + #[derive(Debug)] + enum InnerError { + Io(io::Error), + Ssl(ErrorStack), + } + + /// An SSL error. + #[derive(Debug)] + pub struct Error { + code: ErrorCode, + cause: Option, + } + + impl Error { + pub fn code(&self) -> ErrorCode { + self.code + } + + pub fn io_error(&self) -> Option<&io::Error> { + match self.cause { + Some(InnerError::Io(ref e)) => Some(e), + _ => None, + } + } + + pub fn ssl_error(&self) -> Option<&ErrorStack> { + match self.cause { + Some(InnerError::Ssl(ref e)) => Some(e), + _ => None, + } + } + + pub fn make(ret: c_int, ssl: &openssl::ssl::SslRef) -> Self { + let code = unsafe { ErrorCode::from_raw(openssl_sys::SSL_get_error(ssl.as_ptr(), ret)) }; + + let cause = match code { + ErrorCode::SSL => Some(InnerError::Ssl(ErrorStack::get())), + ErrorCode::SYSCALL => { + let errs = ErrorStack::get(); + if errs.errors().is_empty() { + // get last error from io + let e = std::io::Error::last_os_error(); + Some(InnerError::Io(e)) + } else { + Some(InnerError::Ssl(errs)) + } + } + ErrorCode::ZERO_RETURN => None, + ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => { + // get last error from io + let e = std::io::Error::last_os_error(); + Some(InnerError::Io(e)) + } + _ => None, + }; + + Error { code, cause } + } + } + + impl From for Error { + fn from(e: ErrorStack) -> Error { + Error { + code: ErrorCode::SSL, + cause: Some(InnerError::Ssl(e)), + } + } + } + + impl fmt::Display for Error { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.code { + ErrorCode::ZERO_RETURN => fmt.write_str("the SSL session has been shut down"), + ErrorCode::WANT_READ => match self.io_error() { + Some(_) => fmt.write_str("a nonblocking read call would have blocked"), + None => fmt.write_str("the operation should be retried"), + }, + ErrorCode::WANT_WRITE => match self.io_error() { + Some(_) => fmt.write_str("a nonblocking write call would have blocked"), + None => fmt.write_str("the operation should be retried"), + }, + ErrorCode::SYSCALL => match self.io_error() { + Some(err) => write!(fmt, "{err}"), + None => fmt.write_str("unexpected EOF"), + }, + ErrorCode::SSL => match self.ssl_error() { + Some(e) => write!(fmt, "{e}"), + None => fmt.write_str("OpenSSL error"), + }, + _ => write!(fmt, "unknown error code {}", self.code.as_raw()), + } + } + } + + impl error::Error for Error { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + match self.cause { + Some(InnerError::Io(ref e)) => Some(e), + Some(InnerError::Ssl(ref e)) => Some(e), + None => None, + } + } + } + + impl NoBlock for Result { + type Value = usize; + + fn no_block(self) -> io::Result { + match self { + Ok(value) => Ok(value), + Err(err) if err.code() == ErrorCode::WANT_READ => Ok(0), + Err(err) if err.code() == ErrorCode::WANT_WRITE => Ok(0), + Err(err) => Err(io::Error::other(err)), + } + } + } + + impl NoBlock for Result<(), Error> { + type Value = (); + + fn no_block(self) -> io::Result { + match self { + Ok(()) => Ok(()), + Err(err) if err.code() == ErrorCode::WANT_READ => Ok(()), + Err(err) if err.code() == ErrorCode::WANT_WRITE => Ok(()), + Err(err) => Err(io::Error::other(err)), + } + } + } +} + +mod ffi { + use openssl_sys::BIO_ctrl; + use std::ffi::{c_int, c_long}; + + pub const SSL_OP_ENABLE_KTLS: u64 = 0x00000008; + pub const BIO_NO_CLOSE: c_int = 0x00; + const BIO_CTRL_GET_KTLS_SEND: c_int = 73; + const BIO_CTRL_GET_KTLS_RECV: c_int = 76; + + #[allow(non_snake_case)] + pub unsafe fn BIO_get_ktls_send(b: *mut openssl_sys::BIO) -> c_long { + unsafe { BIO_ctrl(b, BIO_CTRL_GET_KTLS_SEND, 0, std::ptr::null_mut()) } + } + #[allow(non_snake_case)] + pub unsafe fn BIO_get_ktls_recv(b: *mut openssl_sys::BIO) -> c_long { + unsafe { BIO_ctrl(b, BIO_CTRL_GET_KTLS_RECV, 0, std::ptr::null_mut()) } + } +} + +mod net { + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; + use std::os::fd::{AsRawFd, BorrowedFd}; + use std::{io, mem}; + + pub fn peer_addr(fd: BorrowedFd<'_>) -> io::Result> { + let raw = fd.as_raw_fd(); + + let mut storage: libc::sockaddr_storage = unsafe { mem::zeroed() }; + let mut len = mem::size_of::() as libc::socklen_t; + + let rc = unsafe { libc::getpeername(raw, &mut storage as *mut _ as *mut libc::sockaddr, &mut len as *mut _) }; + + if rc == -1 { + let err = io::Error::last_os_error(); + if err.raw_os_error() == Some(libc::ENOTCONN) { + return Ok(None); + } + return Err(err); + } + + unsafe { + match storage.ss_family as libc::c_int { + libc::AF_INET => { + let sa = &*(&storage as *const _ as *const libc::sockaddr_in); + let ip = Ipv4Addr::from(u32::from_be(sa.sin_addr.s_addr)); + let port = u16::from_be(sa.sin_port); + Ok(Some(SocketAddr::new(IpAddr::V4(ip), port))) + } + libc::AF_INET6 => { + let sa = &*(&storage as *const _ as *const libc::sockaddr_in6); + let ip = Ipv6Addr::from(sa.sin6_addr.s6_addr); + let port = u16::from_be(sa.sin6_port); + Ok(Some(SocketAddr::new(IpAddr::V6(ip), port))) + } + _ => Err(io::Error::new(io::ErrorKind::InvalidData, "unsupported address family")), + } + } + } +} diff --git a/src/stream/mio.rs b/src/stream/mio.rs index 6fb1856..cc5e878 100644 --- a/src/stream/mio.rs +++ b/src/stream/mio.rs @@ -1,14 +1,14 @@ //! Stream that can be used together with `MioSelector`. -use std::io::ErrorKind::{Interrupted, NotConnected, WouldBlock}; -use std::io::{Read, Write}; -use std::{io, net}; - use crate::service::select::Selectable; use crate::stream::{ConnectionInfo, ConnectionInfoProvider}; use mio::event::Source; use mio::net::TcpStream; use mio::{Interest, Registry, Token}; +use std::io::ErrorKind::{Interrupted, NotConnected, WouldBlock}; +use std::io::{Read, Write}; +use std::os::fd::{AsRawFd, RawFd}; +use std::{io, net}; #[derive(Debug)] pub struct MioStream { @@ -33,6 +33,12 @@ impl MioStream { } } +impl AsRawFd for MioStream { + fn as_raw_fd(&self) -> RawFd { + self.inner.as_raw_fd() + } +} + impl Selectable for MioStream { fn connected(&mut self) -> io::Result { if self.connected { diff --git a/src/stream/mod.rs b/src/stream/mod.rs index d1e93f1..596e676 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -11,6 +11,8 @@ use url::{ParseError, Url}; pub mod buffer; pub mod file; +#[cfg(all(target_os = "linux", feature = "ktls"))] +pub mod ktls; #[cfg(feature = "mio")] pub mod mio; pub mod record; diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 2e0da21..2c6abf4 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -5,6 +5,7 @@ use crate::stream::{ConnectionInfo, ConnectionInfoProvider}; use std::io; use std::io::{Read, Write}; use std::net::SocketAddr; +use std::os::fd::{AsRawFd, RawFd}; /// Wraps `std::net::TcpStream` and provides `ConnectionInfo`. #[derive(Debug)] @@ -13,6 +14,12 @@ pub struct TcpStream { connection_info: ConnectionInfo, } +impl AsRawFd for TcpStream { + fn as_raw_fd(&self) -> RawFd { + self.inner.as_raw_fd() + } +} + impl From for std::net::TcpStream { fn from(stream: TcpStream) -> Self { stream.inner @@ -62,12 +69,17 @@ impl TryFrom<(ConnectionInfo, SocketAddr)> for TcpStream { } impl TcpStream { - pub fn new(stream: std::net::TcpStream, connection_info: ConnectionInfo) -> Self { + pub const fn new(stream: std::net::TcpStream, connection_info: ConnectionInfo) -> Self { Self { inner: stream, connection_info, } } + + #[inline] + pub fn connected(&mut self) -> bool { + self.inner.peer_addr().is_ok() + } } impl Read for TcpStream { diff --git a/src/stream/tls.rs b/src/stream/tls.rs index 354041f..9980e03 100644 --- a/src/stream/tls.rs +++ b/src/stream/tls.rs @@ -24,6 +24,20 @@ pub struct TlsConfig { openssl_config: SslConnectorBuilder, } +#[cfg(feature = "openssl")] +impl From for TlsConfig { + fn from(config: SslConnectorBuilder) -> Self { + Self { openssl_config: config } + } +} + +#[cfg(all(feature = "rustls", not(feature = "openssl")))] +impl From for TlsConfig { + fn from(config: ClientConfig) -> Self { + Self { rustls_config: config } + } +} + /// Extension methods for `TlsConfig`. pub trait TlsConfigExt { /// Disable certificate verification. @@ -69,6 +83,12 @@ impl TlsConfig { pub const fn as_openssl_mut(&mut self) -> &mut SslConnectorBuilder { &mut self.openssl_config } + + /// Get mutable reference to the `openssl` configuration object. + #[cfg(feature = "openssl")] + pub fn into_openssl(self) -> SslConnectorBuilder { + self.openssl_config + } } impl TlsConfigExt for TlsConfig { @@ -184,7 +204,7 @@ mod __rustls { } impl TlsStream { - pub fn wrap_with_config(stream: S, server_name: &str, builder: F) -> io::Result> + pub fn new_with_config(stream: S, server_name: &str, builder: F) -> io::Result> where F: FnOnce(&mut TlsConfig), { @@ -218,8 +238,8 @@ mod __rustls { Ok(Self { inner: stream, tls }) } - pub fn wrap(stream: S, server_name: &str) -> io::Result> { - Self::wrap_with_config(stream, server_name, |_| {}) + pub fn new(stream: S, server_name: &str) -> io::Result> { + Self::new_with_config(stream, server_name, |_| {}) } fn complete_io(&mut self) -> io::Result<(usize, usize)> { @@ -355,7 +375,7 @@ mod __openssl { } impl State { - fn get_stream_mut(&mut self) -> io::Result<&mut S> { + fn get_mut(&mut self) -> io::Result<&mut S> { match self { State::Handshake(stream_and_buf) => match stream_and_buf.as_mut() { Some((stream, _)) => Ok(stream.get_mut()), @@ -383,29 +403,29 @@ mod __openssl { #[cfg(feature = "mio")] impl Source for TlsStream { fn register(&mut self, registry: &Registry, token: Token, interests: Interest) -> io::Result<()> { - registry.register(self.state.get_stream_mut()?, token, interests) + registry.register(self.state.get_mut()?, token, interests) } fn reregister(&mut self, registry: &Registry, token: Token, interests: Interest) -> io::Result<()> { - registry.reregister(self.state.get_stream_mut()?, token, interests) + registry.reregister(self.state.get_mut()?, token, interests) } fn deregister(&mut self, registry: &Registry) -> io::Result<()> { - registry.deregister(self.state.get_stream_mut()?) + registry.deregister(self.state.get_mut()?) } } impl Selectable for TlsStream { fn connected(&mut self) -> io::Result { - self.state.get_stream_mut()?.connected() + self.state.get_mut()?.connected() } fn make_writable(&mut self) -> io::Result<()> { - self.state.get_stream_mut()?.make_writable() + self.state.get_mut()?.make_writable() } fn make_readable(&mut self) -> io::Result<()> { - self.state.get_stream_mut()?.make_readable() + self.state.get_mut()?.make_readable() } } @@ -485,7 +505,7 @@ mod __openssl { } impl TlsStream { - pub fn wrap_with_config(stream: S, server_name: &str, configure: F) -> io::Result> + pub fn new_with_config(stream: S, server_name: &str, configure: F) -> io::Result> where F: FnOnce(&mut TlsConfig), { @@ -510,8 +530,8 @@ mod __openssl { } } - pub fn wrap(stream: S, server_name: &str) -> io::Result> { - Self::wrap_with_config(stream, server_name, |_| {}) + pub fn new(stream: S, server_name: &str) -> io::Result> { + Self::new_with_config(stream, server_name, |_| {}) } } @@ -576,7 +596,7 @@ where F: FnOnce(&mut TlsConfig), { let server_name = self.connection_info().clone().host; - TlsStream::wrap_with_config(self, &server_name, builder) + TlsStream::new_with_config(self, &server_name, builder) } } diff --git a/src/ws/mod.rs b/src/ws/mod.rs index 0c6ba8e..b6835fc 100644 --- a/src/ws/mod.rs +++ b/src/ws/mod.rs @@ -468,7 +468,7 @@ where let tls_ready_stream = match url.scheme() { "ws" => Ok(TlsReadyStream::Plain(stream)), - "wss" => Ok(TlsReadyStream::Tls(TlsStream::wrap(stream, url.host_str().unwrap()).unwrap())), + "wss" => Ok(TlsReadyStream::Tls(TlsStream::new(stream, url.host_str().unwrap()).unwrap())), scheme => Err(io::Error::other(format!("unrecognised url scheme: {scheme}"))), }?;