From 7bc5fd76b0d41cce3b54801709483dffd0b1246f Mon Sep 17 00:00:00 2001 From: Luqman Aden Date: Fri, 11 Nov 2022 16:34:26 -0800 Subject: [PATCH] Change stop to return only after server has finished stopping. --- src/server.rs | 80 ++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 67 insertions(+), 13 deletions(-) diff --git a/src/server.rs b/src/server.rs index 0ce98c1..314fe91 100644 --- a/src/server.rs +++ b/src/server.rs @@ -10,14 +10,14 @@ use std::net::SocketAddr; use std::sync::Arc; use async_trait::async_trait; -use futures::FutureExt; use futures::future::Shared; +use futures::{future, FutureExt}; use log::{debug, error, info, trace}; use thiserror::Error; use tokio::io::AsyncWriteExt; use tokio::net::{TcpListener, TcpStream}; use tokio::select; -use tokio::sync::{Mutex, oneshot}; +use tokio::sync::{oneshot, Mutex}; use crate::rfb::{ ClientInit, ClientMessage, FramebufferUpdate, KeyEvent, PixelFormat, ProtoVersion, @@ -74,7 +74,7 @@ pub struct VncServer { pub server: S, /// One-shot channel used to signal that the server should shut down. - stop_ch: Mutex>>, + stop_ch: Mutex>>>, } #[async_trait] @@ -183,7 +183,12 @@ impl VncServer { Ok(()) } - async fn handle_conn(&self, s: &mut TcpStream, addr: SocketAddr, mut close_ch: Shared>) { + async fn handle_conn( + &self, + s: &mut TcpStream, + addr: SocketAddr, + mut close_ch: Shared>, + ) { info!("[{:?}] new connection", addr); if let Err(e) = self.rfb_handshake(s, addr).await { @@ -288,36 +293,85 @@ impl VncServer { let listener = TcpListener::bind(self.config.addr).await?; // Create a channel to signal the server to stop. - let (close_tx, close_rx) = oneshot::channel(); - assert!(self.stop_ch.lock().await.replace(close_tx).is_none(), "server already started"); - let mut close_rx = close_rx.shared(); + let (close_tx, mut close_rx) = oneshot::channel(); + assert!( + self.stop_ch.lock().await.replace(close_tx).is_none(), + "server already started" + ); + + // And a pair used to stop clients + let (client_stop_tx, client_stop_rx) = oneshot::channel(); + let client_stop_rx = client_stop_rx.shared(); + let mut clients = vec![]; loop { let (mut client_sock, client_addr) = select! { // Poll in the order written so we check for close first biased; - _ = &mut close_rx => { + done = &mut close_rx => { info!("server stopping"); + + // Stop all clients and wait for them to finish + // SAFTEY: unwrapping here is fine because we also hold a valid reference + // to the receiver a this point (the shared `client_stop_rx`). + client_stop_tx.send(()).unwrap(); + future::join_all(clients).await; + + // Inform the `Server` impl that we're stopping self.server.stop().await; + + match done { + Ok(done) => { + // Let .stop() know we're done + let _ = done.send(()); + } + Err(_) => { + return Err(io::Error::new(io::ErrorKind::Other, "unexpected server stop")); + } + } + return Ok(()); } conn = listener.accept() => conn?, }; - let close_rx = close_rx.clone(); + // Create a new task to handle the client connection + let client_stop_rx = client_stop_rx.clone(); let server = self.clone(); - tokio::spawn(async move { - server.handle_conn(&mut client_sock, client_addr, close_rx).await; + let client = tokio::spawn(async move { + server + .handle_conn(&mut client_sock, client_addr, client_stop_rx) + .await; }); + clients.push(client); } } /// Stop the server (and disconnect any client) if it's running. - pub async fn stop(self: &Arc) { + /// + /// Returns `Ok` if the server successfully stopped or `Err` otherwise. + pub async fn stop(self: &Arc) -> io::Result<()> { if let Some(close_tx) = self.stop_ch.lock().await.take() { - let _ = close_tx.send(()); + let (done_tx, done_rx) = oneshot::channel(); + + // Signal the server to stop + close_tx + .send(done_tx) + .map_err(|_| io::Error::new(io::ErrorKind::Other, "failed to stop server"))?; + + // and now wait for it to exit + done_rx + .await + .map_err(|_| io::Error::new(io::ErrorKind::Other, "ungraceful server stop"))?; + + Ok(()) + } else { + Err(io::Error::new( + io::ErrorKind::NotFound, + "server not running", + )) } } }