Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

100 changes: 51 additions & 49 deletions concurrency/src/threads/gen_server.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! GenServer trait and structs to create an abstraction similar to Erlang gen_server.
//! See examples/name_server for a usage example.
use spawned_rt::threads::{self as rt, mpsc, oneshot};
use spawned_rt::threads::{self as rt, mpsc, oneshot, CancellationToken};
use std::{
fmt::Debug,
panic::{catch_unwind, AssertUnwindSafe},
Expand All @@ -11,20 +11,26 @@ use crate::error::GenServerError;
#[derive(Debug)]
pub struct GenServerHandle<G: GenServer + 'static> {
pub tx: mpsc::Sender<GenServerInMsg<G>>,
cancellation_token: CancellationToken,
}

impl<G: GenServer> Clone for GenServerHandle<G> {
fn clone(&self) -> Self {
Self {
tx: self.tx.clone(),
cancellation_token: self.cancellation_token.clone(),
}
}
}

impl<G: GenServer> GenServerHandle<G> {
pub(crate) fn new(gen_server: G) -> Self {
let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
let handle = GenServerHandle { tx };
let cancellation_token = CancellationToken::new();
let handle = GenServerHandle {
tx,
cancellation_token,
};
let handle_clone = handle.clone();
// Ignore the JoinHandle for now. Maybe we'll use it in the future
let _join_handle = rt::spawn(move || {
Expand Down Expand Up @@ -56,6 +62,10 @@ impl<G: GenServer> GenServerHandle<G> {
.send(GenServerInMsg::Cast { message })
.map_err(|_error| GenServerError::Server)
}

pub fn cancellation_token(&self) -> CancellationToken {
self.cancellation_token.clone()
}
}

pub enum GenServerInMsg<G: GenServer> {
Expand All @@ -69,18 +79,18 @@ pub enum GenServerInMsg<G: GenServer> {
}

pub enum CallResponse<G: GenServer> {
Reply(G, G::OutMsg),
Reply(G::OutMsg),
Unused,
Stop(G::OutMsg),
}

pub enum CastResponse<G: GenServer> {
NoReply(G),
pub enum CastResponse {
NoReply,
Unused,
Stop,
}

pub trait GenServer: Send + Sized + Clone {
pub trait GenServer: Send + Sized {
type CallMsg: Clone + Send + Sized;
type CastMsg: Clone + Send + Sized;
type OutMsg: Send + Sized;
Expand All @@ -101,16 +111,16 @@ pub trait GenServer: Send + Sized + Clone {
handle: &GenServerHandle<Self>,
rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
) -> Result<(), GenServerError> {
match self.init(handle) {
Ok(new_state) => {
new_state.main_loop(handle, rx)?;
Ok(())
}
let mut cancellation_token = handle.cancellation_token.clone();
let res = match self.init(handle) {
Ok(new_state) => Ok(new_state.main_loop(handle, rx)?),
Err(err) => {
tracing::error!("Initialization failed: {err:?}");
Err(GenServerError::Initialization)
}
}
};
cancellation_token.cancel();
res
}

/// Initialization function. It's called before main loop. It
Expand All @@ -126,90 +136,82 @@ pub trait GenServer: Send + Sized + Clone {
rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
) -> Result<(), GenServerError> {
loop {
let (new_state, cont) = self.receive(handle, rx)?;
if !cont {
if !self.receive(handle, rx)? {
break;
}
self = new_state;
}
tracing::trace!("Stopping GenServer");
Ok(())
}

fn receive(
self,
&mut self,
handle: &GenServerHandle<Self>,
rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
) -> Result<(Self, bool), GenServerError> {
) -> Result<bool, GenServerError> {
let message = rx.recv().ok();

// Save current state in case of a rollback
let state_clone = self.clone();

let (keep_running, new_state) = match message {
let keep_running = match message {
Some(GenServerInMsg::Call { sender, message }) => {
let (keep_running, new_state, response) =
match catch_unwind(AssertUnwindSafe(|| self.handle_call(message, handle))) {
Ok(response) => match response {
CallResponse::Reply(new_state, response) => {
(true, new_state, Ok(response))
}
CallResponse::Stop(response) => (false, state_clone, Ok(response)),
CallResponse::Unused => {
tracing::error!("GenServer received unexpected CallMessage");
(false, state_clone, Err(GenServerError::CallMsgUnused))
}
},
Err(error) => {
tracing::trace!(
"Error in callback, reverting state - Error: '{error:?}'"
);
(true, state_clone, Err(GenServerError::Callback))
let (keep_running, response) = match catch_unwind(AssertUnwindSafe(|| {
self.handle_call(message, handle)
})) {
Ok(response) => match response {
CallResponse::Reply(response) => (true, Ok(response)),
CallResponse::Stop(response) => (false, Ok(response)),
CallResponse::Unused => {
tracing::error!("GenServer received unexpected CallMessage");
(false, Err(GenServerError::CallMsgUnused))
}
};
},
Err(error) => {
tracing::trace!("Error in callback, reverting state - Error: '{error:?}'");
(true, Err(GenServerError::Callback))
}
};
// Send response back
if sender.send(response).is_err() {
tracing::trace!("GenServer failed to send response back, client must have died")
};
(keep_running, new_state)
keep_running
}
Some(GenServerInMsg::Cast { message }) => {
match catch_unwind(AssertUnwindSafe(|| self.handle_cast(message, handle))) {
Ok(response) => match response {
CastResponse::NoReply(new_state) => (true, new_state),
CastResponse::Stop => (false, state_clone),
CastResponse::NoReply => true,
CastResponse::Stop => false,
CastResponse::Unused => {
tracing::error!("GenServer received unexpected CastMessage");
(false, state_clone)
false
}
},
Err(error) => {
tracing::trace!("Error in callback, reverting state - Error: '{error:?}'");
(true, state_clone)
true
}
}
}
None => {
// Channel has been closed; won't receive further messages. Stop the server.
(false, self)
false
}
};
Ok((new_state, keep_running))
Ok(keep_running)
}

fn handle_call(
self,
&mut self,
_message: Self::CallMsg,
_handle: &GenServerHandle<Self>,
) -> CallResponse<Self> {
CallResponse::Unused
}

fn handle_cast(
self,
&mut self,
_message: Self::CastMsg,
_handle: &GenServerHandle<Self>,
) -> CastResponse<Self> {
) -> CastResponse {
CastResponse::Unused
}
}
30 changes: 17 additions & 13 deletions concurrency/src/threads/timer_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,20 @@ impl GenServer for Repeater {
Ok(self)
}

fn handle_call(self, _message: Self::CallMsg, _handle: &RepeaterHandle) -> CallResponse<Self> {
fn handle_call(
&mut self,
_message: Self::CallMsg,
_handle: &RepeaterHandle,
) -> CallResponse<Self> {
let count = self.count;
CallResponse::Reply(self, RepeaterOutMessage::Count(count))
CallResponse::Reply(RepeaterOutMessage::Count(count))
}

fn handle_cast(
mut self,
&mut self,
message: Self::CastMsg,
_handle: &GenServerHandle<Self>,
) -> CastResponse<Self> {
) -> CastResponse {
match message {
RepeaterCastMessage::Inc => {
self.count += 1;
Expand All @@ -83,7 +87,7 @@ impl GenServer for Repeater {
};
}
};
CastResponse::NoReply(self)
CastResponse::NoReply
}
}

Expand Down Expand Up @@ -156,22 +160,22 @@ impl GenServer for Delayed {
type OutMsg = DelayedOutMessage;
type Error = ();

fn handle_call(self, _message: Self::CallMsg, _handle: &DelayedHandle) -> CallResponse<Self> {
fn handle_call(
&mut self,
_message: Self::CallMsg,
_handle: &DelayedHandle,
) -> CallResponse<Self> {
let count = self.count;
CallResponse::Reply(self, DelayedOutMessage::Count(count))
CallResponse::Reply(DelayedOutMessage::Count(count))
}

fn handle_cast(
mut self,
message: Self::CastMsg,
_handle: &DelayedHandle,
) -> CastResponse<Self> {
fn handle_cast(&mut self, message: Self::CastMsg, _handle: &DelayedHandle) -> CastResponse {
match message {
DelayedCastMessage::Inc => {
self.count += 1;
}
};
CastResponse::NoReply(self)
CastResponse::NoReply
}
}

Expand Down
45 changes: 17 additions & 28 deletions examples/bank_threads/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,53 +61,42 @@ impl GenServer for Bank {
Ok(self)
}

fn handle_call(mut self, message: Self::CallMsg, _handle: &BankHandle) -> CallResponse<Self> {
fn handle_call(&mut self, message: Self::CallMsg, _handle: &BankHandle) -> CallResponse<Self> {
match message.clone() {
Self::CallMsg::New { who } => match self.accounts.get(&who) {
Some(_amount) => {
CallResponse::Reply(self, Err(BankError::AlreadyACustomer { who }))
}
Some(_amount) => CallResponse::Reply(Err(BankError::AlreadyACustomer { who })),
None => {
self.accounts.insert(who.clone(), 0);
CallResponse::Reply(self, Ok(OutMessage::Welcome { who }))
CallResponse::Reply(Ok(OutMessage::Welcome { who }))
}
},
Self::CallMsg::Add { who, amount } => match self.accounts.get(&who) {
Some(current) => {
let new_amount = current + amount;
self.accounts.insert(who.clone(), new_amount);
CallResponse::Reply(
self,
Ok(OutMessage::Balance {
who,
amount: new_amount,
}),
)
CallResponse::Reply(Ok(OutMessage::Balance {
who,
amount: new_amount,
}))
}
None => CallResponse::Reply(self, Err(BankError::NotACustomer { who })),
None => CallResponse::Reply(Err(BankError::NotACustomer { who })),
},
Self::CallMsg::Remove { who, amount } => match self.accounts.get(&who) {
Some(&current) => match current < amount {
true => CallResponse::Reply(
self,
Err(BankError::InsufficientBalance {
who,
amount: current,
}),
),
true => CallResponse::Reply(Err(BankError::InsufficientBalance {
who,
amount: current,
})),
false => {
let new_amount = current - amount;
self.accounts.insert(who.clone(), new_amount);
CallResponse::Reply(
self,
Ok(OutMessage::WidrawOk {
who,
amount: new_amount,
}),
)
CallResponse::Reply(Ok(OutMessage::WidrawOk {
who,
amount: new_amount,
}))
}
},
None => CallResponse::Reply(self, Err(BankError::NotACustomer { who })),
None => CallResponse::Reply(Err(BankError::NotACustomer { who })),
},
Self::CallMsg::Stop => CallResponse::Stop(Ok(OutMessage::Stopped)),
}
Expand Down
8 changes: 2 additions & 6 deletions examples/updater_threads/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@ impl GenServer for UpdaterServer {
Ok(self)
}

fn handle_cast(
self,
message: Self::CastMsg,
handle: &UpdateServerHandle,
) -> CastResponse<Self> {
fn handle_cast(&mut self, message: Self::CastMsg, handle: &UpdateServerHandle) -> CastResponse {
match message {
Self::CastMsg::Check => {
send_after(self.periodicity, handle.clone(), InMessage::Check);
Expand All @@ -42,7 +38,7 @@ impl GenServer for UpdaterServer {

tracing::info!("Response: {resp:?}");

CastResponse::NoReply(self)
CastResponse::NoReply
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion rt/src/threads/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ where
spawn(f)
}

#[derive(Clone, Default)]
#[derive(Clone, Debug, Default)]
pub struct CancellationToken {
is_cancelled: Arc<AtomicBool>,
}
Expand Down