Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 67 additions & 13 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ use std::net::SocketAddr;
use std::sync::Arc;

use async_trait::async_trait;
use futures::FutureExt;
use futures::future::Shared;
use futures::{future, FutureExt};
use log::{debug, error, info, trace};
use thiserror::Error;
use tokio::io::AsyncWriteExt;
use tokio::net::{TcpListener, TcpStream};
use tokio::select;
use tokio::sync::{Mutex, oneshot};
use tokio::sync::{oneshot, Mutex};

use crate::rfb::{
ClientInit, ClientMessage, FramebufferUpdate, KeyEvent, PixelFormat, ProtoVersion,
Expand Down Expand Up @@ -74,7 +74,7 @@ pub struct VncServer<S: Server> {
pub server: S,

/// One-shot channel used to signal that the server should shut down.
stop_ch: Mutex<Option<oneshot::Sender<()>>>,
stop_ch: Mutex<Option<oneshot::Sender<oneshot::Sender<()>>>>,
}

#[async_trait]
Expand Down Expand Up @@ -183,7 +183,12 @@ impl<S: Server> VncServer<S> {
Ok(())
}

async fn handle_conn(&self, s: &mut TcpStream, addr: SocketAddr, mut close_ch: Shared<oneshot::Receiver<()>>) {
async fn handle_conn(
&self,
s: &mut TcpStream,
addr: SocketAddr,
mut close_ch: Shared<oneshot::Receiver<()>>,
) {
info!("[{:?}] new connection", addr);

if let Err(e) = self.rfb_handshake(s, addr).await {
Expand Down Expand Up @@ -288,36 +293,85 @@ impl<S: Server> VncServer<S> {
let listener = TcpListener::bind(self.config.addr).await?;

// Create a channel to signal the server to stop.
let (close_tx, close_rx) = oneshot::channel();
assert!(self.stop_ch.lock().await.replace(close_tx).is_none(), "server already started");
let mut close_rx = close_rx.shared();
let (close_tx, mut close_rx) = oneshot::channel();
assert!(
self.stop_ch.lock().await.replace(close_tx).is_none(),
"server already started"
);

// And a pair used to stop clients
let (client_stop_tx, client_stop_rx) = oneshot::channel();
let client_stop_rx = client_stop_rx.shared();
let mut clients = vec![];

loop {
let (mut client_sock, client_addr) = select! {
// Poll in the order written so we check for close first
biased;

_ = &mut close_rx => {
done = &mut close_rx => {
info!("server stopping");

// Stop all clients and wait for them to finish
// SAFTEY: unwrapping here is fine because we also hold a valid reference
// to the receiver a this point (the shared `client_stop_rx`).
client_stop_tx.send(()).unwrap();
future::join_all(clients).await;

// Inform the `Server` impl that we're stopping
self.server.stop().await;

match done {
Ok(done) => {
// Let .stop() know we're done
let _ = done.send(());
}
Err(_) => {
return Err(io::Error::new(io::ErrorKind::Other, "unexpected server stop"));
}
}

return Ok(());
}

conn = listener.accept() => conn?,
};

let close_rx = close_rx.clone();
// Create a new task to handle the client connection
let client_stop_rx = client_stop_rx.clone();
let server = self.clone();
tokio::spawn(async move {
server.handle_conn(&mut client_sock, client_addr, close_rx).await;
let client = tokio::spawn(async move {
server
.handle_conn(&mut client_sock, client_addr, client_stop_rx)
.await;
});
clients.push(client);
}
}

/// Stop the server (and disconnect any client) if it's running.
pub async fn stop(self: &Arc<Self>) {
///
/// Returns `Ok` if the server successfully stopped or `Err` otherwise.
pub async fn stop(self: &Arc<Self>) -> io::Result<()> {
if let Some(close_tx) = self.stop_ch.lock().await.take() {
let _ = close_tx.send(());
let (done_tx, done_rx) = oneshot::channel();

// Signal the server to stop
close_tx
.send(done_tx)
.map_err(|_| io::Error::new(io::ErrorKind::Other, "failed to stop server"))?;

// and now wait for it to exit
done_rx
.await
.map_err(|_| io::Error::new(io::ErrorKind::Other, "ungraceful server stop"))?;

Ok(())
} else {
Err(io::Error::new(
io::ErrorKind::NotFound,
"server not running",
))
}
}
}