diff --git a/Cargo.lock b/Cargo.lock index efef144..3c37b7c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1197,7 +1197,7 @@ dependencies = [ [[package]] name = "spawned-concurrency" -version = "0.4.2" +version = "0.4.3" dependencies = [ "futures", "pin-project-lite", @@ -1210,7 +1210,7 @@ dependencies = [ [[package]] name = "spawned-rt" -version = "0.4.2" +version = "0.4.3" dependencies = [ "crossbeam", "tokio", diff --git a/concurrency/src/threads/gen_server.rs b/concurrency/src/threads/gen_server.rs index ee09b17..0237b85 100644 --- a/concurrency/src/threads/gen_server.rs +++ b/concurrency/src/threads/gen_server.rs @@ -1,6 +1,6 @@ //! GenServer trait and structs to create an abstraction similar to Erlang gen_server. //! See examples/name_server for a usage example. -use spawned_rt::threads::{self as rt, mpsc, oneshot}; +use spawned_rt::threads::{self as rt, mpsc, oneshot, CancellationToken}; use std::{ fmt::Debug, panic::{catch_unwind, AssertUnwindSafe}, @@ -11,12 +11,14 @@ use crate::error::GenServerError; #[derive(Debug)] pub struct GenServerHandle { pub tx: mpsc::Sender>, + cancellation_token: CancellationToken, } impl Clone for GenServerHandle { fn clone(&self) -> Self { Self { tx: self.tx.clone(), + cancellation_token: self.cancellation_token.clone(), } } } @@ -24,7 +26,11 @@ impl Clone for GenServerHandle { impl GenServerHandle { pub(crate) fn new(gen_server: G) -> Self { let (tx, mut rx) = mpsc::channel::>(); - let handle = GenServerHandle { tx }; + let cancellation_token = CancellationToken::new(); + let handle = GenServerHandle { + tx, + cancellation_token, + }; let handle_clone = handle.clone(); // Ignore the JoinHandle for now. Maybe we'll use it in the future let _join_handle = rt::spawn(move || { @@ -56,6 +62,10 @@ impl GenServerHandle { .send(GenServerInMsg::Cast { message }) .map_err(|_error| GenServerError::Server) } + + pub fn cancellation_token(&self) -> CancellationToken { + self.cancellation_token.clone() + } } pub enum GenServerInMsg { @@ -69,18 +79,18 @@ pub enum GenServerInMsg { } pub enum CallResponse { - Reply(G, G::OutMsg), + Reply(G::OutMsg), Unused, Stop(G::OutMsg), } -pub enum CastResponse { - NoReply(G), +pub enum CastResponse { + NoReply, Unused, Stop, } -pub trait GenServer: Send + Sized + Clone { +pub trait GenServer: Send + Sized { type CallMsg: Clone + Send + Sized; type CastMsg: Clone + Send + Sized; type OutMsg: Send + Sized; @@ -101,16 +111,16 @@ pub trait GenServer: Send + Sized + Clone { handle: &GenServerHandle, rx: &mut mpsc::Receiver>, ) -> Result<(), GenServerError> { - match self.init(handle) { - Ok(new_state) => { - new_state.main_loop(handle, rx)?; - Ok(()) - } + let mut cancellation_token = handle.cancellation_token.clone(); + let res = match self.init(handle) { + Ok(new_state) => Ok(new_state.main_loop(handle, rx)?), Err(err) => { tracing::error!("Initialization failed: {err:?}"); Err(GenServerError::Initialization) } - } + }; + cancellation_token.cancel(); + res } /// Initialization function. It's called before main loop. It @@ -126,79 +136,71 @@ pub trait GenServer: Send + Sized + Clone { rx: &mut mpsc::Receiver>, ) -> Result<(), GenServerError> { loop { - let (new_state, cont) = self.receive(handle, rx)?; - if !cont { + if !self.receive(handle, rx)? { break; } - self = new_state; } tracing::trace!("Stopping GenServer"); Ok(()) } fn receive( - self, + &mut self, handle: &GenServerHandle, rx: &mut mpsc::Receiver>, - ) -> Result<(Self, bool), GenServerError> { + ) -> Result { let message = rx.recv().ok(); - // Save current state in case of a rollback - let state_clone = self.clone(); - - let (keep_running, new_state) = match message { + let keep_running = match message { Some(GenServerInMsg::Call { sender, message }) => { - let (keep_running, new_state, response) = - match catch_unwind(AssertUnwindSafe(|| self.handle_call(message, handle))) { - Ok(response) => match response { - CallResponse::Reply(new_state, response) => { - (true, new_state, Ok(response)) - } - CallResponse::Stop(response) => (false, state_clone, Ok(response)), - CallResponse::Unused => { - tracing::error!("GenServer received unexpected CallMessage"); - (false, state_clone, Err(GenServerError::CallMsgUnused)) - } - }, - Err(error) => { - tracing::trace!( - "Error in callback, reverting state - Error: '{error:?}'" - ); - (true, state_clone, Err(GenServerError::Callback)) + let (keep_running, response) = match catch_unwind(AssertUnwindSafe(|| { + self.handle_call(message, handle) + })) { + Ok(response) => match response { + CallResponse::Reply(response) => (true, Ok(response)), + CallResponse::Stop(response) => (false, Ok(response)), + CallResponse::Unused => { + tracing::error!("GenServer received unexpected CallMessage"); + (false, Err(GenServerError::CallMsgUnused)) } - }; + }, + Err(error) => { + tracing::trace!("Error in callback, reverting state - Error: '{error:?}'"); + (true, Err(GenServerError::Callback)) + } + }; // Send response back if sender.send(response).is_err() { tracing::trace!("GenServer failed to send response back, client must have died") }; - (keep_running, new_state) + keep_running } Some(GenServerInMsg::Cast { message }) => { match catch_unwind(AssertUnwindSafe(|| self.handle_cast(message, handle))) { Ok(response) => match response { - CastResponse::NoReply(new_state) => (true, new_state), - CastResponse::Stop => (false, state_clone), + CastResponse::NoReply => true, + CastResponse::Stop => false, CastResponse::Unused => { tracing::error!("GenServer received unexpected CastMessage"); - (false, state_clone) + false } }, Err(error) => { tracing::trace!("Error in callback, reverting state - Error: '{error:?}'"); - (true, state_clone) + true } } } None => { // Channel has been closed; won't receive further messages. Stop the server. - (false, self) + false } }; - Ok((new_state, keep_running)) + Ok(keep_running) } fn handle_call( - self, + &mut self, _message: Self::CallMsg, _handle: &GenServerHandle, ) -> CallResponse { @@ -206,10 +208,10 @@ pub trait GenServer: Send + Sized + Clone { } fn handle_cast( - self, + &mut self, _message: Self::CastMsg, _handle: &GenServerHandle, - ) -> CastResponse { + ) -> CastResponse { CastResponse::Unused } } diff --git a/concurrency/src/threads/timer_tests.rs b/concurrency/src/threads/timer_tests.rs index 7c144d8..446b147 100644 --- a/concurrency/src/threads/timer_tests.rs +++ b/concurrency/src/threads/timer_tests.rs @@ -63,16 +63,20 @@ impl GenServer for Repeater { Ok(self) } - fn handle_call(self, _message: Self::CallMsg, _handle: &RepeaterHandle) -> CallResponse { + fn handle_call( + &mut self, + _message: Self::CallMsg, + _handle: &RepeaterHandle, + ) -> CallResponse { let count = self.count; - CallResponse::Reply(self, RepeaterOutMessage::Count(count)) + CallResponse::Reply(RepeaterOutMessage::Count(count)) } fn handle_cast( - mut self, + &mut self, message: Self::CastMsg, _handle: &GenServerHandle, - ) -> CastResponse { + ) -> CastResponse { match message { RepeaterCastMessage::Inc => { self.count += 1; @@ -83,7 +87,7 @@ impl GenServer for Repeater { }; } }; - CastResponse::NoReply(self) + CastResponse::NoReply } } @@ -156,22 +160,22 @@ impl GenServer for Delayed { type OutMsg = DelayedOutMessage; type Error = (); - fn handle_call(self, _message: Self::CallMsg, _handle: &DelayedHandle) -> CallResponse { + fn handle_call( + &mut self, + _message: Self::CallMsg, + _handle: &DelayedHandle, + ) -> CallResponse { let count = self.count; - CallResponse::Reply(self, DelayedOutMessage::Count(count)) + CallResponse::Reply(DelayedOutMessage::Count(count)) } - fn handle_cast( - mut self, - message: Self::CastMsg, - _handle: &DelayedHandle, - ) -> CastResponse { + fn handle_cast(&mut self, message: Self::CastMsg, _handle: &DelayedHandle) -> CastResponse { match message { DelayedCastMessage::Inc => { self.count += 1; } }; - CastResponse::NoReply(self) + CastResponse::NoReply } } diff --git a/examples/bank_threads/src/server.rs b/examples/bank_threads/src/server.rs index 5419708..baeb71a 100644 --- a/examples/bank_threads/src/server.rs +++ b/examples/bank_threads/src/server.rs @@ -61,53 +61,42 @@ impl GenServer for Bank { Ok(self) } - fn handle_call(mut self, message: Self::CallMsg, _handle: &BankHandle) -> CallResponse { + fn handle_call(&mut self, message: Self::CallMsg, _handle: &BankHandle) -> CallResponse { match message.clone() { Self::CallMsg::New { who } => match self.accounts.get(&who) { - Some(_amount) => { - CallResponse::Reply(self, Err(BankError::AlreadyACustomer { who })) - } + Some(_amount) => CallResponse::Reply(Err(BankError::AlreadyACustomer { who })), None => { self.accounts.insert(who.clone(), 0); - CallResponse::Reply(self, Ok(OutMessage::Welcome { who })) + CallResponse::Reply(Ok(OutMessage::Welcome { who })) } }, Self::CallMsg::Add { who, amount } => match self.accounts.get(&who) { Some(current) => { let new_amount = current + amount; self.accounts.insert(who.clone(), new_amount); - CallResponse::Reply( - self, - Ok(OutMessage::Balance { - who, - amount: new_amount, - }), - ) + CallResponse::Reply(Ok(OutMessage::Balance { + who, + amount: new_amount, + })) } - None => CallResponse::Reply(self, Err(BankError::NotACustomer { who })), + None => CallResponse::Reply(Err(BankError::NotACustomer { who })), }, Self::CallMsg::Remove { who, amount } => match self.accounts.get(&who) { Some(¤t) => match current < amount { - true => CallResponse::Reply( - self, - Err(BankError::InsufficientBalance { - who, - amount: current, - }), - ), + true => CallResponse::Reply(Err(BankError::InsufficientBalance { + who, + amount: current, + })), false => { let new_amount = current - amount; self.accounts.insert(who.clone(), new_amount); - CallResponse::Reply( - self, - Ok(OutMessage::WidrawOk { - who, - amount: new_amount, - }), - ) + CallResponse::Reply(Ok(OutMessage::WidrawOk { + who, + amount: new_amount, + })) } }, - None => CallResponse::Reply(self, Err(BankError::NotACustomer { who })), + None => CallResponse::Reply(Err(BankError::NotACustomer { who })), }, Self::CallMsg::Stop => CallResponse::Stop(Ok(OutMessage::Stopped)), } diff --git a/examples/updater_threads/src/server.rs b/examples/updater_threads/src/server.rs index 0ed9173..23eafc1 100644 --- a/examples/updater_threads/src/server.rs +++ b/examples/updater_threads/src/server.rs @@ -28,11 +28,7 @@ impl GenServer for UpdaterServer { Ok(self) } - fn handle_cast( - self, - message: Self::CastMsg, - handle: &UpdateServerHandle, - ) -> CastResponse { + fn handle_cast(&mut self, message: Self::CastMsg, handle: &UpdateServerHandle) -> CastResponse { match message { Self::CastMsg::Check => { send_after(self.periodicity, handle.clone(), InMessage::Check); @@ -42,7 +38,7 @@ impl GenServer for UpdaterServer { tracing::info!("Response: {resp:?}"); - CastResponse::NoReply(self) + CastResponse::NoReply } } } diff --git a/rt/src/threads/mod.rs b/rt/src/threads/mod.rs index adcea5f..3c71067 100644 --- a/rt/src/threads/mod.rs +++ b/rt/src/threads/mod.rs @@ -34,7 +34,7 @@ where spawn(f) } -#[derive(Clone, Default)] +#[derive(Clone, Debug, Default)] pub struct CancellationToken { is_cancelled: Arc, }