From 10521801d512bc7572b78fe1c535ff80304bef3a Mon Sep 17 00:00:00 2001 From: Genaro-Chris Date: Sun, 9 Jun 2024 14:46:22 +0100 Subject: [PATCH 01/10] Initial Commit --- .github/workflows/rust.yml | 4 +- src/async_runtime/exec.rs | 87 +++++++++++++++++ src/async_runtime/executor.rs | 141 ---------------------------- src/async_runtime/mod.rs | 5 +- src/async_runtime/task_queue.rs | 34 ------- src/discarding_spawn_group.rs | 11 +-- src/err_spawn_group.rs | 13 +-- src/lib.rs | 42 +++++++-- src/shared/initializible.rs | 3 - src/shared/mod.rs | 1 - src/shared/runtime.rs | 50 ++-------- src/spawn_group.rs | 13 +-- src/threadpool_impl/channel.rs | 87 +++++++++++++++++ src/threadpool_impl/index.rs | 31 ++++++ src/threadpool_impl/iteratorimpl.rs | 12 --- src/threadpool_impl/mod.rs | 10 +- src/threadpool_impl/queue.rs | 42 --------- src/threadpool_impl/queueops.rs | 5 - src/threadpool_impl/thread.rs | 39 +++++--- src/threadpool_impl/threadpool.rs | 120 +++++------------------ 20 files changed, 311 insertions(+), 439 deletions(-) create mode 100644 src/async_runtime/exec.rs delete mode 100755 src/async_runtime/executor.rs delete mode 100755 src/async_runtime/task_queue.rs delete mode 100755 src/shared/initializible.rs create mode 100644 src/threadpool_impl/channel.rs create mode 100644 src/threadpool_impl/index.rs delete mode 100644 src/threadpool_impl/iteratorimpl.rs delete mode 100644 src/threadpool_impl/queue.rs delete mode 100644 src/threadpool_impl/queueops.rs diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 12c6eee..bcb8cac 100755 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -2,9 +2,9 @@ name: Rust on: push: - branches: [ "master" ] + branches: [ "*" ] pull_request: - branches: [ "master" ] + branches: [ "*" ] env: CARGO_TERM_COLOR: always diff --git a/src/async_runtime/exec.rs b/src/async_runtime/exec.rs new file mode 100644 index 0000000..2dad39b --- /dev/null +++ b/src/async_runtime/exec.rs @@ -0,0 +1,87 @@ +use crate::{ + pin_future, + threadpool_impl::{Channel, ThreadPool}, +}; + +use super::{notifier::Notifier, task::Task}; + +use cooked_waker::IntoWaker; + +use std::{ + future::Future, + sync::Arc, + task::{Context, Poll, Waker}, +}; + +#[derive(Clone)] +pub struct Executor { + pool: Arc, + queue: Channel, +} + +impl Executor { + pub(crate) fn new(count: usize) -> Self { + let result: Executor = Self { + pool: Arc::new(ThreadPool::new(count)), + queue: Channel::new(), + }; + let result_clone = result.clone(); + std::thread::spawn(move || { + result_clone.run(); + }); + result + } +} + +impl Executor { + pub(crate) fn submit(&self, task: Task) + where + Task: FnOnce() + Send + 'static, + { + self.pool.submit(task); + } + + pub(crate) fn spawn(&self, task: F) -> Task + where + F: Future + Send + 'static, + { + let task = Task::new(task); + self.queue.enqueue(task.clone()); + task + } + + pub(crate) fn cancel(&self) { + self.pool.clear(); + self.queue.clear(); + self.poll_all(); + self.pool.clear(); + self.queue.clear(); + } + + fn run(&self) { + while let Some(task) = self.queue.dequeue() { + let queue = self.queue.clone(); + self.submit(move || { + let waker: Waker = Arc::new(Notifier::default()).into_waker(); + pin_future!(task); + let mut cx: Context<'_> = Context::from_waker(&waker); + match task.as_mut().poll(&mut cx) { + Poll::Ready(()) => (), + Poll::Pending => { + queue.enqueue(task.clone()); + } + } + }); + } + } + + pub(crate) fn poll_all(&self) { + self.pool.wait_for_all(); + } + + pub(crate) fn end(&mut self) { + self.queue.clear(); + self.queue.close(); + //self.pool.drop_pool(); + } +} diff --git a/src/async_runtime/executor.rs b/src/async_runtime/executor.rs deleted file mode 100755 index 2cb1689..0000000 --- a/src/async_runtime/executor.rs +++ /dev/null @@ -1,141 +0,0 @@ -use crate::{pin_future, threadpool_impl::ThreadPool}; - -use super::{notifier::Notifier, task::Task, task_queue::TaskQueue}; - -use cooked_waker::IntoWaker; -use parking_lot::{lock_api::MutexGuard, Condvar, Mutex, RawMutex}; - -use std::{ - future::Future, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - task::{Context, Poll, Waker}, -}; - -#[derive(Clone)] -pub struct Executor { - cancel: Arc, - lock_pair: Arc<(Mutex, Condvar)>, - pool: Arc, - queue: TaskQueue, - started: Arc, -} - -impl Default for Executor { - fn default() -> Self { - let result: Executor = Self { - cancel: Arc::new(AtomicBool::new(false)), - lock_pair: Arc::new((Mutex::new(false), Condvar::new())), - pool: Arc::new(ThreadPool::default()), - queue: TaskQueue::default(), - started: Arc::new(AtomicBool::new(false)), - }; - result.start(); - result - } -} - -impl Executor { - pub(crate) fn new(count: usize) -> Self { - let result: Executor = Self { - cancel: Arc::new(AtomicBool::new(false)), - lock_pair: Arc::new((Mutex::new(false), Condvar::new())), - pool: Arc::new(ThreadPool::new(count)), - queue: TaskQueue::default(), - started: Arc::new(AtomicBool::new(false)), - }; - result.start(); - result - } -} - -impl Executor { - fn started(&self) -> bool { - self.started.load(Ordering::Acquire) - } - - fn update(&self, val: bool) { - self.started.store(val, Ordering::Release); - } -} - -impl Executor { - pub(crate) fn submit(&self, task: Task) - where - Task: FnOnce() + Send + 'static, - { - self.pool.submit(task); - } - - pub(crate) fn spawn(&self, task: Fut) -> Task - where - Fut: Future + 'static + Send, - { - let task: Task = Task::new(task); - self.queue.push(&task); - - if !self.started() { - self.notify(); - } - task - } - - fn notify(&self) { - self.update(true); - let pair2: Arc<(Mutex, Condvar)> = self.lock_pair.clone(); - std::thread::spawn(move || { - let (lock, cvar) = &*pair2; - let mut started: MutexGuard<'_, RawMutex, bool> = lock.lock(); - *started = true; - cvar.notify_one(); - }); - } - - pub(crate) fn cancel(&self) { - self.cancel.store(true, Ordering::Release); - *self.lock_pair.0.lock() = false; - self.update(false); - self.queue.drain_all(); - self.cancel.store(false, Ordering::Release); - } - - pub(crate) fn run(&self) { - while !self.cancel.load(Ordering::Acquire) { - self.queue.clone().for_each(|task| { - let queue: TaskQueue = self.queue.clone(); - self.submit(move || { - let waker: Waker = Arc::new(Notifier::default()).into_waker(); - pin_future!(task); - let mut cx: Context<'_> = Context::from_waker(&waker); - match task.as_mut().poll(&mut cx) { - Poll::Ready(()) => (), - Poll::Pending => { - queue.push(&task); - } - } - }); - }); - } - self.poll_all(); - self.queue.drain_all(); - } - - pub(crate) fn poll_all(&self) { - self.pool.wait_for_all(); - } - - pub(crate) fn start(&self) { - let lock_pair: Arc<(Mutex, Condvar)> = self.lock_pair.clone(); - let executor: Executor = self.clone(); - std::thread::spawn(move || { - let (lock, cvar) = &*lock_pair; - let mut started: MutexGuard<'_, RawMutex, bool> = lock.lock(); - while !*started { - cvar.wait(&mut started); - } - executor.run(); - }); - } -} diff --git a/src/async_runtime/mod.rs b/src/async_runtime/mod.rs index 023ec5a..e4e8d24 100755 --- a/src/async_runtime/mod.rs +++ b/src/async_runtime/mod.rs @@ -1,5 +1,4 @@ -pub(crate) mod executor; +pub(crate) mod exec; pub(crate) mod notifier; +mod pin_macro; pub(crate) mod task; -mod task_queue; -mod pin_macro; \ No newline at end of file diff --git a/src/async_runtime/task_queue.rs b/src/async_runtime/task_queue.rs deleted file mode 100755 index f49f16d..0000000 --- a/src/async_runtime/task_queue.rs +++ /dev/null @@ -1,34 +0,0 @@ -use super::task::Task; -use parking_lot::Mutex; -use std::{collections::VecDeque, iter::Iterator, sync::Arc}; - -#[derive(Clone, Default)] -pub struct TaskQueue { - buffer: Arc>>, -} - -impl TaskQueue { - pub(crate) fn push(&self, task: &Task) { - self.buffer.lock().push_back(task.clone()); - } -} - -impl TaskQueue { - pub(crate) fn drain_all(&self) { - self.buffer.lock().clear(); - } -} - -impl Iterator for TaskQueue { - type Item = Task; - - fn next(&mut self) -> Option { - let Some(task) = self.buffer.lock().pop_front() else { - return None; - }; - if !task.is_completed() { - return Some(task); - } - None - } -} diff --git a/src/discarding_spawn_group.rs b/src/discarding_spawn_group.rs index 4b0a7b6..3b564c7 100755 --- a/src/discarding_spawn_group.rs +++ b/src/discarding_spawn_group.rs @@ -1,5 +1,5 @@ use crate::shared::{ - initializible::Initializible, priority::Priority, runtime::RuntimeEngine, sharedfuncs::Shared, + priority::Priority, runtime::RuntimeEngine, sharedfuncs::Shared, }; use std::future::Future; @@ -133,12 +133,3 @@ impl Shared for DiscardingSpawnGroup { } } -impl Initializible for DiscardingSpawnGroup { - fn init() -> Self { - DiscardingSpawnGroup { - is_cancelled: false, - runtime: RuntimeEngine::init(), - wait_at_drop: true, - } - } -} diff --git a/src/err_spawn_group.rs b/src/err_spawn_group.rs index 2c96140..c4df1f0 100755 --- a/src/err_spawn_group.rs +++ b/src/err_spawn_group.rs @@ -1,5 +1,5 @@ use crate::shared::{ - initializible::Initializible, priority::Priority, runtime::RuntimeEngine, sharedfuncs::Shared, + priority::Priority, runtime::RuntimeEngine, sharedfuncs::Shared, wait::Waitable, }; use async_trait::async_trait; @@ -205,17 +205,6 @@ impl Drop for ErrSpawnGroup Initializible for ErrSpawnGroup { - fn init() -> Self { - ErrSpawnGroup:: { - count: Arc::new(AtomicUsize::new(0)), - is_cancelled: false, - runtime: RuntimeEngine::init(), - wait_at_drop: true, - } - } -} - impl Shared for ErrSpawnGroup { diff --git a/src/lib.rs b/src/lib.rs index a7de5e0..01f95a2 100755 --- a/src/lib.rs +++ b/src/lib.rs @@ -182,7 +182,6 @@ pub use discarding_spawn_group::DiscardingSpawnGroup; pub use err_spawn_group::ErrSpawnGroup; pub use executors::block_on; pub use meta_types::GetType; -use shared::initializible::Initializible; pub use shared::priority::Priority; pub use sleeper::sleep; pub use spawn_group::SpawnGroup; @@ -190,6 +189,7 @@ pub use yield_now::yield_now; use std::future::Future; use std::marker::PhantomData; +use std::thread::available_parallelism; /// Starts a scoped closure that takes a mutable ``SpawnGroup`` instance as an argument which can execute any number of child tasks which its result values are of the generic ``ResultType`` type. /// @@ -244,8 +244,14 @@ where Fut: Future + Send + 'static, ResultType: Send + 'static, { + let count: usize; + if let Ok(thread_count) = available_parallelism() { + count = thread_count.get(); + } else { + count = 1; + } _ = of_type; - let task_group = spawn_group::SpawnGroup::::init(); + let task_group = spawn_group::SpawnGroup::::new(count); body(task_group).await } @@ -298,7 +304,13 @@ where Fut: Future + Send + 'static, ResultType: Send + 'static, { - let task_group = spawn_group::SpawnGroup::::init(); + let count: usize; + if let Ok(thread_count) = available_parallelism() { + count = thread_count.get(); + } else { + count = 1; + } + let task_group = spawn_group::SpawnGroup::::new(count); body(task_group).await } @@ -402,8 +414,14 @@ where Closure: FnOnce(err_spawn_group::ErrSpawnGroup) -> Fut + Send + 'static, ResultType: Send + 'static, { + let count: usize; + if let Ok(thread_count) = available_parallelism() { + count = thread_count.get(); + } else { + count = 1; + } _ = (of_type, error_type); - let task_group = err_spawn_group::ErrSpawnGroup::::init(); + let task_group = err_spawn_group::ErrSpawnGroup::::new(count); body(task_group).await } @@ -503,7 +521,13 @@ where Closure: FnOnce(err_spawn_group::ErrSpawnGroup) -> Fut + Send + 'static, ResultType: Send + 'static, { - let task_group = err_spawn_group::ErrSpawnGroup::::init(); + let count: usize; + if let Ok(thread_count) = available_parallelism() { + count = thread_count.get(); + } else { + count = 1; + } + let task_group = err_spawn_group::ErrSpawnGroup::::new(count); body(task_group).await } @@ -549,6 +573,12 @@ where Fut: Future, Closure: FnOnce(discarding_spawn_group::DiscardingSpawnGroup) -> Fut + Send + 'static, { - let discarding_tg = discarding_spawn_group::DiscardingSpawnGroup::init(); + let count: usize; + if let Ok(thread_count) = available_parallelism() { + count = thread_count.get(); + } else { + count = 1; + } + let discarding_tg = discarding_spawn_group::DiscardingSpawnGroup::new(count); body(discarding_tg).await } diff --git a/src/shared/initializible.rs b/src/shared/initializible.rs deleted file mode 100755 index 4dbc663..0000000 --- a/src/shared/initializible.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub trait Initializible { - fn init() -> Self; -} diff --git a/src/shared/mod.rs b/src/shared/mod.rs index 41384ef..0cc077a 100755 --- a/src/shared/mod.rs +++ b/src/shared/mod.rs @@ -1,4 +1,3 @@ -pub(crate) mod initializible; pub(crate) mod priority; pub(crate) mod runtime; pub(crate) mod sharedfuncs; diff --git a/src/shared/runtime.rs b/src/shared/runtime.rs index 37598f6..a4d8f31 100755 --- a/src/shared/runtime.rs +++ b/src/shared/runtime.rs @@ -1,17 +1,12 @@ +use parking_lot::Mutex; + use crate::{ - async_runtime::{executor::Executor, task::Task}, + async_runtime::{exec::Executor, task::Task}, async_stream::AsyncStream, executors::block_task, - shared::{initializible::Initializible, priority::Priority}, -}; -use parking_lot::Mutex; -use std::{ - future::Future, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, + shared::priority::Priority, }; +use std::{future::Future, sync::Arc}; type TaskQueue = Arc>>; @@ -19,18 +14,6 @@ pub struct RuntimeEngine { tasks: TaskQueue, runtime: Executor, stream: AsyncStream, - wait_flag: Arc, -} - -impl Initializible for RuntimeEngine { - fn init() -> Self { - Self { - tasks: Arc::new(Mutex::new(vec![])), - stream: AsyncStream::new(), - runtime: Executor::default(), - wait_flag: Arc::new(AtomicBool::new(false)), - } - } } impl RuntimeEngine { @@ -39,14 +22,12 @@ impl RuntimeEngine { tasks: Arc::new(Mutex::new(vec![])), stream: AsyncStream::new(), runtime: Executor::new(count), - wait_flag: Arc::new(AtomicBool::new(false)), } } } impl RuntimeEngine { pub(crate) fn cancel(&mut self) { - self.store(true); self.runtime.cancel(); self.tasks.lock().clear(); self.stream.cancel_tasks(); @@ -62,6 +43,7 @@ impl RuntimeEngine { pub(crate) fn end(&mut self) { self.runtime.cancel(); self.tasks.lock().clear(); + self.runtime.end() } } @@ -69,9 +51,9 @@ impl RuntimeEngine { pub(crate) fn wait_for_all_tasks(&self) { self.poll(); self.runtime.cancel(); - self.tasks.lock().sort_by(|lhs, rhs| lhs.0.cmp(&rhs.0)); - self.store(true); - while let Some((_, handle)) = self.tasks.lock().pop() { + let mut lock = self.tasks.lock(); + lock.sort_by(|lhs, rhs| lhs.0.cmp(&rhs.0)); + while let Some((_, handle)) = lock.pop() { self.runtime.submit(move || { block_task(handle); }); @@ -80,25 +62,11 @@ impl RuntimeEngine { } } -impl RuntimeEngine { - pub(crate) fn load(&self) -> bool { - self.wait_flag.load(Ordering::Acquire) - } - - pub(crate) fn store(&self, val: bool) { - self.wait_flag.store(val, Ordering::Release); - } -} - impl RuntimeEngine { pub(crate) fn write_task(&self, priority: Priority, task: F) where F: Future + Send + 'static, { - if self.load() { - self.runtime.start(); - self.store(false); - } self.stream.increment(); let mut stream: AsyncStream = self.stream(); let runtime = self.runtime.clone(); diff --git a/src/spawn_group.rs b/src/spawn_group.rs index 4c42947..9611705 100755 --- a/src/spawn_group.rs +++ b/src/spawn_group.rs @@ -1,5 +1,5 @@ use crate::shared::{ - initializible::Initializible, priority::Priority, runtime::RuntimeEngine, sharedfuncs::Shared, + priority::Priority, runtime::RuntimeEngine, sharedfuncs::Shared, wait::Waitable, }; use async_trait::async_trait; @@ -201,17 +201,6 @@ impl Drop for SpawnGroup { } } -impl Initializible for SpawnGroup { - fn init() -> Self { - SpawnGroup { - runtime: RuntimeEngine::init(), - is_cancelled: false, - count: Arc::new(AtomicUsize::new(0)), - wait_at_drop: true, - } - } -} - impl Shared for SpawnGroup { type Result = ValueType; diff --git a/src/threadpool_impl/channel.rs b/src/threadpool_impl/channel.rs new file mode 100644 index 0000000..9a75186 --- /dev/null +++ b/src/threadpool_impl/channel.rs @@ -0,0 +1,87 @@ +use std::{ + collections::VecDeque, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, Condvar, Mutex, + }, +}; + +#[derive(Default)] +pub struct Channel { + pair: Arc<(Mutex>, Condvar)>, + closed: Arc, +} + +impl Channel { + pub fn enqueue(&self, value: ItemType) -> bool { + if self.closed.load(Ordering::Relaxed) { + return false; + } + if let Ok(mut lock) = self.pair.0.lock() { + lock.push_back(value); + self.pair.1.notify_one(); + return true; + } + false + } +} + +impl Channel { + pub fn new() -> Self { + Self { + pair: Arc::new((Mutex::new(VecDeque::new()), Condvar::new())), + closed: Arc::new(AtomicBool::new(false)), + } + } +} + +impl Clone for Channel { + fn clone(&self) -> Self { + Self { + pair: self.pair.clone(), + closed: self.closed.clone(), + } + } +} + +impl Channel { + pub fn dequeue(&self) -> Option { + if self.closed.load(Ordering::Relaxed) { + return None; + } + let Ok(mut lock) = self.pair.0.lock() else { + return None; + }; + while lock.is_empty() { + if self.closed.load(Ordering::Relaxed) { + return None; + } + lock = self.pair.1.wait(lock).unwrap(); + } + lock.pop_front() + } +} + +impl Channel { + /// + pub fn close(&self) { + if let Ok(_lock) = self.pair.0.lock() { + self.closed.store(true, Ordering::Relaxed); + self.pair.1.notify_all(); + } + } + + pub fn clear(&self) { + if let Ok(mut lock) = self.pair.0.lock() { + lock.clear(); + } + } +} + +impl Iterator for Channel { + type Item = ItemType; + + fn next(&mut self) -> Option { + self.dequeue() + } +} diff --git a/src/threadpool_impl/index.rs b/src/threadpool_impl/index.rs new file mode 100644 index 0000000..f27d31d --- /dev/null +++ b/src/threadpool_impl/index.rs @@ -0,0 +1,31 @@ +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; + +#[derive(Clone)] +pub(crate) struct Indexer { + index: Arc, + last_index: usize, +} + +impl Indexer { + pub(crate) fn new(count: usize) -> Self { + Indexer { + index: Arc::new(AtomicUsize::new(0)), + last_index: count - 1, + } + } +} + +impl Indexer { + pub(crate) fn next(&self) -> usize { + if let Ok(_) = + self.index + .compare_exchange(self.last_index, 0, Ordering::SeqCst, Ordering::SeqCst) + { + return 0; + } + self.index.fetch_add(1, Ordering::SeqCst) + } +} diff --git a/src/threadpool_impl/iteratorimpl.rs b/src/threadpool_impl/iteratorimpl.rs deleted file mode 100644 index b2ee6a4..0000000 --- a/src/threadpool_impl/iteratorimpl.rs +++ /dev/null @@ -1,12 +0,0 @@ -use super::{QueueOperation, ThreadSafeQueue, Func}; - -impl Iterator for ThreadSafeQueue> { - type Item = QueueOperation; - - fn next(&mut self) -> Option { - let Some(value) = self.dequeue() else { - return Some(QueueOperation::NotYet); - }; - Some(value) - } -} diff --git a/src/threadpool_impl/mod.rs b/src/threadpool_impl/mod.rs index de45d18..5730af2 100644 --- a/src/threadpool_impl/mod.rs +++ b/src/threadpool_impl/mod.rs @@ -1,11 +1,9 @@ -mod iteratorimpl; -mod queue; -mod queueops; -mod threadpool; +mod channel; +mod index; mod thread; +mod threadpool; pub(crate) type Func = dyn FnOnce() + Send; -pub(crate) use queue::ThreadSafeQueue; -pub(crate) use queueops::QueueOperation; +pub(crate) use channel::Channel; pub(crate) use threadpool::ThreadPool; diff --git a/src/threadpool_impl/queue.rs b/src/threadpool_impl/queue.rs deleted file mode 100644 index 28dd70a..0000000 --- a/src/threadpool_impl/queue.rs +++ /dev/null @@ -1,42 +0,0 @@ -use std::{ - collections::VecDeque, - sync::{Arc, Mutex}, -}; - -#[derive(Default)] -pub(crate) struct ThreadSafeQueue { - buffer: Arc>>, -} - -impl ThreadSafeQueue { - pub fn enqueue(&self, value: ItemType) { - if let Ok(mut lock) = self.buffer.lock() { - lock.push_back(value); - } - } -} - -impl ThreadSafeQueue { - pub fn new() -> Self { - Self { - buffer: Arc::new(Mutex::new(VecDeque::new())), - } - } -} - -impl Clone for ThreadSafeQueue { - fn clone(&self) -> Self { - Self { - buffer: self.buffer.clone(), - } - } -} - -impl ThreadSafeQueue { - pub fn dequeue(&self) -> Option { - let Ok(mut buffer_lock) = self.buffer.lock() else { - return None; - }; - buffer_lock.pop_front() - } -} diff --git a/src/threadpool_impl/queueops.rs b/src/threadpool_impl/queueops.rs deleted file mode 100644 index d2a4ecd..0000000 --- a/src/threadpool_impl/queueops.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub(crate) enum QueueOperation { - Ready(Box), - NotYet, - Wait, -} \ No newline at end of file diff --git a/src/threadpool_impl/thread.rs b/src/threadpool_impl/thread.rs index e972ec1..a729136 100644 --- a/src/threadpool_impl/thread.rs +++ b/src/threadpool_impl/thread.rs @@ -1,23 +1,40 @@ -use std::thread; +use std::thread::{spawn, JoinHandle}; + +use super::{Channel, Func}; pub(crate) struct UniqueThread { - handle: thread::JoinHandle<()>, + channel: Channel>, + handle: JoinHandle<()>, } impl UniqueThread { - pub(crate) fn new(name: String, task: Task) -> Self { - let handle = thread::Builder::new() - .name(name) - .spawn(move || { - task(); - }) - .unwrap(); - UniqueThread { handle } + pub(crate) fn new() -> Self { + let channel: Channel> = Channel::new(); + let chan = channel.clone(); + let handle = spawn(move || { + while let Some(ops) = chan.dequeue() { + ops() + } + }); + UniqueThread { channel, handle } } } impl UniqueThread { pub(crate) fn join(self) { - _ = self.handle.join(); + self.channel.close(); + self.channel.clear(); + _ = self.handle.join().unwrap(); + } + + pub(crate) fn submit(&self, task: Task) + where + Task: FnOnce() + Send + 'static, + { + self.channel.enqueue(Box::new(task)); + } + + pub(crate) fn clear(&self) { + self.channel.clear(); } } diff --git a/src/threadpool_impl/threadpool.rs b/src/threadpool_impl/threadpool.rs index dda17e9..d15d6c6 100644 --- a/src/threadpool_impl/threadpool.rs +++ b/src/threadpool_impl/threadpool.rs @@ -1,62 +1,24 @@ -use std::{ - backtrace, panic, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, Barrier, - }, - thread, -}; +use std::sync::{Arc, Barrier}; -use super::{queueops::QueueOperation, thread::UniqueThread, Func, ThreadSafeQueue}; +use super::{index::Indexer, thread::UniqueThread}; pub struct ThreadPool { handles: Vec, - count: usize, - queue: ThreadSafeQueue>, + indexer: Indexer, barrier: Arc, - stop_flag: Arc, -} - -impl Default for ThreadPool { - fn default() -> Self { - panic_hook(); - let queue = ThreadSafeQueue::new(); - let count: usize; - if let Ok(thread_count) = thread::available_parallelism() { - count = thread_count.get(); - } else { - count = 1; - } - let barrier = Arc::new(Barrier::new(count + 1)); - let stop_flag = Arc::new(AtomicBool::new(false)); - let handles = (0..count) - .map(|index| start(index, queue.clone(), barrier.clone(), stop_flag.clone())) - .collect(); - ThreadPool { - handles, - queue, - count, - barrier, - stop_flag, - } - } } impl ThreadPool { pub(crate) fn new(count: usize) -> Self { - panic_hook(); - let queue = ThreadSafeQueue::new(); - let barrier = Arc::new(Barrier::new(count + 1)); - let stop_flag = Arc::new(AtomicBool::new(false)); - let handles = (0..count) - .map(|index| start(index, queue.clone(), barrier.clone(), stop_flag.clone())) - .collect(); + let mut handles = vec![]; + handles.reserve(count); + for _ in 1..=count { + handles.push(UniqueThread::new()); + } ThreadPool { handles, - queue, - count, - barrier, - stop_flag, + indexer: Indexer::new(count), + barrier: Arc::new(Barrier::new(count + 1)), } } } @@ -66,68 +28,30 @@ impl ThreadPool { where Task: FnOnce() + 'static + Send, { - self.queue.enqueue(QueueOperation::Ready(Box::new(task))); + self.handles[self.indexer.next()].submit(task); } -} -impl ThreadPool { - pub fn wait_for_all(&self) { - for _ in 0..self.count { - self.queue.enqueue(QueueOperation::Wait); - } - self.barrier.wait(); + pub fn clear(&self) { + self.handles.iter().for_each(|handles| handles.clear()); } } impl ThreadPool { - fn cancel_all(&self) { - self.stop_flag - .store(true, std::sync::atomic::Ordering::Release) + pub fn wait_for_all(&self) { + self.handles.iter().for_each(|handle| { + let barrier = self.barrier.clone(); + handle.submit(move || { + barrier.wait(); + }); + }); + self.barrier.wait(); } } impl Drop for ThreadPool { fn drop(&mut self) { - _ = panic::take_hook(); - self.cancel_all(); while let Some(handle) = self.handles.pop() { - handle.join(); + handle.join() } } } - -fn start( - index: usize, - queue: ThreadSafeQueue>, - barrier: Arc, - stop_flag: Arc, -) -> UniqueThread { - UniqueThread::new(format!("ThreadPool #{}", index), move || { - for op in queue { - match (op, stop_flag.load(Ordering::Acquire)) { - (QueueOperation::NotYet, false) => continue, - (QueueOperation::Ready(work), false) => { - work(); - } - (QueueOperation::Wait, false) => _ = barrier.wait(), - _ => { - return; - } - } - } - }) -} - -fn panic_hook() { - panic::set_hook(Box::new(move |info: &panic::PanicInfo<'_>| { - let msg = format!( - "{} panicked at location {} with {} \nBacktrace:\n{}", - thread::current().name().unwrap(), - info.location().unwrap(), - info.to_string().split('\n').collect::>()[1], - backtrace::Backtrace::capture() - ); - eprintln!("{}", msg); - _ = panic::take_hook(); - })); -} From 63118a0d26944d63798cb90f2f5433687b24d791 Mon Sep 17 00:00:00 2001 From: Genaro-Chris <37796152+Genaro-Chris@users.noreply.github.com> Date: Sun, 9 Jun 2024 14:50:53 +0100 Subject: [PATCH 02/10] Update workflow with miri testing --- .github/workflows/rust.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index bcb8cac..44d1060 100755 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -35,3 +35,5 @@ jobs: run: cargo build - name: Test run: cargo test + - name: Miri Test + run: cargo miri Test From 640e7823f7b0de877d5376e85adc6aad4a85f5c0 Mon Sep 17 00:00:00 2001 From: Genaro-Chris <37796152+Genaro-Chris@users.noreply.github.com> Date: Sun, 9 Jun 2024 14:57:11 +0100 Subject: [PATCH 03/10] Install miri --- .github/workflows/rust.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 44d1060..cc62359 100755 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -31,6 +31,11 @@ jobs: profile: minimal toolchain: nightly override: true + - name: Install Miri + run: | + rustup toolchain install nightly --component miri + rustup override set nightly + cargo miri test - name: Build run: cargo build - name: Test From 6f0ca5eea2e2ec6836158dd1f7f978671d4ad4d2 Mon Sep 17 00:00:00 2001 From: Genaro-Chris <37796152+Genaro-Chris@users.noreply.github.com> Date: Sun, 9 Jun 2024 15:08:37 +0100 Subject: [PATCH 04/10] Update rust.yml --- .github/workflows/rust.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index cc62359..cb19208 100755 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -41,4 +41,4 @@ jobs: - name: Test run: cargo test - name: Miri Test - run: cargo miri Test + run: cargo miri test From 8c53efb468cbb17c4fd7c8d9400f4bc0c0a46d58 Mon Sep 17 00:00:00 2001 From: Genaro-Chris Date: Mon, 10 Jun 2024 11:45:20 +0100 Subject: [PATCH 05/10] minor fixes --- Cargo.toml | 1 - src/async_runtime/{exec.rs => executor.rs} | 23 +++++++++++------ src/async_runtime/mod.rs | 2 +- src/async_runtime/task.rs | 7 +++--- src/async_stream/mod.rs | 2 +- src/discarding_spawn_group.rs | 12 +++------ src/err_spawn_group.rs | 12 ++++----- src/executors/task_executor.rs | 5 +++- src/shared/runtime.rs | 29 ++++++++++++---------- src/spawn_group.rs | 12 ++++----- src/threadpool_impl/index.rs | 10 +++----- src/threadpool_impl/thread.rs | 14 +++-------- src/threadpool_impl/threadpool.rs | 8 ++---- 13 files changed, 63 insertions(+), 74 deletions(-) rename src/async_runtime/{exec.rs => executor.rs} (77%) diff --git a/Cargo.toml b/Cargo.toml index ff15669..ac632f8 100755 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,5 @@ publish = true [dependencies] async-trait = "0.1.73" cooked-waker = "5.0.0" -parking_lot = "0.12.1" futures-lite = "1.13.0" async-mutex = "1.4.0" diff --git a/src/async_runtime/exec.rs b/src/async_runtime/executor.rs similarity index 77% rename from src/async_runtime/exec.rs rename to src/async_runtime/executor.rs index 2dad39b..ad2b858 100644 --- a/src/async_runtime/exec.rs +++ b/src/async_runtime/executor.rs @@ -9,7 +9,10 @@ use cooked_waker::IntoWaker; use std::{ future::Future, - sync::Arc, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, task::{Context, Poll, Waker}, }; @@ -17,6 +20,7 @@ use std::{ pub struct Executor { pool: Arc, queue: Channel, + cancelled: Arc, } impl Executor { @@ -24,8 +28,9 @@ impl Executor { let result: Executor = Self { pool: Arc::new(ThreadPool::new(count)), queue: Channel::new(), + cancelled: Arc::new(AtomicBool::new(false)), }; - let result_clone = result.clone(); + let result_clone: Executor = result.clone(); std::thread::spawn(move || { result_clone.run(); }); @@ -45,22 +50,25 @@ impl Executor { where F: Future + Send + 'static, { - let task = Task::new(task); + let task: Task = Task::new(task); self.queue.enqueue(task.clone()); task } pub(crate) fn cancel(&self) { - self.pool.clear(); self.queue.clear(); + self.cancelled.store(true, Ordering::Relaxed); self.poll_all(); - self.pool.clear(); self.queue.clear(); + self.cancelled.store(false, Ordering::Relaxed); } fn run(&self) { while let Some(task) = self.queue.dequeue() { - let queue = self.queue.clone(); + if self.cancelled.load(Ordering::Acquire) { + continue; + } + let queue: Channel = self.queue.clone(); self.submit(move || { let waker: Waker = Arc::new(Notifier::default()).into_waker(); pin_future!(task); @@ -80,8 +88,7 @@ impl Executor { } pub(crate) fn end(&mut self) { - self.queue.clear(); self.queue.close(); - //self.pool.drop_pool(); + self.queue.clear(); } } diff --git a/src/async_runtime/mod.rs b/src/async_runtime/mod.rs index e4e8d24..786620b 100755 --- a/src/async_runtime/mod.rs +++ b/src/async_runtime/mod.rs @@ -1,4 +1,4 @@ -pub(crate) mod exec; +pub(crate) mod executor; pub(crate) mod notifier; mod pin_macro; pub(crate) mod task; diff --git a/src/async_runtime/task.rs b/src/async_runtime/task.rs index c303efd..33886eb 100755 --- a/src/async_runtime/task.rs +++ b/src/async_runtime/task.rs @@ -1,10 +1,9 @@ -use parking_lot::Mutex; use std::{ future::Future, pin::Pin, sync::{ atomic::{AtomicBool, Ordering}, - Arc, + Arc, Mutex, }, task::Poll, }; @@ -30,7 +29,7 @@ impl Task { } fn complete(&self) { - self.complete.store(true, Ordering::Release); + self.complete.store(true, Ordering::Relaxed); } } @@ -40,7 +39,7 @@ impl Future for Task { self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll { - match self.future.lock().as_mut().poll(cx) { + match self.future.lock().unwrap().as_mut().poll(cx) { Poll::Ready(()) => { self.complete(); Poll::Ready(()) diff --git a/src/async_stream/mod.rs b/src/async_stream/mod.rs index 58c7e52..e0b2b98 100755 --- a/src/async_stream/mod.rs +++ b/src/async_stream/mod.rs @@ -71,7 +71,7 @@ impl AsyncStream { pub(crate) fn cancel_tasks(&mut self) { self.cancelled = true; - self.counts.1.store(0, Ordering::Release); + self.counts.1.store(0, Ordering::Relaxed); } } diff --git a/src/discarding_spawn_group.rs b/src/discarding_spawn_group.rs index 3b564c7..146d68a 100755 --- a/src/discarding_spawn_group.rs +++ b/src/discarding_spawn_group.rs @@ -1,6 +1,4 @@ -use crate::shared::{ - priority::Priority, runtime::RuntimeEngine, sharedfuncs::Shared, -}; +use crate::shared::{priority::Priority, runtime::RuntimeEngine, sharedfuncs::Shared}; use std::future::Future; @@ -34,7 +32,7 @@ impl DiscardingSpawnGroup { impl DiscardingSpawnGroup { /// Instantiates `DiscardingSpawnGroup` with a specific number of threads to use in the underlying threadpool when polling futures - /// + /// /// # Parameters /// /// * `num_of_threads`: number of threads to use @@ -42,7 +40,7 @@ impl DiscardingSpawnGroup { Self { is_cancelled: false, runtime: RuntimeEngine::new(num_of_threads), - wait_at_drop: false, + wait_at_drop: true, } } } @@ -102,9 +100,8 @@ impl Drop for DiscardingSpawnGroup { fn drop(&mut self) { if self.wait_at_drop { self.runtime.wait_for_all_tasks(); - } else { - self.runtime.end() } + self.runtime.end() } } @@ -132,4 +129,3 @@ impl Shared for DiscardingSpawnGroup { self.is_cancelled = true; } } - diff --git a/src/err_spawn_group.rs b/src/err_spawn_group.rs index c4df1f0..7856c58 100755 --- a/src/err_spawn_group.rs +++ b/src/err_spawn_group.rs @@ -1,6 +1,5 @@ use crate::shared::{ - priority::Priority, runtime::RuntimeEngine, sharedfuncs::Shared, - wait::Waitable, + priority::Priority, runtime::RuntimeEngine, sharedfuncs::Shared, wait::Waitable, }; use async_trait::async_trait; use futures_lite::{Stream, StreamExt}; @@ -39,7 +38,7 @@ pub struct ErrSpawnGroup { impl ErrSpawnGroup { /// Instantiates `ErrSpawnGroup` with a specific number of threads to use in the underlying threadpool when polling futures - /// + /// /// # Parameters /// /// * `num_of_threads`: number of threads to use @@ -48,7 +47,7 @@ impl ErrSpawnGroup { is_cancelled: false, count: Arc::new(AtomicUsize::new(0)), runtime: RuntimeEngine::new(num_of_threads), - wait_at_drop: false, + wait_at_drop: true, } } } @@ -129,7 +128,7 @@ impl ErrSpawnGroup { } fn decrement_count_to_zero(&self) { - self.count.store(0, Ordering::Release); + self.count.store(0, Ordering::Relaxed); } } @@ -199,9 +198,8 @@ impl Drop for ErrSpawnGroup, waker: &Waker) return; } let mut context: Context<'_> = Context::from_waker(waker); + let Ok(mut task) = task.future.lock() else { + return; + }; loop { - match task.future.lock().as_mut().poll(&mut context) { + match task.as_mut().poll(&mut context) { std::task::Poll::Ready(()) => return, std::task::Poll::Pending => notifier.wait(), } diff --git a/src/shared/runtime.rs b/src/shared/runtime.rs index a4d8f31..a20fc19 100755 --- a/src/shared/runtime.rs +++ b/src/shared/runtime.rs @@ -1,12 +1,13 @@ -use parking_lot::Mutex; - use crate::{ - async_runtime::{exec::Executor, task::Task}, + async_runtime::{executor::Executor, task::Task}, async_stream::AsyncStream, executors::block_task, shared::priority::Priority, }; -use std::{future::Future, sync::Arc}; +use std::{ + future::Future, + sync::{Arc, Mutex}, +}; type TaskQueue = Arc>>; @@ -29,7 +30,7 @@ impl RuntimeEngine { impl RuntimeEngine { pub(crate) fn cancel(&mut self) { self.runtime.cancel(); - self.tasks.lock().clear(); + self.tasks.lock().unwrap().clear(); self.stream.cancel_tasks(); self.poll(); } @@ -42,7 +43,7 @@ impl RuntimeEngine { pub(crate) fn end(&mut self) { self.runtime.cancel(); - self.tasks.lock().clear(); + self.tasks.lock().unwrap().clear(); self.runtime.end() } } @@ -51,13 +52,15 @@ impl RuntimeEngine { pub(crate) fn wait_for_all_tasks(&self) { self.poll(); self.runtime.cancel(); - let mut lock = self.tasks.lock(); - lock.sort_by(|lhs, rhs| lhs.0.cmp(&rhs.0)); - while let Some((_, handle)) = lock.pop() { - self.runtime.submit(move || { - block_task(handle); - }); + if let Ok(mut lock) = self.tasks.lock() { + lock.sort_by(|lhs, rhs| lhs.0.cmp(&rhs.0)); + while let Some((_, handle)) = lock.pop() { + self.runtime.submit(move || { + block_task(handle); + }); + } } + self.poll(); } } @@ -72,7 +75,7 @@ impl RuntimeEngine { let runtime = self.runtime.clone(); let tasks: Arc>> = self.tasks.clone(); self.runtime.submit(move || { - tasks.lock().push(( + tasks.lock().unwrap().push(( priority, runtime.spawn(async move { stream.insert_item(task.await).await; diff --git a/src/spawn_group.rs b/src/spawn_group.rs index 9611705..b8d757a 100755 --- a/src/spawn_group.rs +++ b/src/spawn_group.rs @@ -1,6 +1,5 @@ use crate::shared::{ - priority::Priority, runtime::RuntimeEngine, sharedfuncs::Shared, - wait::Waitable, + priority::Priority, runtime::RuntimeEngine, sharedfuncs::Shared, wait::Waitable, }; use async_trait::async_trait; use futures_lite::{Stream, StreamExt}; @@ -40,7 +39,7 @@ pub struct SpawnGroup { impl SpawnGroup { /// Instantiates `SpawnGroup` with a specific number of threads to use in the underlying threadpool when polling futures - /// + /// /// # Parameters /// /// * `num_of_threads`: number of threads to use @@ -49,7 +48,7 @@ impl SpawnGroup { is_cancelled: false, count: Arc::new(AtomicUsize::new(0)), runtime: RuntimeEngine::new(num_of_threads), - wait_at_drop: false, + wait_at_drop: true, } } } @@ -118,7 +117,7 @@ impl SpawnGroup { } fn decrement_count_to_zero(&self) { - self.count.store(0, Ordering::Release); + self.count.store(0, Ordering::Relaxed); } } @@ -195,9 +194,8 @@ impl Drop for SpawnGroup { fn drop(&mut self) { if self.wait_at_drop { self.runtime.wait_for_all_tasks(); - } else { - self.runtime.end() } + self.runtime.end() } } diff --git a/src/threadpool_impl/index.rs b/src/threadpool_impl/index.rs index f27d31d..d630209 100644 --- a/src/threadpool_impl/index.rs +++ b/src/threadpool_impl/index.rs @@ -1,18 +1,14 @@ -use std::sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, -}; +use std::sync::atomic::{AtomicUsize, Ordering}; -#[derive(Clone)] pub(crate) struct Indexer { - index: Arc, + index: AtomicUsize, last_index: usize, } impl Indexer { pub(crate) fn new(count: usize) -> Self { Indexer { - index: Arc::new(AtomicUsize::new(0)), + index: AtomicUsize::new(0), last_index: count - 1, } } diff --git a/src/threadpool_impl/thread.rs b/src/threadpool_impl/thread.rs index a729136..5e7579c 100644 --- a/src/threadpool_impl/thread.rs +++ b/src/threadpool_impl/thread.rs @@ -1,22 +1,21 @@ -use std::thread::{spawn, JoinHandle}; +use std::thread::spawn; use super::{Channel, Func}; pub(crate) struct UniqueThread { channel: Channel>, - handle: JoinHandle<()>, } impl UniqueThread { pub(crate) fn new() -> Self { let channel: Channel> = Channel::new(); - let chan = channel.clone(); - let handle = spawn(move || { + let chan: Channel> = channel.clone(); + spawn(move || { while let Some(ops) = chan.dequeue() { ops() } }); - UniqueThread { channel, handle } + UniqueThread { channel } } } @@ -24,7 +23,6 @@ impl UniqueThread { pub(crate) fn join(self) { self.channel.close(); self.channel.clear(); - _ = self.handle.join().unwrap(); } pub(crate) fn submit(&self, task: Task) @@ -33,8 +31,4 @@ impl UniqueThread { { self.channel.enqueue(Box::new(task)); } - - pub(crate) fn clear(&self) { - self.channel.clear(); - } } diff --git a/src/threadpool_impl/threadpool.rs b/src/threadpool_impl/threadpool.rs index d15d6c6..fad7e78 100644 --- a/src/threadpool_impl/threadpool.rs +++ b/src/threadpool_impl/threadpool.rs @@ -10,7 +10,7 @@ pub struct ThreadPool { impl ThreadPool { pub(crate) fn new(count: usize) -> Self { - let mut handles = vec![]; + let mut handles: Vec = vec![]; handles.reserve(count); for _ in 1..=count { handles.push(UniqueThread::new()); @@ -30,16 +30,12 @@ impl ThreadPool { { self.handles[self.indexer.next()].submit(task); } - - pub fn clear(&self) { - self.handles.iter().for_each(|handles| handles.clear()); - } } impl ThreadPool { pub fn wait_for_all(&self) { self.handles.iter().for_each(|handle| { - let barrier = self.barrier.clone(); + let barrier: Arc = self.barrier.clone(); handle.submit(move || { barrier.wait(); }); From e4568db2a1a250c732feaf2628d070b47b0384a3 Mon Sep 17 00:00:00 2001 From: Genaro-Chris Date: Wed, 12 Jun 2024 01:29:44 +0100 Subject: [PATCH 06/10] remove some items --- src/async_runtime/executor.rs | 69 ++++++++++++++----------------- src/async_runtime/task.rs | 4 +- src/executors/mod.rs | 2 +- src/shared/runtime.rs | 17 ++++---- src/threadpool_impl/index.rs | 27 ------------ src/threadpool_impl/mod.rs | 2 - src/threadpool_impl/thread.rs | 34 --------------- src/threadpool_impl/threadpool.rs | 58 +++++++++++++++++++------- 8 files changed, 83 insertions(+), 130 deletions(-) delete mode 100644 src/threadpool_impl/index.rs delete mode 100644 src/threadpool_impl/thread.rs diff --git a/src/async_runtime/executor.rs b/src/async_runtime/executor.rs index ad2b858..1a913e7 100644 --- a/src/async_runtime/executor.rs +++ b/src/async_runtime/executor.rs @@ -1,7 +1,4 @@ -use crate::{ - pin_future, - threadpool_impl::{Channel, ThreadPool}, -}; +use crate::{pin_future, threadpool_impl::ThreadPool}; use super::{notifier::Notifier, task::Task}; @@ -16,10 +13,8 @@ use std::{ task::{Context, Poll, Waker}, }; -#[derive(Clone)] pub struct Executor { pool: Arc, - queue: Channel, cancelled: Arc, } @@ -27,13 +22,8 @@ impl Executor { pub(crate) fn new(count: usize) -> Self { let result: Executor = Self { pool: Arc::new(ThreadPool::new(count)), - queue: Channel::new(), cancelled: Arc::new(AtomicBool::new(false)), }; - let result_clone: Executor = result.clone(); - std::thread::spawn(move || { - result_clone.run(); - }); result } } @@ -51,44 +41,47 @@ impl Executor { F: Future + Send + 'static, { let task: Task = Task::new(task); - self.queue.enqueue(task.clone()); + self.spawn_task(task.clone()); task } + fn spawn_task(&self, task: Task) { + let executor = self.clone(); + if self.cancelled.load(Ordering::Acquire) { + return; + } + self.submit(move || { + let waker: Waker = Arc::new(Notifier::default()).into_waker(); + pin_future!(task); + let mut cx: Context<'_> = Context::from_waker(&waker); + match task.as_mut().poll(&mut cx) { + Poll::Ready(()) => task.complete(), + Poll::Pending => { + cx.waker().wake_by_ref(); + executor.spawn_task(task.clone()); + } + } + }); + } + + fn clone(&self) -> Self { + Self { + pool: self.pool.clone(), + cancelled: self.cancelled.clone(), + } + } + pub(crate) fn cancel(&self) { - self.queue.clear(); + self.pool.clear(); self.cancelled.store(true, Ordering::Relaxed); self.poll_all(); - self.queue.clear(); + self.pool.clear(); self.cancelled.store(false, Ordering::Relaxed); } - fn run(&self) { - while let Some(task) = self.queue.dequeue() { - if self.cancelled.load(Ordering::Acquire) { - continue; - } - let queue: Channel = self.queue.clone(); - self.submit(move || { - let waker: Waker = Arc::new(Notifier::default()).into_waker(); - pin_future!(task); - let mut cx: Context<'_> = Context::from_waker(&waker); - match task.as_mut().poll(&mut cx) { - Poll::Ready(()) => (), - Poll::Pending => { - queue.enqueue(task.clone()); - } - } - }); - } - } - pub(crate) fn poll_all(&self) { self.pool.wait_for_all(); } - pub(crate) fn end(&mut self) { - self.queue.close(); - self.queue.clear(); - } + pub(crate) fn end(&mut self) {} } diff --git a/src/async_runtime/task.rs b/src/async_runtime/task.rs index 33886eb..207cd33 100755 --- a/src/async_runtime/task.rs +++ b/src/async_runtime/task.rs @@ -13,7 +13,7 @@ type LocalBoxedFuture = Pin + Send + 'static>>; #[derive(Clone)] pub struct Task { pub(crate) future: Arc>, - pub(crate) complete: Arc, + complete: Arc, } impl Task { @@ -28,7 +28,7 @@ impl Task { self.complete.load(Ordering::Acquire) } - fn complete(&self) { + pub(crate) fn complete(&self) { self.complete.store(true, Ordering::Relaxed); } } diff --git a/src/executors/mod.rs b/src/executors/mod.rs index 9de0074..bff9ca4 100755 --- a/src/executors/mod.rs +++ b/src/executors/mod.rs @@ -36,7 +36,7 @@ pub fn block_on(future: Fut) -> Fut::Output { pub(crate) fn block_task(task: Task) { let waker_pair: Result<(Arc, Waker), std::thread::AccessError> = - local_executor::WAKER_PAIR + task_executor::WAKER_PAIR .try_with(|waker_pair: &(Arc, Waker)| waker_pair.clone()); match waker_pair { Ok((notifier, waker)) => block_on_task(task, notifier, &waker), diff --git a/src/shared/runtime.rs b/src/shared/runtime.rs index a20fc19..eb2b10d 100755 --- a/src/shared/runtime.rs +++ b/src/shared/runtime.rs @@ -72,17 +72,14 @@ impl RuntimeEngine { { self.stream.increment(); let mut stream: AsyncStream = self.stream(); - let runtime = self.runtime.clone(); let tasks: Arc>> = self.tasks.clone(); - self.runtime.submit(move || { - tasks.lock().unwrap().push(( - priority, - runtime.spawn(async move { - stream.insert_item(task.await).await; - stream.decrement_task_count(); - }), - )); - }); + tasks.lock().unwrap().push(( + priority, + self.runtime.spawn(async move { + stream.insert_item(task.await).await; + stream.decrement_task_count(); + }), + )); } } diff --git a/src/threadpool_impl/index.rs b/src/threadpool_impl/index.rs deleted file mode 100644 index d630209..0000000 --- a/src/threadpool_impl/index.rs +++ /dev/null @@ -1,27 +0,0 @@ -use std::sync::atomic::{AtomicUsize, Ordering}; - -pub(crate) struct Indexer { - index: AtomicUsize, - last_index: usize, -} - -impl Indexer { - pub(crate) fn new(count: usize) -> Self { - Indexer { - index: AtomicUsize::new(0), - last_index: count - 1, - } - } -} - -impl Indexer { - pub(crate) fn next(&self) -> usize { - if let Ok(_) = - self.index - .compare_exchange(self.last_index, 0, Ordering::SeqCst, Ordering::SeqCst) - { - return 0; - } - self.index.fetch_add(1, Ordering::SeqCst) - } -} diff --git a/src/threadpool_impl/mod.rs b/src/threadpool_impl/mod.rs index 5730af2..fad412e 100644 --- a/src/threadpool_impl/mod.rs +++ b/src/threadpool_impl/mod.rs @@ -1,6 +1,4 @@ mod channel; -mod index; -mod thread; mod threadpool; pub(crate) type Func = dyn FnOnce() + Send; diff --git a/src/threadpool_impl/thread.rs b/src/threadpool_impl/thread.rs deleted file mode 100644 index 5e7579c..0000000 --- a/src/threadpool_impl/thread.rs +++ /dev/null @@ -1,34 +0,0 @@ -use std::thread::spawn; - -use super::{Channel, Func}; - -pub(crate) struct UniqueThread { - channel: Channel>, -} - -impl UniqueThread { - pub(crate) fn new() -> Self { - let channel: Channel> = Channel::new(); - let chan: Channel> = channel.clone(); - spawn(move || { - while let Some(ops) = chan.dequeue() { - ops() - } - }); - UniqueThread { channel } - } -} - -impl UniqueThread { - pub(crate) fn join(self) { - self.channel.close(); - self.channel.clear(); - } - - pub(crate) fn submit(&self, task: Task) - where - Task: FnOnce() + Send + 'static, - { - self.channel.enqueue(Box::new(task)); - } -} diff --git a/src/threadpool_impl/threadpool.rs b/src/threadpool_impl/threadpool.rs index fad7e78..4845719 100644 --- a/src/threadpool_impl/threadpool.rs +++ b/src/threadpool_impl/threadpool.rs @@ -1,23 +1,30 @@ -use std::sync::{Arc, Barrier}; +use std::{ + backtrace, panic, sync::{Arc, Barrier}, thread::spawn +}; -use super::{index::Indexer, thread::UniqueThread}; +use super::{Channel, Func}; pub struct ThreadPool { - handles: Vec, - indexer: Indexer, + task_channel: Channel>, barrier: Arc, + count: usize, } impl ThreadPool { pub(crate) fn new(count: usize) -> Self { - let mut handles: Vec = vec![]; - handles.reserve(count); + let task_channel: Channel> = Channel::new(); for _ in 1..=count { - handles.push(UniqueThread::new()); + let channel: Channel> = task_channel.clone(); + spawn(move || { + panic_hook(); + while let Some(ops) = channel.dequeue() { + ops() + } + }); } ThreadPool { - handles, - indexer: Indexer::new(count), + task_channel, + count, barrier: Arc::new(Barrier::new(count + 1)), } } @@ -28,17 +35,17 @@ impl ThreadPool { where Task: FnOnce() + 'static + Send, { - self.handles[self.indexer.next()].submit(task); + self.task_channel.enqueue(Box::new(task)); } } impl ThreadPool { pub fn wait_for_all(&self) { - self.handles.iter().for_each(|handle| { + (1..=self.count).for_each(|_| { let barrier: Arc = self.barrier.clone(); - handle.submit(move || { + self.task_channel.enqueue(Box::new(move || { barrier.wait(); - }); + })); }); self.barrier.wait(); } @@ -46,8 +53,27 @@ impl ThreadPool { impl Drop for ThreadPool { fn drop(&mut self) { - while let Some(handle) = self.handles.pop() { - handle.join() - } + _ = panic::take_hook(); + self.task_channel.close(); + self.clear() } } + +impl ThreadPool { + pub fn clear(&self) { + self.task_channel.clear(); + } +} + +fn panic_hook() { + panic::set_hook(Box::new(move |info: &panic::PanicInfo<'_>| { + let msg = format!( + "Threadpool panicked at location {} with {} \nBacktrace:\n{}", + info.location().unwrap(), + info.to_string().split('\n').collect::>()[1], + backtrace::Backtrace::capture() + ); + eprintln!("{}", msg); + _ = panic::take_hook(); + })); +} From 79c2fb4022a07cdf66977da77327ce5eeb80019c Mon Sep 17 00:00:00 2001 From: Genaro-Chris Date: Mon, 22 Jul 2024 22:11:31 +0100 Subject: [PATCH 07/10] Major fix for v2.0 --- Cargo.toml | 1 - src/async_runtime/executor.rs | 107 ++++++++++----- src/async_runtime/mod.rs | 1 - src/async_runtime/task.rs | 32 +++-- src/async_stream/mod.rs | 59 ++++---- src/discarding_spawn_group.rs | 12 ++ src/err_spawn_group.rs | 6 + src/executors/future_executor.rs | 62 +++++++++ src/executors/local_executor.rs | 34 ----- src/executors/mod.rs | 55 ++------ src/{async_runtime => executors}/notifier.rs | 22 ++- src/executors/parker.rs | 135 +++++++++++++++++++ src/executors/task_executor.rs | 72 +++++++--- src/executors/waker.rs | 15 +++ src/executors/waker_traits.rs | 106 +++++++++++++++ src/lib.rs | 1 + src/shared/runtime.rs | 49 ++++--- src/spawn_group.rs | 8 +- src/threadpool_impl/channel.rs | 29 ++-- src/threadpool_impl/mod.rs | 1 + src/threadpool_impl/threadpool.rs | 86 +++++++----- src/threadpool_impl/waitgroup.rs | 42 ++++++ src/yield_now/mod.rs | 14 ++ 23 files changed, 708 insertions(+), 241 deletions(-) mode change 100755 => 100644 Cargo.toml create mode 100644 src/executors/future_executor.rs delete mode 100755 src/executors/local_executor.rs rename src/{async_runtime => executors}/notifier.rs (63%) mode change 100755 => 100644 create mode 100755 src/executors/parker.rs mode change 100755 => 100644 src/executors/task_executor.rs create mode 100644 src/executors/waker.rs create mode 100644 src/executors/waker_traits.rs create mode 100644 src/threadpool_impl/waitgroup.rs diff --git a/Cargo.toml b/Cargo.toml old mode 100755 new mode 100644 index ac632f8..118de02 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,5 @@ publish = true [dependencies] async-trait = "0.1.73" -cooked-waker = "5.0.0" futures-lite = "1.13.0" async-mutex = "1.4.0" diff --git a/src/async_runtime/executor.rs b/src/async_runtime/executor.rs index 1a913e7..2a3ace1 100644 --- a/src/async_runtime/executor.rs +++ b/src/async_runtime/executor.rs @@ -1,8 +1,9 @@ -use crate::{pin_future, threadpool_impl::ThreadPool}; - -use super::{notifier::Notifier, task::Task}; +use crate::{ + executors::{IntoWaker, Notifier}, + threadpool_impl::{Channel, ThreadPool}, +}; -use cooked_waker::IntoWaker; +use super::task::Task; use std::{ future::Future, @@ -16,6 +17,7 @@ use std::{ pub struct Executor { pool: Arc, cancelled: Arc, + task_queue: Channel, } impl Executor { @@ -23,12 +25,30 @@ impl Executor { let result: Executor = Self { pool: Arc::new(ThreadPool::new(count)), cancelled: Arc::new(AtomicBool::new(false)), + task_queue: Channel::new(), }; + result.start(); result } } impl Executor { + fn start(&self) { + let queue = self.task_queue.clone(); + let cancelled = self.cancelled.clone(); + let pool = self.pool.clone(); + std::thread::spawn(move || loop { + let Some(task) = queue.clone().dequeue() else { + return; + }; + if cancelled.load(Ordering::Relaxed) || task.is_cancelled() { + continue; + } + let queue = queue.clone(); + pool.submit(move || block_task_with(task, &queue)) + }); + } + pub(crate) fn submit(&self, task: Task) where Task: FnOnce() + Send + 'static, @@ -41,36 +61,10 @@ impl Executor { F: Future + Send + 'static, { let task: Task = Task::new(task); - self.spawn_task(task.clone()); + self.task_queue.enqueue(task.clone()); task } - fn spawn_task(&self, task: Task) { - let executor = self.clone(); - if self.cancelled.load(Ordering::Acquire) { - return; - } - self.submit(move || { - let waker: Waker = Arc::new(Notifier::default()).into_waker(); - pin_future!(task); - let mut cx: Context<'_> = Context::from_waker(&waker); - match task.as_mut().poll(&mut cx) { - Poll::Ready(()) => task.complete(), - Poll::Pending => { - cx.waker().wake_by_ref(); - executor.spawn_task(task.clone()); - } - } - }); - } - - fn clone(&self) -> Self { - Self { - pool: self.pool.clone(), - cancelled: self.cancelled.clone(), - } - } - pub(crate) fn cancel(&self) { self.pool.clear(); self.cancelled.store(true, Ordering::Relaxed); @@ -83,5 +77,54 @@ impl Executor { self.pool.wait_for_all(); } - pub(crate) fn end(&mut self) {} + pub(crate) fn end(&self) { + self.cancelled.store(true, Ordering::Release); + self.task_queue.close(); + self.pool.end(); + } +} + +thread_local! { + static TASK_WAKER: Waker = { + Arc::new(Notifier::default()).into_waker() + }; +} + +fn block_task_with(future: Task, queue: &Channel) { + if future.is_completed() || future.is_cancelled() { + return; + } + let task = future.clone(); + let waker_result = TASK_WAKER.try_with(|waker| waker.clone()); + match waker_result { + Ok(waker) => { + let mut context: Context<'_> = Context::from_waker(&waker); + let Ok(mut future) = future.lock() else { + return; + }; + match future.as_mut().poll(&mut context) { + Poll::Ready(()) => task.complete(), + Poll::Pending => { + if !task.is_cancelled() { + queue.enqueue(task.clone()); + } + } + } + } + Err(_) => { + let waker: Waker = Arc::new(Notifier::default()).into_waker(); + let mut context: Context<'_> = Context::from_waker(&waker); + let Ok(mut future) = future.lock() else { + return; + }; + match future.as_mut().poll(&mut context) { + Poll::Ready(()) => task.complete(), + Poll::Pending => { + if !task.is_cancelled() { + queue.enqueue(task.clone()); + } + } + } + } + } } diff --git a/src/async_runtime/mod.rs b/src/async_runtime/mod.rs index 786620b..537875a 100755 --- a/src/async_runtime/mod.rs +++ b/src/async_runtime/mod.rs @@ -1,4 +1,3 @@ pub(crate) mod executor; -pub(crate) mod notifier; mod pin_macro; pub(crate) mod task; diff --git a/src/async_runtime/task.rs b/src/async_runtime/task.rs index 207cd33..9b6b5d7 100755 --- a/src/async_runtime/task.rs +++ b/src/async_runtime/task.rs @@ -1,19 +1,21 @@ use std::{ future::Future, + ops::Deref, pin::Pin, sync::{ atomic::{AtomicBool, Ordering}, Arc, Mutex, }, - task::Poll, + task::{Context, Poll}, }; type LocalBoxedFuture = Pin + Send + 'static>>; #[derive(Clone)] pub struct Task { - pub(crate) future: Arc>, + future: Arc>, complete: Arc, + cancelled: Arc, } impl Task { @@ -21,24 +23,38 @@ impl Task { Self { future: Arc::new(Mutex::new(Box::pin(fut))), complete: Arc::new(AtomicBool::new(false)), + cancelled: Arc::new(AtomicBool::new(false)), } } pub(crate) fn is_completed(&self) -> bool { - self.complete.load(Ordering::Acquire) + self.complete.load(Ordering::Relaxed) } pub(crate) fn complete(&self) { - self.complete.store(true, Ordering::Relaxed); + self.complete.store(true, Ordering::Release); + } + + pub(crate) fn cancel(&self) { + self.cancelled.store(true, Ordering::Release); + } + + pub(crate) fn is_cancelled(&self) -> bool { + self.cancelled.load(Ordering::Acquire) + } +} + +impl Deref for Task { + type Target = Arc>; + + fn deref(&self) -> &Self::Target { + &self.future } } impl Future for Task { type Output = (); - fn poll( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.future.lock().unwrap().as_mut().poll(cx) { Poll::Ready(()) => { self.complete(); diff --git a/src/async_stream/mod.rs b/src/async_stream/mod.rs index e0b2b98..8b19750 100755 --- a/src/async_stream/mod.rs +++ b/src/async_stream/mod.rs @@ -2,7 +2,7 @@ use std::{ collections::VecDeque, pin::Pin, sync::{ - atomic::{AtomicUsize, Ordering}, + atomic::{AtomicBool, AtomicUsize, Ordering}, Arc, }, task::{Context, Poll}, @@ -15,16 +15,17 @@ use crate::executors::block_on; pub struct AsyncStream { buffer: Arc>>, - started: bool, - counts: (Arc, Arc), - cancelled: bool, + item_count: Arc, + task_count: Arc, + cancelled: Arc, } impl AsyncStream { - pub(crate) async fn insert_item(&mut self, value: ItemType) { - if !self.started { - self.started = true; - } + #[inline] + pub(crate) async fn insert_item(&self, value: ItemType) { + _ = self + .cancelled + .compare_exchange(true, false, Ordering::Relaxed, Ordering::Relaxed); self.buffer.lock().await.push_back(value); } } @@ -37,41 +38,33 @@ impl AsyncStream { impl AsyncStream { pub(crate) fn increment(&self) { - self.counts.0.fetch_add(1, Ordering::Acquire); - self.counts.1.fetch_add(1, Ordering::Acquire); + self.task_count.fetch_add(1, Ordering::SeqCst); + self.item_count.fetch_add(1, Ordering::SeqCst); } } impl AsyncStream { - pub(crate) async fn first(&mut self) -> Option { + pub async fn first(&mut self) -> Option { self.next().await } } impl AsyncStream { pub(crate) fn task_count(&self) -> usize { - self.counts.1.load(Ordering::Acquire) + self.task_count.load(Ordering::Acquire) } pub(crate) fn decrement_task_count(&self) { - if self.task_count() > 0 { - self.counts.1.fetch_sub(1, Ordering::Acquire); - } + self.task_count.fetch_sub(1, Ordering::SeqCst); } pub(crate) fn item_count(&self) -> usize { - self.counts.0.load(Ordering::Acquire) - } - - pub(crate) fn decrement_count(&self) { - if self.item_count() > 0 { - self.counts.0.fetch_sub(1, Ordering::Acquire); - } + self.item_count.load(Ordering::Acquire) } pub(crate) fn cancel_tasks(&mut self) { - self.cancelled = true; - self.counts.1.store(0, Ordering::Relaxed); + self.cancelled.store(true, Ordering::Release); + self.task_count.store(0, Ordering::Release); } } @@ -79,9 +72,9 @@ impl Clone for AsyncStream { fn clone(&self) -> Self { Self { buffer: self.buffer.clone(), - started: self.started, - counts: self.counts.clone(), - cancelled: self.cancelled, + task_count: self.task_count.clone(), + item_count: self.item_count.clone(), + cancelled: self.cancelled.clone(), } } } @@ -90,9 +83,9 @@ impl AsyncStream { pub(crate) fn new() -> Self { AsyncStream:: { buffer: Arc::new(Mutex::new(VecDeque::new())), - started: false, - counts: (Arc::new(AtomicUsize::new(0)), Arc::new(AtomicUsize::new(0))), - cancelled: false, + item_count: Arc::new(AtomicUsize::new(0)), + task_count: Arc::new(AtomicUsize::new(0)), + cancelled: Arc::new(AtomicBool::new(false)), } } } @@ -103,14 +96,16 @@ impl Stream for AsyncStream { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { block_on(async move { let mut inner_lock: MutexGuard<'_, VecDeque> = self.buffer.lock().await; - if self.cancelled && inner_lock.is_empty() || self.item_count() == 0 { + if self.cancelled.load(Ordering::Relaxed) && inner_lock.is_empty() + || self.item_count() == 0 + { return Poll::Ready(None); } let Some(value) = inner_lock.pop_front() else { cx.waker().wake_by_ref(); return Poll::Pending; }; - self.decrement_count(); + self.item_count.fetch_sub(1, Ordering::SeqCst); Poll::Ready(Some(value)) }) } diff --git a/src/discarding_spawn_group.rs b/src/discarding_spawn_group.rs index 146d68a..1f2d897 100755 --- a/src/discarding_spawn_group.rs +++ b/src/discarding_spawn_group.rs @@ -96,6 +96,18 @@ impl DiscardingSpawnGroup { } } +impl DiscardingSpawnGroup { + /// Waits for all remaining child tasks for finish. + pub async fn wait_for_all(&mut self) { + self.runtime.wait_for_all_tasks(); + } + + /// Waits for all remaining child tasks for finish in non async context. + pub fn wait_non_async(&mut self) { + self.runtime.wait_for_all_tasks(); + } +} + impl Drop for DiscardingSpawnGroup { fn drop(&mut self) { if self.wait_at_drop { diff --git a/src/err_spawn_group.rs b/src/err_spawn_group.rs index 7856c58..8a4e1b1 100755 --- a/src/err_spawn_group.rs +++ b/src/err_spawn_group.rs @@ -116,6 +116,12 @@ impl ErrSpawnGroup { pub async fn wait_for_all(&mut self) { self.wait().await; } + + /// Waits for all remaining child tasks for finish in non async context. + pub fn wait_non_async(&mut self) { + self.runtime.wait_for_all_tasks(); + self.decrement_count_to_zero() + } } impl ErrSpawnGroup { diff --git a/src/executors/future_executor.rs b/src/executors/future_executor.rs new file mode 100644 index 0000000..31ab109 --- /dev/null +++ b/src/executors/future_executor.rs @@ -0,0 +1,62 @@ +use std::{ + cell::RefCell, + future::Future, + task::{Context, Poll, Waker}, +}; + +use crate::{executors::waker::waker_helper, pin_future}; + +use super::parker::{pair, Parker}; + +fn parker_and_waker() -> (Parker, Waker) { + let (parker, unparker) = pair(); + let waker = waker_helper(move || { + unparker.unpark(); + }); + (parker, waker) +} + +/// Blocks the current thread until the future is polled to finish. +/// +/// Example +/// ```rust +/// let result = spawn_groups::block_on(async { +/// println!("This is an async executor"); +/// 1 +/// }); +/// assert_eq!(result, 1); +/// ``` +/// +pub fn block_on(future: Fut) -> Fut::Output { + pin_future!(future); + thread_local! { + static WAKER_PAIR: RefCell<(Parker, Waker)> = { + RefCell::new(parker_and_waker()) + }; + } + return WAKER_PAIR.with(|waker_pair| match waker_pair.try_borrow_mut() { + Ok(waker_pair) => { + let (parker, waker) = &*waker_pair; + let mut context: Context<'_> = Context::from_waker(waker); + loop { + match future.as_mut().poll(&mut context) { + Poll::Ready(output) => return output, + Poll::Pending => parker.park(), + } + } + } + Err(_) => { + let (parker, unparker) = pair(); + let waker = waker_helper(move || { + unparker.unpark(); + }); + let mut context: Context<'_> = Context::from_waker(&waker); + loop { + match future.as_mut().poll(&mut context) { + Poll::Ready(output) => return output, + Poll::Pending => parker.park(), + } + } + } + }); +} diff --git a/src/executors/local_executor.rs b/src/executors/local_executor.rs deleted file mode 100755 index 6b7d67c..0000000 --- a/src/executors/local_executor.rs +++ /dev/null @@ -1,34 +0,0 @@ -use std::{ - future::Future, - sync::Arc, - task::{Context, Waker}, -}; - -use cooked_waker::IntoWaker; - -use crate::{async_runtime::notifier::Notifier, pin_future}; - -thread_local! { - pub(crate) static WAKER_PAIR: (Arc, Waker) = { - let notifier = Arc::new(Notifier::default()); - let waker = notifier.clone().into_waker(); - (notifier, waker) - }; -} - -pub(crate) fn block_future( - future: Fut, - notifier: Arc, - waker: &Waker, -) -> Fut::Output { - let mut context: Context<'_> = Context::from_waker(waker); - pin_future!(future); - loop { - match future.as_mut().poll(&mut context) { - std::task::Poll::Ready(output) => return output, - std::task::Poll::Pending => { - notifier.wait() - } - } - } -} diff --git a/src/executors/mod.rs b/src/executors/mod.rs index bff9ca4..8c4aefc 100755 --- a/src/executors/mod.rs +++ b/src/executors/mod.rs @@ -1,49 +1,12 @@ -use std::{future::Future, sync::Arc, task::Waker}; - -use cooked_waker::IntoWaker; - -use crate::async_runtime::{notifier::Notifier, task::Task}; - -use self::{local_executor::block_future, task_executor::block_on_task}; - -mod local_executor; +mod waker; +mod waker_traits; mod task_executor; +mod future_executor; +mod notifier; +mod parker; -/// Blocks the current thread until the future is polled to finish. -/// -/// Example -/// ```rust -/// let result = spawn_groups::block_on(async { -/// println!("This is an async executor"); -/// 1 -/// }); -/// assert_eq!(result, 1); -/// ``` -/// -pub fn block_on(future: Fut) -> Fut::Output { - let waker_pair: Result<(Arc, Waker), std::thread::AccessError> = - local_executor::WAKER_PAIR - .try_with(|waker_pair: &(Arc, Waker)| waker_pair.clone()); - match waker_pair { - Ok((notifier, waker)) => block_future(future, notifier, &waker), - Err(_) => { - let notifier: Arc = Arc::new(Notifier::default()); - let waker: Waker = notifier.clone().into_waker(); - block_future(future, notifier, &waker) - } - } -} +pub(crate) use task_executor::block_task; +pub(crate) use notifier::Notifier; +pub(crate) use waker_traits::IntoWaker; -pub(crate) fn block_task(task: Task) { - let waker_pair: Result<(Arc, Waker), std::thread::AccessError> = - task_executor::WAKER_PAIR - .try_with(|waker_pair: &(Arc, Waker)| waker_pair.clone()); - match waker_pair { - Ok((notifier, waker)) => block_on_task(task, notifier, &waker), - Err(_) => { - let notifier: Arc = Arc::new(Notifier::default()); - let waker: Waker = notifier.clone().into_waker(); - block_on_task(task, notifier, &waker) - } - } -} +pub use future_executor::block_on; diff --git a/src/async_runtime/notifier.rs b/src/executors/notifier.rs old mode 100755 new mode 100644 similarity index 63% rename from src/async_runtime/notifier.rs rename to src/executors/notifier.rs index a918817..656d703 --- a/src/async_runtime/notifier.rs +++ b/src/executors/notifier.rs @@ -1,22 +1,33 @@ -use cooked_waker::WakeRef; +#![allow(dead_code)] + use std::sync::{Condvar, Mutex, MutexGuard}; +use crate::executors::waker_traits::WakeRef; + #[derive(Default)] -pub struct Notifier { +pub(crate) struct Notifier { was_notified: Mutex, cv: Condvar, } -impl WakeRef for Notifier { - fn wake_by_ref(&self) { +impl Notifier { + pub(crate) fn wake(&self) { + let mut was_notified: MutexGuard<'_, bool> = self.was_notified.lock().unwrap(); + let was_notified: bool = - { std::mem::replace(&mut self.was_notified.lock().unwrap(), true) }; + { std::mem::replace(&mut was_notified, true) }; if !was_notified { self.cv.notify_one(); } } } +impl WakeRef for Notifier { + fn wake_by_ref(&self) { + self.wake() + } +} + impl Notifier { pub(crate) fn wait(&self) { let mut was_notified: MutexGuard<'_, bool> = self.was_notified.lock().unwrap(); @@ -27,3 +38,4 @@ impl Notifier { *was_notified = false; } } + diff --git a/src/executors/parker.rs b/src/executors/parker.rs new file mode 100755 index 0000000..014b660 --- /dev/null +++ b/src/executors/parker.rs @@ -0,0 +1,135 @@ +use std::{ + cell::Cell, + marker::PhantomData, + sync::{ + atomic::{AtomicUsize, Ordering::SeqCst}, + Arc, Condvar, Mutex, + }, + task::{Wake, Waker}, +}; + +pub(crate) fn pair() -> (Parker, Unparker) { + let p = Parker::new(); + let u = p.unparker(); + (p, u) +} + +pub(crate) struct Parker { + unparker: Unparker, + _marker: PhantomData>, +} + +impl Parker { + pub(crate) fn new() -> Parker { + Parker { + unparker: Unparker { + inner: Arc::new(Inner { + state: AtomicUsize::new(EMPTY), + lock: Mutex::new(()), + cvar: Condvar::new(), + }), + }, + _marker: PhantomData, + } + } + + pub(crate) fn park(&self) { + self.unparker.inner.park(); + } + + pub(crate) fn unparker(&self) -> Unparker { + self.unparker.clone() + } +} + +pub(crate) struct Unparker { + inner: Arc, +} + +impl Unparker { + pub(crate) fn unpark(&self) { + self.inner.unpark(); + } +} + +impl Clone for Unparker { + fn clone(&self) -> Unparker { + Unparker { + inner: self.inner.clone(), + } + } +} + +impl From for Waker { + fn from(up: Unparker) -> Self { + Waker::from(up.inner) + } +} + +const EMPTY: usize = 0; +const PARKED: usize = 1; +const NOTIFIED: usize = 2; + +struct Inner { + state: AtomicUsize, + lock: Mutex<()>, + cvar: Condvar, +} + +impl Inner { + fn park(&self) { + if self + .state + .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst) + .is_ok() + { + return; + } + + let mut m = self.lock.lock().unwrap(); + + match self.state.compare_exchange(EMPTY, PARKED, SeqCst, SeqCst) { + Ok(_) => {} + Err(NOTIFIED) => { + let old = self.state.swap(EMPTY, SeqCst); + assert_eq!(old, NOTIFIED, "park state changed unexpectedly"); + return; + } + Err(n) => panic!("inconsistent park_timeout state: {}", n), + } + + loop { + m = self.cvar.wait(m).unwrap(); + + if self + .state + .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst) + .is_ok() + { + return; + } + } + } + + pub(crate) fn unpark(&self) { + match self.state.swap(NOTIFIED, SeqCst) { + EMPTY => return, + NOTIFIED => return, + PARKED => {} + _ => panic!("inconsistent state in unpark"), + } + + drop(self.lock.lock().unwrap()); + self.cvar.notify_one(); + } +} + +impl Wake for Inner { + fn wake(self: Arc) { + self.unpark(); + } + + fn wake_by_ref(self: &Arc) { + self.unpark(); + } +} diff --git a/src/executors/task_executor.rs b/src/executors/task_executor.rs old mode 100755 new mode 100644 index dbd4eda..68b7049 --- a/src/executors/task_executor.rs +++ b/src/executors/task_executor.rs @@ -1,31 +1,63 @@ use std::{ - sync::Arc, - task::{Context, Waker}, + cell::RefCell, + task::{Context, Poll, Waker}, }; -use crate::async_runtime::{notifier::Notifier, task::Task}; -use cooked_waker::IntoWaker; +use crate::{ + async_runtime::task::Task, + executors::parker::{pair, Parker}, +}; + +use super::waker::waker_helper; -thread_local! { - pub(crate) static WAKER_PAIR: (Arc, Waker) = { - let notifier = Arc::new(Notifier::default()); - let waker = notifier.clone().into_waker(); - (notifier, waker) - }; +fn parker_and_waker() -> (Parker, Waker) { + let (parker, unparker) = pair(); + let waker = waker_helper(move || { + unparker.unpark(); + }); + (parker, waker) } -pub(crate) fn block_on_task(task: Task, notifier: Arc, waker: &Waker) { +pub(crate) fn block_task(task: Task) { if task.is_completed() { return; } - let mut context: Context<'_> = Context::from_waker(waker); - let Ok(mut task) = task.future.lock() else { - return; - }; - loop { - match task.as_mut().poll(&mut context) { - std::task::Poll::Ready(()) => return, - std::task::Poll::Pending => notifier.wait(), - } + + thread_local! { + static TASK_PAIR: RefCell<(Parker, Waker)> = { + RefCell::new(parker_and_waker()) + }; } + + TASK_PAIR.with(|waker_pair| match waker_pair.try_borrow_mut() { + Ok(waker_pair) => { + let (parker, ref waker) = &*waker_pair; + let mut context: Context<'_> = Context::from_waker(waker); + let Ok(mut task) = task.lock() else { + return; + }; + loop { + match task.as_mut().poll(&mut context) { + Poll::Ready(()) => return, + Poll::Pending => parker.park(), + } + } + } + Err(_) => { + let (parker, unparker) = pair(); + let waker = waker_helper(move || { + unparker.unpark(); + }); + let mut context: Context<'_> = Context::from_waker(&waker); + let Ok(mut task) = task.lock() else { + return; + }; + loop { + match task.as_mut().poll(&mut context) { + Poll::Ready(()) => return, + Poll::Pending => parker.park(), + } + } + } + }); } diff --git a/src/executors/waker.rs b/src/executors/waker.rs new file mode 100644 index 0000000..1c4b35b --- /dev/null +++ b/src/executors/waker.rs @@ -0,0 +1,15 @@ +use std::{sync::Arc, task::Waker}; + +use super::waker_traits::{IntoWaker, WakeRef}; + +struct WakerHelper(F); + +pub(crate) fn waker_helper(f: F) -> Waker { + Arc::new(WakerHelper(f)).into_waker() +} + +impl WakeRef for WakerHelper { + fn wake_by_ref(&self) { + (self.0)(); + } +} diff --git a/src/executors/waker_traits.rs b/src/executors/waker_traits.rs new file mode 100644 index 0000000..f3ee5af --- /dev/null +++ b/src/executors/waker_traits.rs @@ -0,0 +1,106 @@ +use std::{ + mem::ManuallyDrop, + sync::Arc, + task::{RawWaker, RawWakerVTable, Waker}, +}; + +/// # Safety +/// All safe here +pub(crate) unsafe trait ViaRawPointer { + type Target: ?Sized; + + fn into_raw(self) -> *mut Self::Target; + + unsafe fn from_raw(ptr: *mut Self::Target) -> Self; +} + +pub(crate) trait WakeRef { + fn wake_by_ref(&self); +} + +pub(crate) trait Wake: WakeRef + Sized { + #[inline] + fn wake(self) { + self.wake_by_ref() + } +} + +pub(crate) trait IntoWaker { + const VTABLE: &'static RawWakerVTable; + + #[must_use] + fn into_waker(self) -> Waker; +} + +impl IntoWaker for T +where + T: Wake + Clone + 'static + ViaRawPointer, + T::Target: Sized, +{ + const VTABLE: &'static RawWakerVTable = &RawWakerVTable::new( + // clone + |raw| { + let raw = raw as *mut T::Target; + + let waker = ManuallyDrop::::new(unsafe { ViaRawPointer::from_raw(raw) }); + let cloned: T = (*waker).clone(); + + // We can't save the `into_raw` back into the raw waker, so we must + // simply assert that the pointer has remained the same. This is + // part of the ViaRawPointer safety contract, so we only check it + // in debug builds. + debug_assert_eq!(ManuallyDrop::into_inner(waker).into_raw(), raw); + + let cloned_raw = cloned.into_raw(); + let cloned_raw = cloned_raw as *const (); + RawWaker::new(cloned_raw, T::VTABLE) + }, + // wake by value + |raw| { + let raw = raw as *mut T::Target; + let waker: T = unsafe { ViaRawPointer::from_raw(raw) }; + waker.wake(); + }, + // wake by ref + |raw| { + let raw = raw as *mut T::Target; + let waker = ManuallyDrop::::new(unsafe { ViaRawPointer::from_raw(raw) }); + waker.wake_by_ref(); + + debug_assert_eq!(ManuallyDrop::into_inner(waker).into_raw(), raw); + }, + // Drop + |raw| { + let raw = raw as *mut T::Target; + let _waker: T = unsafe { ViaRawPointer::from_raw(raw) }; + }, + ); + + fn into_waker(self) -> Waker { + let raw = self.into_raw(); + let raw = raw as *const (); + let raw_waker = RawWaker::new(raw, T::VTABLE); + unsafe { Waker::from_raw(raw_waker) } + } +} + +unsafe impl ViaRawPointer for Arc { + type Target = T; + + fn into_raw(self) -> *mut T { + Arc::into_raw(self) as *mut T + } + + unsafe fn from_raw(ptr: *mut T) -> Self { + Arc::from_raw(ptr as *const T) + } +} + +impl WakeRef for Arc { + #[inline] + fn wake_by_ref(&self) { + T::wake_by_ref(self.as_ref()) + } +} + +impl Wake for Arc {} diff --git a/src/lib.rs b/src/lib.rs index 01f95a2..f154fac 100755 --- a/src/lib.rs +++ b/src/lib.rs @@ -185,6 +185,7 @@ pub use meta_types::GetType; pub use shared::priority::Priority; pub use sleeper::sleep; pub use spawn_group::SpawnGroup; +pub use yield_now::ready; pub use yield_now::yield_now; use std::future::Future; diff --git a/src/shared/runtime.rs b/src/shared/runtime.rs index eb2b10d..93978bd 100755 --- a/src/shared/runtime.rs +++ b/src/shared/runtime.rs @@ -1,7 +1,7 @@ use crate::{ async_runtime::{executor::Executor, task::Task}, async_stream::AsyncStream, - executors::block_task, + executors::{block_on, block_task}, shared::priority::Priority, }; use std::{ @@ -29,8 +29,11 @@ impl RuntimeEngine { impl RuntimeEngine { pub(crate) fn cancel(&mut self) { + let Ok(mut tasks) = self.tasks.lock() else { + return; + }; self.runtime.cancel(); - self.tasks.lock().unwrap().clear(); + tasks.clear(); self.stream.cancel_tasks(); self.poll(); } @@ -42,26 +45,35 @@ impl RuntimeEngine { } pub(crate) fn end(&mut self) { - self.runtime.cancel(); - self.tasks.lock().unwrap().clear(); + self.cancel(); self.runtime.end() } } impl RuntimeEngine { pub(crate) fn wait_for_all_tasks(&self) { - self.poll(); + let Ok(mut tasks) = self.tasks.lock() else { + return; + }; + if tasks.is_empty() { + return; + } self.runtime.cancel(); - if let Ok(mut lock) = self.tasks.lock() { - lock.sort_by(|lhs, rhs| lhs.0.cmp(&rhs.0)); - while let Some((_, handle)) = lock.pop() { - self.runtime.submit(move || { - block_task(handle); - }); + tasks.retain(|(_, task)| { + task.cancel(); + !task.is_completed() + }); + tasks.sort_by(|lhs, rhs| lhs.0.cmp(&rhs.0)); + if tasks.is_empty() { + return; + } + while let Some((_, task)) = tasks.pop() { + if task.is_completed() { + continue; } + self.runtime.submit(move || block_task(task)); } - - self.poll(); + self.poll() } } @@ -70,13 +82,16 @@ impl RuntimeEngine { where F: Future + Send + 'static, { + let Ok(mut tasks) = self.tasks.lock() else { + return; + }; self.stream.increment(); - let mut stream: AsyncStream = self.stream(); - let tasks: Arc>> = self.tasks.clone(); - tasks.lock().unwrap().push(( + let stream: AsyncStream = self.stream(); + tasks.push(( priority, self.runtime.spawn(async move { - stream.insert_item(task.await).await; + let task_result = task.await; + block_on(async { stream.insert_item(task_result).await }); stream.decrement_task_count(); }), )); diff --git a/src/spawn_group.rs b/src/spawn_group.rs index b8d757a..9280153 100755 --- a/src/spawn_group.rs +++ b/src/spawn_group.rs @@ -102,9 +102,15 @@ impl SpawnGroup { impl SpawnGroup { /// Waits for all remaining child tasks for finish. - pub async fn wait_for_all(&self) { + pub async fn wait_for_all(&mut self) { self.wait().await; } + + /// Waits for all remaining child tasks for finish in non async context. + pub fn wait_non_async(&mut self) { + self.runtime.wait_for_all_tasks(); + self.decrement_count_to_zero() + } } impl SpawnGroup { diff --git a/src/threadpool_impl/channel.rs b/src/threadpool_impl/channel.rs index 9a75186..9f25cfa 100644 --- a/src/threadpool_impl/channel.rs +++ b/src/threadpool_impl/channel.rs @@ -6,23 +6,23 @@ use std::{ }, }; -#[derive(Default)] pub struct Channel { pair: Arc<(Mutex>, Condvar)>, closed: Arc, } impl Channel { - pub fn enqueue(&self, value: ItemType) -> bool { + pub(crate) fn enqueue(&self, value: ItemType) { if self.closed.load(Ordering::Relaxed) { - return false; + return; } - if let Ok(mut lock) = self.pair.0.lock() { - lock.push_back(value); + let Ok(mut lock) = self.pair.0.lock() else { + return; + }; + lock.push_back(value); + if lock.len() == 1 { self.pair.1.notify_one(); - return true; } - false } } @@ -45,7 +45,7 @@ impl Clone for Channel { } impl Channel { - pub fn dequeue(&self) -> Option { + pub(crate) fn dequeue(&self) -> Option { if self.closed.load(Ordering::Relaxed) { return None; } @@ -63,18 +63,21 @@ impl Channel { } impl Channel { - /// pub fn close(&self) { + if self.closed.load(Ordering::Relaxed) { + return; + } if let Ok(_lock) = self.pair.0.lock() { self.closed.store(true, Ordering::Relaxed); self.pair.1.notify_all(); } } - pub fn clear(&self) { - if let Ok(mut lock) = self.pair.0.lock() { - lock.clear(); - } + pub(crate) fn clear(&self) { + let Ok(mut lock) = self.pair.0.lock() else { + return; + }; + lock.clear(); } } diff --git a/src/threadpool_impl/mod.rs b/src/threadpool_impl/mod.rs index fad412e..97f5b7e 100644 --- a/src/threadpool_impl/mod.rs +++ b/src/threadpool_impl/mod.rs @@ -1,5 +1,6 @@ mod channel; mod threadpool; +mod waitgroup; pub(crate) type Func = dyn FnOnce() + Send; diff --git a/src/threadpool_impl/threadpool.rs b/src/threadpool_impl/threadpool.rs index 4845719..0400849 100644 --- a/src/threadpool_impl/threadpool.rs +++ b/src/threadpool_impl/threadpool.rs @@ -1,67 +1,91 @@ use std::{ - backtrace, panic, sync::{Arc, Barrier}, thread::spawn + backtrace, panic, + sync::atomic::{AtomicUsize, Ordering}, + thread::spawn, }; -use super::{Channel, Func}; +use super::{waitgroup::WaitGroup, Channel, Func}; -pub struct ThreadPool { - task_channel: Channel>, - barrier: Arc, - count: usize, +pub(crate) struct ThreadPool { + task_channels: Vec>>, + index: AtomicUsize, + wait_group: WaitGroup, } impl ThreadPool { pub(crate) fn new(count: usize) -> Self { - let task_channel: Channel> = Channel::new(); - for _ in 1..=count { - let channel: Channel> = task_channel.clone(); - spawn(move || { - panic_hook(); - while let Some(ops) = channel.dequeue() { - ops() - } - }); + let mut count = count; + if count < 1 { + count = 1; } ThreadPool { - task_channel, - count, - barrier: Arc::new(Barrier::new(count + 1)), + task_channels: (1..=count) + .map(|_| { + let channel: Channel> = Channel::new(); + let chan = channel.clone(); + spawn(move || { + panic_hook(); + for ops in channel { + ops(); + } + }); + chan + }) + .collect(), + index: AtomicUsize::new(0), + wait_group: WaitGroup::new(), } } } impl ThreadPool { - pub fn submit(&self, task: Task) + fn current_index(&self) -> usize { + self.index.swap( + (self.index.load(Ordering::Relaxed) + 1) % self.task_channels.len(), + Ordering::SeqCst, + ) + } + + pub(crate) fn submit(&self, task: Task) where Task: FnOnce() + 'static + Send, { - self.task_channel.enqueue(Box::new(task)); + self.task_channels[self.current_index()].enqueue(Box::new(task)); + } + + pub(crate) fn wait_for_all(&self) { + self.task_channels.iter().for_each(|channel| { + let wait_group = self.wait_group.clone(); + wait_group.enter(); + channel.enqueue(Box::new(move || { + wait_group.leave(); + })); + }); + self.wait_group.wait(); } } impl ThreadPool { - pub fn wait_for_all(&self) { - (1..=self.count).for_each(|_| { - let barrier: Arc = self.barrier.clone(); - self.task_channel.enqueue(Box::new(move || { - barrier.wait(); - })); + pub(crate) fn end(&self) { + self.task_channels.iter().for_each(|channel| { + channel.close(); + channel.clear(); }); - self.barrier.wait(); } } impl Drop for ThreadPool { fn drop(&mut self) { _ = panic::take_hook(); - self.task_channel.close(); - self.clear() + self.end(); } } impl ThreadPool { - pub fn clear(&self) { - self.task_channel.clear(); + pub(crate) fn clear(&self) { + self.task_channels + .iter() + .for_each(|channel| channel.clear()); } } diff --git a/src/threadpool_impl/waitgroup.rs b/src/threadpool_impl/waitgroup.rs new file mode 100644 index 0000000..f0292a2 --- /dev/null +++ b/src/threadpool_impl/waitgroup.rs @@ -0,0 +1,42 @@ +use std::sync::{Arc, Condvar, Mutex}; + +#[derive(Clone)] +pub(crate) struct WaitGroup { + pair: Arc<(Mutex, Condvar)>, +} + +impl WaitGroup { + pub(crate) fn new() -> Self { + Self { + pair: Arc::new((Mutex::new(0), Condvar::new())), + } + } +} + +impl WaitGroup { + pub(crate) fn enter(&self) { + let Ok(mut guard) = self.pair.0.lock() else { + return; + }; + (*guard) += 1; + } + + pub(crate) fn leave(&self) { + let Ok(mut guard) = self.pair.0.lock() else { + return; + }; + (*guard) -= 1; + if (*guard) == 0 { + self.pair.1.notify_all(); + } + } + + pub(crate) fn wait(&self) { + let Ok(mut guard) = self.pair.0.lock() else { + return; + }; + while *guard > 0 { + guard = self.pair.1.wait(guard).unwrap(); + } + } +} diff --git a/src/yield_now/mod.rs b/src/yield_now/mod.rs index e85848a..152dcdd 100755 --- a/src/yield_now/mod.rs +++ b/src/yield_now/mod.rs @@ -18,3 +18,17 @@ mod yielder; pub fn yield_now() -> Yielder { Yielder::default() } + +/// Resolves to the provided value. +/// +/// # Examples +/// ``` +/// use spawn_groups::{block_on, ready}; +/// block_on(async { +/// let ten = ready(10).await; +/// assert_eq!(ten, 10); +/// }); +/// ``` +pub async fn ready(val: ValueType) -> ValueType { + val +} From 9153be96c52b07671ebf6ee7e6f80dc8637a8042 Mon Sep 17 00:00:00 2001 From: Genaro-Chris Date: Wed, 24 Jul 2024 12:53:53 +0100 Subject: [PATCH 08/10] Another major fix for v2.0 --- src/async_runtime/executor.rs | 70 ++++++++++------------------------ src/discarding_spawn_group.rs | 40 +++++-------------- src/err_spawn_group.rs | 63 +++++++----------------------- src/executors/mod.rs | 13 +++---- src/executors/notifier.rs | 41 -------------------- src/executors/task_executor.rs | 61 ++++++++--------------------- src/shared/mod.rs | 2 - src/shared/runtime.rs | 32 ++++++---------- src/shared/sharedfuncs.rs | 20 ---------- src/shared/wait.rs | 6 --- src/spawn_group.rs | 57 ++++++--------------------- 11 files changed, 88 insertions(+), 317 deletions(-) delete mode 100644 src/executors/notifier.rs delete mode 100755 src/shared/sharedfuncs.rs delete mode 100755 src/shared/wait.rs diff --git a/src/async_runtime/executor.rs b/src/async_runtime/executor.rs index 2a3ace1..1d9517e 100644 --- a/src/async_runtime/executor.rs +++ b/src/async_runtime/executor.rs @@ -1,7 +1,4 @@ -use crate::{ - executors::{IntoWaker, Notifier}, - threadpool_impl::{Channel, ThreadPool}, -}; +use crate::{executors::waker_helper, threadpool_impl::{Channel, ThreadPool}}; use super::task::Task; @@ -34,17 +31,17 @@ impl Executor { impl Executor { fn start(&self) { - let queue = self.task_queue.clone(); - let cancelled = self.cancelled.clone(); - let pool = self.pool.clone(); + let queue: Channel = self.task_queue.clone(); + let cancelled: Arc = self.cancelled.clone(); + let pool: Arc = self.pool.clone(); std::thread::spawn(move || loop { let Some(task) = queue.clone().dequeue() else { return; }; - if cancelled.load(Ordering::Relaxed) || task.is_cancelled() { + if cancelled.load(Ordering::Acquire) || task.is_cancelled() { continue; } - let queue = queue.clone(); + let queue: Channel = queue.clone(); pool.submit(move || block_task_with(task, &queue)) }); } @@ -67,10 +64,10 @@ impl Executor { pub(crate) fn cancel(&self) { self.pool.clear(); - self.cancelled.store(true, Ordering::Relaxed); + self.cancelled.store(true, Ordering::Release); self.poll_all(); self.pool.clear(); - self.cancelled.store(false, Ordering::Relaxed); + self.cancelled.store(false, Ordering::Release); } pub(crate) fn poll_all(&self) { @@ -84,46 +81,21 @@ impl Executor { } } -thread_local! { - static TASK_WAKER: Waker = { - Arc::new(Notifier::default()).into_waker() - }; -} - -fn block_task_with(future: Task, queue: &Channel) { - if future.is_completed() || future.is_cancelled() { +fn block_task_with(task: Task, queue: &Channel) { + if task.is_completed() || task.is_cancelled() { return; } - let task = future.clone(); - let waker_result = TASK_WAKER.try_with(|waker| waker.clone()); - match waker_result { - Ok(waker) => { - let mut context: Context<'_> = Context::from_waker(&waker); - let Ok(mut future) = future.lock() else { - return; - }; - match future.as_mut().poll(&mut context) { - Poll::Ready(()) => task.complete(), - Poll::Pending => { - if !task.is_cancelled() { - queue.enqueue(task.clone()); - } - } - } - } - Err(_) => { - let waker: Waker = Arc::new(Notifier::default()).into_waker(); - let mut context: Context<'_> = Context::from_waker(&waker); - let Ok(mut future) = future.lock() else { - return; - }; - match future.as_mut().poll(&mut context) { - Poll::Ready(()) => task.complete(), - Poll::Pending => { - if !task.is_cancelled() { - queue.enqueue(task.clone()); - } - } + let task_clone = task.clone(); + let waker: Waker = waker_helper(|| {}); + let mut context: Context<'_> = Context::from_waker(&waker); + let Ok(mut future) = task.lock() else { + return; + }; + match future.as_mut().poll(&mut context) { + Poll::Ready(()) => task_clone.complete(), + Poll::Pending => { + if !task_clone.is_cancelled() { + queue.enqueue(task_clone); } } } diff --git a/src/discarding_spawn_group.rs b/src/discarding_spawn_group.rs index 1f2d897..e14db42 100755 --- a/src/discarding_spawn_group.rs +++ b/src/discarding_spawn_group.rs @@ -1,4 +1,4 @@ -use crate::shared::{priority::Priority, runtime::RuntimeEngine, sharedfuncs::Shared}; +use crate::shared::{priority::Priority, runtime::RuntimeEngine}; use std::future::Future; @@ -54,9 +54,9 @@ impl DiscardingSpawnGroup { /// * `closure`: an async closure that doesn't return anything pub fn spawn_task(&mut self, priority: Priority, closure: F) where - F: Future::Result> + Send + 'static, + F: Future + Send + 'static, { - self.add_task(priority, closure); + self.runtime.write_task(priority, closure); } /// Spawn a new task only if the group is not cancelled yet, @@ -68,14 +68,17 @@ impl DiscardingSpawnGroup { /// * `closure`: an async closure that return doesn't return anything pub fn spawn_task_unlessed_cancelled(&mut self, priority: Priority, closure: F) where - F: Future::Result> + Send + 'static, + F: Future + Send + 'static, { - self.add_task_unlessed_cancelled(priority, closure); + if !self.is_cancelled { + self.runtime.write_task(priority, closure); + } } /// Cancels all running task in the spawn group pub fn cancel_all(&mut self) { - self.cancel_all_tasks(); + self.runtime.cancel(); + self.is_cancelled = true; } } @@ -116,28 +119,3 @@ impl Drop for DiscardingSpawnGroup { self.runtime.end() } } - -impl Shared for DiscardingSpawnGroup { - type Result = (); - - fn add_task(&mut self, priority: Priority, closure: F) - where - F: Future + Send + 'static, - { - self.runtime.write_task(priority, closure); - } - - fn add_task_unlessed_cancelled(&mut self, priority: Priority, closure: F) - where - F: Future + Send + 'static, - { - if !self.is_cancelled { - self.add_task(priority, closure) - } - } - - fn cancel_all_tasks(&mut self) { - self.runtime.cancel(); - self.is_cancelled = true; - } -} diff --git a/src/err_spawn_group.rs b/src/err_spawn_group.rs index 8a4e1b1..a6d8efd 100755 --- a/src/err_spawn_group.rs +++ b/src/err_spawn_group.rs @@ -1,7 +1,4 @@ -use crate::shared::{ - priority::Priority, runtime::RuntimeEngine, sharedfuncs::Shared, wait::Waitable, -}; -use async_trait::async_trait; +use crate::shared::{priority::Priority, runtime::RuntimeEngine}; use futures_lite::{Stream, StreamExt}; use std::{ future::Future, @@ -68,16 +65,19 @@ impl ErrSpawnGroup { /// * `closure`: an async closure that return a value of type ``Result`` pub fn spawn_task(&mut self, priority: Priority, closure: F) where - F: Future as Shared>::Result> + F: Future> + Send + 'static, { - self.add_task(priority, closure); + self.increment_count(); + self.runtime.write_task(priority, closure); } /// Cancels all running task in the spawn group pub fn cancel_all(&mut self) { - self.cancel_all_tasks(); + self.runtime.cancel(); + self.is_cancelled = true; + self.decrement_count_to_zero(); } /// Spawn a new task only if the group is not cancelled yet, @@ -89,17 +89,19 @@ impl ErrSpawnGroup { /// * `closure`: an async closure that return a value of type ``Result`` pub fn spawn_task_unlessed_cancelled(&mut self, priority: Priority, closure: F) where - F: Future as Shared>::Result> + F: Future> + Send + 'static, { - self.add_task_unlessed_cancelled(priority, closure); + if !self.is_cancelled { + self.spawn_task(priority, closure) + } } } impl ErrSpawnGroup { /// Returns the first element of the stream, or None if it is empty. - pub async fn first(&self) -> Option< as Shared>::Result> { + pub async fn first(&self) -> Option> { self.runtime.stream().first().await } } @@ -114,7 +116,7 @@ impl ErrSpawnGroup { impl ErrSpawnGroup { /// Waits for all remaining child tasks for finish. pub async fn wait_for_all(&mut self) { - self.wait().await; + self.wait_non_async() } /// Waits for all remaining child tasks for finish in non async context. @@ -209,35 +211,6 @@ impl Drop for ErrSpawnGroup Shared - for ErrSpawnGroup -{ - type Result = Result; - - fn add_task(&mut self, priority: Priority, closure: F) - where - F: Future + Send + 'static, - { - self.increment_count(); - self.runtime.write_task(priority, closure); - } - - fn cancel_all_tasks(&mut self) { - self.runtime.cancel(); - self.is_cancelled = true; - self.decrement_count_to_zero(); - } - - fn add_task_unlessed_cancelled(&mut self, priority: Priority, closure: F) - where - F: Future + Send + 'static, - { - if !self.is_cancelled { - self.add_task(priority, closure) - } - } -} - impl Stream for ErrSpawnGroup { type Item = Result; @@ -245,13 +218,3 @@ impl Stream for ErrSpawnGroup Waitable - for ErrSpawnGroup -{ - async fn wait(&self) { - self.runtime.wait_for_all_tasks(); - self.decrement_count_to_zero(); - } -} diff --git a/src/executors/mod.rs b/src/executors/mod.rs index 8c4aefc..99211db 100755 --- a/src/executors/mod.rs +++ b/src/executors/mod.rs @@ -1,12 +1,9 @@ -mod waker; -mod waker_traits; -mod task_executor; mod future_executor; -mod notifier; mod parker; - -pub(crate) use task_executor::block_task; -pub(crate) use notifier::Notifier; -pub(crate) use waker_traits::IntoWaker; +mod task_executor; +mod waker; +mod waker_traits; pub use future_executor::block_on; +pub(crate) use task_executor::block_task; +pub(crate) use waker::waker_helper; diff --git a/src/executors/notifier.rs b/src/executors/notifier.rs deleted file mode 100644 index 656d703..0000000 --- a/src/executors/notifier.rs +++ /dev/null @@ -1,41 +0,0 @@ -#![allow(dead_code)] - -use std::sync::{Condvar, Mutex, MutexGuard}; - -use crate::executors::waker_traits::WakeRef; - -#[derive(Default)] -pub(crate) struct Notifier { - was_notified: Mutex, - cv: Condvar, -} - -impl Notifier { - pub(crate) fn wake(&self) { - let mut was_notified: MutexGuard<'_, bool> = self.was_notified.lock().unwrap(); - - let was_notified: bool = - { std::mem::replace(&mut was_notified, true) }; - if !was_notified { - self.cv.notify_one(); - } - } -} - -impl WakeRef for Notifier { - fn wake_by_ref(&self) { - self.wake() - } -} - -impl Notifier { - pub(crate) fn wait(&self) { - let mut was_notified: MutexGuard<'_, bool> = self.was_notified.lock().unwrap(); - - while !*was_notified { - was_notified = self.cv.wait(was_notified).unwrap(); - } - *was_notified = false; - } -} - diff --git a/src/executors/task_executor.rs b/src/executors/task_executor.rs index 68b7049..2985d65 100644 --- a/src/executors/task_executor.rs +++ b/src/executors/task_executor.rs @@ -1,14 +1,11 @@ -use std::{ - cell::RefCell, - task::{Context, Poll, Waker}, -}; +use std::task::{Context, Poll, Waker}; -use crate::{ - async_runtime::task::Task, - executors::parker::{pair, Parker}, -}; +use crate::async_runtime::task::Task; -use super::waker::waker_helper; +use super::{ + parker::{pair, Parker}, + waker::waker_helper, +}; fn parker_and_waker() -> (Parker, Waker) { let (parker, unparker) = pair(); @@ -23,41 +20,15 @@ pub(crate) fn block_task(task: Task) { return; } - thread_local! { - static TASK_PAIR: RefCell<(Parker, Waker)> = { - RefCell::new(parker_and_waker()) - }; - } - - TASK_PAIR.with(|waker_pair| match waker_pair.try_borrow_mut() { - Ok(waker_pair) => { - let (parker, ref waker) = &*waker_pair; - let mut context: Context<'_> = Context::from_waker(waker); - let Ok(mut task) = task.lock() else { - return; - }; - loop { - match task.as_mut().poll(&mut context) { - Poll::Ready(()) => return, - Poll::Pending => parker.park(), - } - } - } - Err(_) => { - let (parker, unparker) = pair(); - let waker = waker_helper(move || { - unparker.unpark(); - }); - let mut context: Context<'_> = Context::from_waker(&waker); - let Ok(mut task) = task.lock() else { - return; - }; - loop { - match task.as_mut().poll(&mut context) { - Poll::Ready(()) => return, - Poll::Pending => parker.park(), - } - } + let (parker, waker) = parker_and_waker(); + let mut context: Context<'_> = Context::from_waker(&waker); + let Ok(mut future) = task.lock() else { + return; + }; + loop { + match future.as_mut().poll(&mut context) { + Poll::Ready(output) => return output, + Poll::Pending => parker.park(), } - }); + } } diff --git a/src/shared/mod.rs b/src/shared/mod.rs index 0cc077a..82f8a61 100755 --- a/src/shared/mod.rs +++ b/src/shared/mod.rs @@ -1,4 +1,2 @@ pub(crate) mod priority; pub(crate) mod runtime; -pub(crate) mod sharedfuncs; -pub(crate) mod wait; diff --git a/src/shared/runtime.rs b/src/shared/runtime.rs index 93978bd..e44768a 100755 --- a/src/shared/runtime.rs +++ b/src/shared/runtime.rs @@ -1,15 +1,13 @@ use crate::{ async_runtime::{executor::Executor, task::Task}, async_stream::AsyncStream, - executors::{block_on, block_task}, + block_on, + executors::block_task, shared::priority::Priority, }; -use std::{ - future::Future, - sync::{Arc, Mutex}, -}; +use std::{cell::RefCell, future::Future}; -type TaskQueue = Arc>>; +type TaskQueue = RefCell>; pub struct RuntimeEngine { tasks: TaskQueue, @@ -20,7 +18,7 @@ pub struct RuntimeEngine { impl RuntimeEngine { pub(crate) fn new(count: usize) -> Self { Self { - tasks: Arc::new(Mutex::new(vec![])), + tasks: RefCell::new(vec![]), stream: AsyncStream::new(), runtime: Executor::new(count), } @@ -29,11 +27,8 @@ impl RuntimeEngine { impl RuntimeEngine { pub(crate) fn cancel(&mut self) { - let Ok(mut tasks) = self.tasks.lock() else { - return; - }; self.runtime.cancel(); - tasks.clear(); + self.tasks.borrow_mut().clear(); self.stream.cancel_tasks(); self.poll(); } @@ -45,20 +40,20 @@ impl RuntimeEngine { } pub(crate) fn end(&mut self) { - self.cancel(); + self.runtime.cancel(); + self.tasks.borrow_mut().clear(); + self.stream.cancel_tasks(); self.runtime.end() } } impl RuntimeEngine { pub(crate) fn wait_for_all_tasks(&self) { - let Ok(mut tasks) = self.tasks.lock() else { - return; - }; + self.runtime.cancel(); + let mut tasks = self.tasks.borrow_mut(); if tasks.is_empty() { return; } - self.runtime.cancel(); tasks.retain(|(_, task)| { task.cancel(); !task.is_completed() @@ -82,12 +77,9 @@ impl RuntimeEngine { where F: Future + Send + 'static, { - let Ok(mut tasks) = self.tasks.lock() else { - return; - }; self.stream.increment(); let stream: AsyncStream = self.stream(); - tasks.push(( + self.tasks.borrow_mut().push(( priority, self.runtime.spawn(async move { let task_result = task.await; diff --git a/src/shared/sharedfuncs.rs b/src/shared/sharedfuncs.rs deleted file mode 100755 index 5a96d40..0000000 --- a/src/shared/sharedfuncs.rs +++ /dev/null @@ -1,20 +0,0 @@ -use crate::shared::priority::Priority; -use std::future::Future; - - -/// The basic functionalities between all kinds of spawn groups -pub trait Shared { - /// A value return when a task is being awaited for - type Result; - /// Add a new task into the engine - fn add_task(&mut self, priority: Priority, closure: F) - where - F: Future + Send + 'static; - /// Cancels all running tasks in the engine - fn cancel_all_tasks(&mut self); - /// Add a new task only if the engine is not cancelled yet, - /// otherwise does nothing - fn add_task_unlessed_cancelled(&mut self, priority: Priority, closure: F) - where - F: Future + Send + 'static; -} diff --git a/src/shared/wait.rs b/src/shared/wait.rs deleted file mode 100755 index cfa9613..0000000 --- a/src/shared/wait.rs +++ /dev/null @@ -1,6 +0,0 @@ -use async_trait::async_trait; - -#[async_trait] -pub trait Waitable { - async fn wait(&self); -} diff --git a/src/spawn_group.rs b/src/spawn_group.rs index 9280153..b715be6 100755 --- a/src/spawn_group.rs +++ b/src/spawn_group.rs @@ -1,7 +1,4 @@ -use crate::shared::{ - priority::Priority, runtime::RuntimeEngine, sharedfuncs::Shared, wait::Waitable, -}; -use async_trait::async_trait; +use crate::shared::{priority::Priority, runtime::RuntimeEngine}; use futures_lite::{Stream, StreamExt}; use std::{ future::Future, @@ -68,9 +65,10 @@ impl SpawnGroup { /// * `closure`: an async closure that return a value of type ``ValueType`` pub fn spawn_task(&mut self, priority: Priority, closure: F) where - F: Future as Shared>::Result> + Send + 'static, + F: Future + Send + 'static, { - self.add_task(priority, closure); + self.increment_count(); + self.runtime.write_task(priority, closure); } /// Spawn a new task only if the group is not cancelled yet, @@ -82,14 +80,18 @@ impl SpawnGroup { /// * `closure`: an async closure that return a value of type ``ValueType`` pub fn spawn_task_unlessed_cancelled(&mut self, priority: Priority, closure: F) where - F: Future as Shared>::Result> + Send + 'static, + F: Future + Send + 'static, { - self.add_task_unlessed_cancelled(priority, closure); + if !self.is_cancelled { + self.spawn_task(priority, closure) + } } /// Cancels all running task in the spawn group pub fn cancel_all(&mut self) { - self.cancel_all_tasks(); + self.runtime.cancel(); + self.is_cancelled = true; + self.decrement_count_to_zero(); } } @@ -103,7 +105,7 @@ impl SpawnGroup { impl SpawnGroup { /// Waits for all remaining child tasks for finish. pub async fn wait_for_all(&mut self) { - self.wait().await; + self.wait_non_async() } /// Waits for all remaining child tasks for finish in non async context. @@ -205,33 +207,6 @@ impl Drop for SpawnGroup { } } -impl Shared for SpawnGroup { - type Result = ValueType; - - fn add_task(&mut self, priority: Priority, closure: F) - where - F: Future + Send + 'static, - { - self.increment_count(); - self.runtime.write_task(priority, closure); - } - - fn cancel_all_tasks(&mut self) { - self.runtime.cancel(); - self.is_cancelled = true; - self.decrement_count_to_zero(); - } - - fn add_task_unlessed_cancelled(&mut self, priority: Priority, closure: F) - where - F: Future + Send + 'static, - { - if !self.is_cancelled { - self.add_task(priority, closure) - } - } -} - impl Stream for SpawnGroup { type Item = ValueType; @@ -239,11 +214,3 @@ impl Stream for SpawnGroup { self.runtime.stream().poll_next(cx) } } - -#[async_trait] -impl Waitable for SpawnGroup { - async fn wait(&self) { - self.runtime.wait_for_all_tasks(); - self.decrement_count_to_zero(); - } -} From 0a3dab6e62c91945d6aaff969db3321773c0a930 Mon Sep 17 00:00:00 2001 From: Genaro-Chris Date: Fri, 9 Aug 2024 23:10:31 +0100 Subject: [PATCH 09/10] ready soon --- Cargo.toml | 1 - src/async_runtime/executor.rs | 100 +++++++++---------- src/async_runtime/mod.rs | 1 - src/async_runtime/pin_macro.rs | 8 -- src/async_runtime/task.rs | 91 ++++++++++-------- src/async_stream/mod.rs | 153 +++++++++++++++++++----------- src/discarding_spawn_group.rs | 2 +- src/err_spawn_group.rs | 34 +++---- src/executors/future_executor.rs | 44 ++++----- src/executors/mod.rs | 7 +- src/executors/parker.rs | 135 -------------------------- src/executors/suspender.rs | 109 +++++++++++++++++++++ src/executors/task_executor.rs | 62 +++++++----- src/executors/waker.rs | 49 ++++++++-- src/executors/waker_traits.rs | 106 --------------------- src/lib.rs | 56 +++++------ src/shared/mod.rs | 1 + src/shared/priority.rs | 3 +- src/shared/priority_task.rs | 51 ++++++++++ src/shared/runtime.rs | 89 ++++++++++------- src/sleeper/delay.rs | 36 ------- src/sleeper/mod.rs | 25 ----- src/spawn_group.rs | 28 +++--- src/threadpool_impl/channel.rs | 89 ++++++++++------- src/threadpool_impl/threadpool.rs | 46 +++------ src/yield_now/mod.rs | 34 ------- src/yield_now/yielder.rs | 25 ----- 27 files changed, 639 insertions(+), 746 deletions(-) delete mode 100755 src/async_runtime/pin_macro.rs delete mode 100755 src/executors/parker.rs create mode 100755 src/executors/suspender.rs delete mode 100644 src/executors/waker_traits.rs create mode 100644 src/shared/priority_task.rs delete mode 100755 src/sleeper/delay.rs delete mode 100755 src/sleeper/mod.rs delete mode 100755 src/yield_now/mod.rs delete mode 100755 src/yield_now/yielder.rs diff --git a/Cargo.toml b/Cargo.toml index 118de02..7c85c70 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,5 @@ publish = true [dependencies] -async-trait = "0.1.73" futures-lite = "1.13.0" async-mutex = "1.4.0" diff --git a/src/async_runtime/executor.rs b/src/async_runtime/executor.rs index 1d9517e..d4194c7 100644 --- a/src/async_runtime/executor.rs +++ b/src/async_runtime/executor.rs @@ -1,51 +1,28 @@ -use crate::{executors::waker_helper, threadpool_impl::{Channel, ThreadPool}}; +use crate::{ + executors::{pair, Suspender, WAKER_PAIR}, + threadpool_impl::ThreadPool, +}; use super::task::Task; use std::{ future::Future, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - task::{Context, Poll, Waker}, + task::{Context, Poll}, }; pub struct Executor { - pool: Arc, - cancelled: Arc, - task_queue: Channel, + pool: ThreadPool, } impl Executor { pub(crate) fn new(count: usize) -> Self { - let result: Executor = Self { - pool: Arc::new(ThreadPool::new(count)), - cancelled: Arc::new(AtomicBool::new(false)), - task_queue: Channel::new(), - }; - result.start(); - result + Self { + pool: ThreadPool::new(count), + } } } impl Executor { - fn start(&self) { - let queue: Channel = self.task_queue.clone(); - let cancelled: Arc = self.cancelled.clone(); - let pool: Arc = self.pool.clone(); - std::thread::spawn(move || loop { - let Some(task) = queue.clone().dequeue() else { - return; - }; - if cancelled.load(Ordering::Acquire) || task.is_cancelled() { - continue; - } - let queue: Channel = queue.clone(); - pool.submit(move || block_task_with(task, &queue)) - }); - } - pub(crate) fn submit(&self, task: Task) where Task: FnOnce() + Send + 'static, @@ -55,19 +32,40 @@ impl Executor { pub(crate) fn spawn(&self, task: F) -> Task where - F: Future + Send + 'static, + F: Future + 'static, { let task: Task = Task::new(task); - self.task_queue.enqueue(task.clone()); + self.async_poll_task(task.clone()); task } + fn async_poll_task(&self, task: Task) { + if task.is_completed() || task.is_cancelled() { + return; + } + + self.submit(move || { + WAKER_PAIR.with(move |waker_pair| { + match waker_pair.try_borrow_mut() { + Ok(waker_pair) => { + let (suspender, waker) = &*waker_pair; + let mut context: Context<'_> = Context::from_waker(waker); + poll_task(task, suspender, &mut context) + } + Err(_) => { + let (suspender, waker) = pair(); + let mut context: Context<'_> = Context::from_waker(&waker); + poll_task(task, &suspender, &mut context) + } + }; + }); + }); + } + pub(crate) fn cancel(&self) { self.pool.clear(); - self.cancelled.store(true, Ordering::Release); self.poll_all(); self.pool.clear(); - self.cancelled.store(false, Ordering::Release); } pub(crate) fn poll_all(&self) { @@ -75,27 +73,23 @@ impl Executor { } pub(crate) fn end(&self) { - self.cancelled.store(true, Ordering::Release); - self.task_queue.close(); self.pool.end(); } } -fn block_task_with(task: Task, queue: &Channel) { - if task.is_completed() || task.is_cancelled() { - return; - } - let task_clone = task.clone(); - let waker: Waker = waker_helper(|| {}); - let mut context: Context<'_> = Context::from_waker(&waker); - let Ok(mut future) = task.lock() else { - return; - }; - match future.as_mut().poll(&mut context) { - Poll::Ready(()) => task_clone.complete(), - Poll::Pending => { - if !task_clone.is_cancelled() { - queue.enqueue(task_clone); +#[inline] +fn poll_task(task: Task, suspender: &Suspender, context: &mut Context<'_>) { + let mut task = task; + loop { + match task.poll_task(context) { + Poll::Ready(()) => { + return; + } + Poll::Pending => { + suspender.suspend(); + if task.is_cancelled() { + return; + } } } } diff --git a/src/async_runtime/mod.rs b/src/async_runtime/mod.rs index 537875a..20991e8 100755 --- a/src/async_runtime/mod.rs +++ b/src/async_runtime/mod.rs @@ -1,3 +1,2 @@ pub(crate) mod executor; -mod pin_macro; pub(crate) mod task; diff --git a/src/async_runtime/pin_macro.rs b/src/async_runtime/pin_macro.rs deleted file mode 100755 index 74c3caa..0000000 --- a/src/async_runtime/pin_macro.rs +++ /dev/null @@ -1,8 +0,0 @@ -/// Pins the variable implementing the ``std::future::Future`` trait onto the stack -#[macro_export] -macro_rules! pin_future { - ($x:ident) => { - let mut $x = $x; - let mut $x = unsafe { std::pin::Pin::new_unchecked(&mut $x) }; - }; -} diff --git a/src/async_runtime/task.rs b/src/async_runtime/task.rs index 9b6b5d7..43b5a65 100755 --- a/src/async_runtime/task.rs +++ b/src/async_runtime/task.rs @@ -1,69 +1,84 @@ use std::{ - future::Future, - ops::Deref, - pin::Pin, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, Mutex, - }, - task::{Context, Poll}, + future::Future, hint, pin::Pin, rc::Rc, sync::atomic::{AtomicBool, AtomicU8, Ordering}, task::{Context, Poll} }; -type LocalBoxedFuture = Pin + Send + 'static>>; +type LocalFuture = dyn Future; #[derive(Clone)] -pub struct Task { - future: Arc>, - complete: Arc, - cancelled: Arc, +pub(crate) struct Task { + inner: Rc, } impl Task { - pub(crate) fn new + Send + 'static>(fut: Fut) -> Self { + pub(crate) fn new + 'static>(future: Fut) -> Self { Self { - future: Arc::new(Mutex::new(Box::pin(fut))), - complete: Arc::new(AtomicBool::new(false)), - cancelled: Arc::new(AtomicBool::new(false)), + inner: Rc::new(Inner::new(future)), } } +} +impl Task { pub(crate) fn is_completed(&self) -> bool { - self.complete.load(Ordering::Relaxed) + self.inner.complete.load(Ordering::Acquire) } pub(crate) fn complete(&self) { - self.complete.store(true, Ordering::Release); + self.inner.complete.store(true, Ordering::Release) } - pub(crate) fn cancel(&self) { - self.cancelled.store(true, Ordering::Release); + pub(crate) fn cancel_task(&self) { + self.inner.cancelled.store(true, Ordering::Release) } pub(crate) fn is_cancelled(&self) -> bool { - self.cancelled.load(Ordering::Acquire) + self.inner.cancelled.load(Ordering::Acquire) + } + + pub(crate) fn poll_task(&mut self, cx: &mut Context<'_>) -> Poll<()> { + // ensures that only this method is polling the future right now regardless of all other cloned tasks + // basically a lightweight spinlock to prevent data race bugs while polling + while self + .inner + .poll_check + .compare_exchange(0, 1, Ordering::Relaxed, Ordering::Relaxed) + .is_err() + { + hint::spin_loop(); + } + + let result = unsafe { Pin::new_unchecked(&mut (*self.inner.ptr)).poll(cx) }; + if result.is_ready() { + self.complete(); + } + self.inner.poll_check.store(0, Ordering::Release); + result } } -impl Deref for Task { - type Target = Arc>; +unsafe impl Send for Task {} + +struct Inner { + poll_check: AtomicU8, + ptr: *mut LocalFuture, + cancelled: AtomicBool, + complete: AtomicBool, +} - fn deref(&self) -> &Self::Target { - &self.future +impl Inner { + fn new(future: impl Future + 'static) -> Self { + Self { + poll_check: AtomicU8::new(0), + ptr: Box::into_raw(Box::new(future)), + complete: AtomicBool::new(false), + cancelled: AtomicBool::new(false), + } } } -impl Future for Task { - type Output = (); - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.future.lock().unwrap().as_mut().poll(cx) { - Poll::Ready(()) => { - self.complete(); - Poll::Ready(()) - } - Poll::Pending => { - cx.waker().wake_by_ref(); - Poll::Pending - } +impl Drop for Inner { + fn drop(&mut self) { + unsafe { + _ = Box::from_raw(self.ptr); } } } diff --git a/src/async_stream/mod.rs b/src/async_stream/mod.rs index 8b19750..82edc3e 100755 --- a/src/async_stream/mod.rs +++ b/src/async_stream/mod.rs @@ -1,112 +1,157 @@ use std::{ collections::VecDeque, + future::Future, pin::Pin, sync::{ - atomic::{AtomicBool, AtomicUsize, Ordering}, + atomic::{AtomicUsize, Ordering}, Arc, }, - task::{Context, Poll}, + task::{Context, Poll, Waker}, }; -use async_mutex::{Mutex, MutexGuard}; -use futures_lite::{Stream, StreamExt}; +use async_mutex::Mutex; +use futures_lite::Stream; -use crate::executors::block_on; +pub struct AsyncStream { + inner: Arc>, +} -pub struct AsyncStream { - buffer: Arc>>, - item_count: Arc, - task_count: Arc, - cancelled: Arc, +struct Inner { + inner_lock: Mutex>, + item_count: AtomicUsize, +} + +impl Inner { + fn new() -> Self { + Self { + inner_lock: Mutex::new(InnerState::new()), + item_count: AtomicUsize::new(0), + } + } } impl AsyncStream { #[inline] pub(crate) async fn insert_item(&self, value: ItemType) { - _ = self - .cancelled - .compare_exchange(true, false, Ordering::Relaxed, Ordering::Relaxed); - self.buffer.lock().await.push_back(value); + let mut inner_lock = self.inner.inner_lock.lock().await; + inner_lock.buffer.push_back(value); + // check if any waker was registered + let Some(waker) = inner_lock.waker.take() else { + return; + }; + // wakeup the waker + waker.wake(); } } impl AsyncStream { pub(crate) async fn buffer_count(&self) -> usize { - self.buffer.lock().await.len() + self.inner.inner_lock.lock().await.buffer.len() } } impl AsyncStream { pub(crate) fn increment(&self) { - self.task_count.fetch_add(1, Ordering::SeqCst); - self.item_count.fetch_add(1, Ordering::SeqCst); + self.inner.item_count.fetch_add(1, Ordering::Relaxed); } } impl AsyncStream { pub async fn first(&mut self) -> Option { - self.next().await + let mut inner_lock = self.inner.inner_lock.lock().await; + if inner_lock.buffer.is_empty() || self.item_count() == 0 { + return None; + } + + let value = inner_lock.buffer.pop_front()?; + self.inner.item_count.fetch_sub(1, Ordering::Relaxed); + Some(value) } } impl AsyncStream { - pub(crate) fn task_count(&self) -> usize { - self.task_count.load(Ordering::Acquire) - } - - pub(crate) fn decrement_task_count(&self) { - self.task_count.fetch_sub(1, Ordering::SeqCst); - } - pub(crate) fn item_count(&self) -> usize { - self.item_count.load(Ordering::Acquire) - } - - pub(crate) fn cancel_tasks(&mut self) { - self.cancelled.store(true, Ordering::Release); - self.task_count.store(0, Ordering::Release); + self.inner.item_count.load(Ordering::Acquire) } } impl Clone for AsyncStream { fn clone(&self) -> Self { Self { - buffer: self.buffer.clone(), - task_count: self.task_count.clone(), - item_count: self.item_count.clone(), - cancelled: self.cancelled.clone(), + inner: self.inner.clone(), } } } impl AsyncStream { pub(crate) fn new() -> Self { - AsyncStream:: { - buffer: Arc::new(Mutex::new(VecDeque::new())), - item_count: Arc::new(AtomicUsize::new(0)), - task_count: Arc::new(AtomicUsize::new(0)), - cancelled: Arc::new(AtomicBool::new(false)), + AsyncStream { + inner: Arc::new(Inner::new()), + } + } +} + +enum Stages { + Empty, + Wait, + Ready(T), +} + +struct InnerState { + buffer: VecDeque, + waker: Option, +} + +impl InnerState { + fn new() -> InnerState { + Self { + buffer: VecDeque::with_capacity(1000), + waker: None, } } } +impl AsyncStream { + fn poll(&self, cx: &mut Context<'_>) -> Poll>> { + let waker = cx.waker().clone(); + let mut future = async move { + let mut inner_lock = self.inner.inner_lock.lock().await; + if self.item_count() == 0 && inner_lock.buffer.is_empty() { + return Stages::Empty; + } + let Some(value) = inner_lock.buffer.pop_front() else { + // register the waker so we can called it later + inner_lock.waker.replace(waker); + return Stages::Wait; + }; + + self.inner.item_count.fetch_sub(1, Ordering::Relaxed); + Stages::Ready(Some(value)) + }; + unsafe { Future::poll(Pin::new_unchecked(&mut future), cx) } + } +} + impl Stream for AsyncStream { type Item = ItemType; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - block_on(async move { - let mut inner_lock: MutexGuard<'_, VecDeque> = self.buffer.lock().await; - if self.cancelled.load(Ordering::Relaxed) && inner_lock.is_empty() - || self.item_count() == 0 - { - return Poll::Ready(None); - } - let Some(value) = inner_lock.pop_front() else { + match self.poll(cx) { + Poll::Pending => { + // This means the lock has not been acquired yet + // so immediately wake up this waker cx.waker().wake_by_ref(); - return Poll::Pending; - }; - self.item_count.fetch_sub(1, Ordering::SeqCst); - Poll::Ready(Some(value)) - }) + Poll::Pending + } + Poll::Ready(stage) => match stage { + Stages::Empty => Poll::Ready(None), + Stages::Wait => Poll::Pending, + Stages::Ready(value) => Poll::Ready(value), + }, + } + } + + fn size_hint(&self) -> (usize, Option) { + (0, Some(self.item_count())) } } diff --git a/src/discarding_spawn_group.rs b/src/discarding_spawn_group.rs index e14db42..1e00456 100755 --- a/src/discarding_spawn_group.rs +++ b/src/discarding_spawn_group.rs @@ -92,7 +92,7 @@ impl DiscardingSpawnGroup { /// - true: if there's no child task still running /// - false: if any child task is still running pub fn is_empty(&self) -> bool { - if self.runtime.stream().task_count() == 0 { + if self.runtime.task_count() == 0 { return true; } false diff --git a/src/err_spawn_group.rs b/src/err_spawn_group.rs index a6d8efd..44c3f0f 100755 --- a/src/err_spawn_group.rs +++ b/src/err_spawn_group.rs @@ -25,7 +25,7 @@ use std::{ /// /// It dereferences into a ``futures`` crate ``Stream`` type where the results of each finished child task is stored and it pops out the result in First-In First-Out /// FIFO order whenever it is being used -pub struct ErrSpawnGroup { +pub struct ErrSpawnGroup { /// A field that indicates if the spawn group had been cancelled pub is_cancelled: bool, count: Arc, @@ -33,7 +33,7 @@ pub struct ErrSpawnGroup { wait_at_drop: bool, } -impl ErrSpawnGroup { +impl ErrSpawnGroup { /// Instantiates `ErrSpawnGroup` with a specific number of threads to use in the underlying threadpool when polling futures /// /// # Parameters @@ -49,14 +49,14 @@ impl ErrSpawnGroup { } } -impl ErrSpawnGroup { +impl ErrSpawnGroup { /// Don't implicity wait for spawned child tasks to finish before being dropped pub fn dont_wait_at_drop(&mut self) { self.wait_at_drop = false; } } -impl ErrSpawnGroup { +impl ErrSpawnGroup { /// Spawns a new task into the spawn group /// /// # Parameters @@ -65,9 +65,7 @@ impl ErrSpawnGroup { /// * `closure`: an async closure that return a value of type ``Result`` pub fn spawn_task(&mut self, priority: Priority, closure: F) where - F: Future> - + Send - + 'static, + F: Future> + Send + 'static, { self.increment_count(); self.runtime.write_task(priority, closure); @@ -89,9 +87,7 @@ impl ErrSpawnGroup { /// * `closure`: an async closure that return a value of type ``Result`` pub fn spawn_task_unlessed_cancelled(&mut self, priority: Priority, closure: F) where - F: Future> - + Send - + 'static, + F: Future> + Send + 'static, { if !self.is_cancelled { self.spawn_task(priority, closure) @@ -99,21 +95,21 @@ impl ErrSpawnGroup { } } -impl ErrSpawnGroup { +impl ErrSpawnGroup { /// Returns the first element of the stream, or None if it is empty. pub async fn first(&self) -> Option> { self.runtime.stream().first().await } } -impl ErrSpawnGroup { +impl ErrSpawnGroup { /// Returns an instance of the `Stream` trait. pub fn stream(&self) -> impl Stream> { self.runtime.stream() } } -impl ErrSpawnGroup { +impl ErrSpawnGroup { /// Waits for all remaining child tasks for finish. pub async fn wait_for_all(&mut self) { self.wait_non_async() @@ -126,7 +122,7 @@ impl ErrSpawnGroup { } } -impl ErrSpawnGroup { +impl ErrSpawnGroup { fn increment_count(&self) { self.count.fetch_add(1, Ordering::Acquire); } @@ -140,7 +136,7 @@ impl ErrSpawnGroup { } } -impl ErrSpawnGroup { +impl ErrSpawnGroup { /// A Boolean value that indicates whether the group has any remaining tasks. /// /// At the start of the body of a ``with_err_spawn_group`` function call, or before calling ``spawn_task`` or ``spawn_task_unless_cancelled`` methods @@ -150,14 +146,14 @@ impl ErrSpawnGroup { /// - true: if there's no child task still running /// - false: if any child task is still running pub fn is_empty(&self) -> bool { - if self.count() == 0 || self.runtime.stream().task_count() == 0 { + if self.count() == 0 || self.runtime.task_count() == 0 { return true; } false } } -impl ErrSpawnGroup { +impl ErrSpawnGroup { /// Waits for a specific number of spawned child tasks to finish and returns their respectively result as a vector /// /// # Panics @@ -202,7 +198,7 @@ impl ErrSpawnGroup { } } -impl Drop for ErrSpawnGroup { +impl Drop for ErrSpawnGroup { fn drop(&mut self) { if self.wait_at_drop { self.runtime.wait_for_all_tasks(); @@ -211,7 +207,7 @@ impl Drop for ErrSpawnGroup Stream for ErrSpawnGroup { +impl Stream for ErrSpawnGroup { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { diff --git a/src/executors/future_executor.rs b/src/executors/future_executor.rs index 31ab109..8e92a41 100644 --- a/src/executors/future_executor.rs +++ b/src/executors/future_executor.rs @@ -4,17 +4,7 @@ use std::{ task::{Context, Poll, Waker}, }; -use crate::{executors::waker::waker_helper, pin_future}; - -use super::parker::{pair, Parker}; - -fn parker_and_waker() -> (Parker, Waker) { - let (parker, unparker) = pair(); - let waker = waker_helper(move || { - unparker.unpark(); - }); - (parker, waker) -} +use super::suspender::{pair, Suspender}; /// Blocks the current thread until the future is polled to finish. /// @@ -28,34 +18,34 @@ fn parker_and_waker() -> (Parker, Waker) { /// ``` /// pub fn block_on(future: Fut) -> Fut::Output { - pin_future!(future); + let mut future = future; + let mut future = unsafe { std::pin::Pin::new_unchecked(&mut future) }; thread_local! { - static WAKER_PAIR: RefCell<(Parker, Waker)> = { - RefCell::new(parker_and_waker()) + static WAKER_PAIR: RefCell<(Suspender, Waker)> = { + RefCell::new(pair()) }; } return WAKER_PAIR.with(|waker_pair| match waker_pair.try_borrow_mut() { Ok(waker_pair) => { - let (parker, waker) = &*waker_pair; + let (suspender, waker) = &*waker_pair; let mut context: Context<'_> = Context::from_waker(waker); loop { - match future.as_mut().poll(&mut context) { - Poll::Ready(output) => return output, - Poll::Pending => parker.park(), - } + let Poll::Ready(output) = Future::poll(future.as_mut(), &mut context) else { + suspender.suspend(); + continue; + }; + return output; } } Err(_) => { - let (parker, unparker) = pair(); - let waker = waker_helper(move || { - unparker.unpark(); - }); + let (suspender, waker) = pair(); let mut context: Context<'_> = Context::from_waker(&waker); loop { - match future.as_mut().poll(&mut context) { - Poll::Ready(output) => return output, - Poll::Pending => parker.park(), - } + let Poll::Ready(output) = Future::poll(future.as_mut(), &mut context) else { + suspender.suspend(); + continue; + }; + return output; } } }); diff --git a/src/executors/mod.rs b/src/executors/mod.rs index 99211db..e02e9d1 100755 --- a/src/executors/mod.rs +++ b/src/executors/mod.rs @@ -1,9 +1,8 @@ mod future_executor; -mod parker; +mod suspender; mod task_executor; mod waker; -mod waker_traits; pub use future_executor::block_on; -pub(crate) use task_executor::block_task; -pub(crate) use waker::waker_helper; +pub(crate) use task_executor::{block_task, WAKER_PAIR}; +pub(crate) use suspender::{Suspender, pair}; \ No newline at end of file diff --git a/src/executors/parker.rs b/src/executors/parker.rs deleted file mode 100755 index 014b660..0000000 --- a/src/executors/parker.rs +++ /dev/null @@ -1,135 +0,0 @@ -use std::{ - cell::Cell, - marker::PhantomData, - sync::{ - atomic::{AtomicUsize, Ordering::SeqCst}, - Arc, Condvar, Mutex, - }, - task::{Wake, Waker}, -}; - -pub(crate) fn pair() -> (Parker, Unparker) { - let p = Parker::new(); - let u = p.unparker(); - (p, u) -} - -pub(crate) struct Parker { - unparker: Unparker, - _marker: PhantomData>, -} - -impl Parker { - pub(crate) fn new() -> Parker { - Parker { - unparker: Unparker { - inner: Arc::new(Inner { - state: AtomicUsize::new(EMPTY), - lock: Mutex::new(()), - cvar: Condvar::new(), - }), - }, - _marker: PhantomData, - } - } - - pub(crate) fn park(&self) { - self.unparker.inner.park(); - } - - pub(crate) fn unparker(&self) -> Unparker { - self.unparker.clone() - } -} - -pub(crate) struct Unparker { - inner: Arc, -} - -impl Unparker { - pub(crate) fn unpark(&self) { - self.inner.unpark(); - } -} - -impl Clone for Unparker { - fn clone(&self) -> Unparker { - Unparker { - inner: self.inner.clone(), - } - } -} - -impl From for Waker { - fn from(up: Unparker) -> Self { - Waker::from(up.inner) - } -} - -const EMPTY: usize = 0; -const PARKED: usize = 1; -const NOTIFIED: usize = 2; - -struct Inner { - state: AtomicUsize, - lock: Mutex<()>, - cvar: Condvar, -} - -impl Inner { - fn park(&self) { - if self - .state - .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst) - .is_ok() - { - return; - } - - let mut m = self.lock.lock().unwrap(); - - match self.state.compare_exchange(EMPTY, PARKED, SeqCst, SeqCst) { - Ok(_) => {} - Err(NOTIFIED) => { - let old = self.state.swap(EMPTY, SeqCst); - assert_eq!(old, NOTIFIED, "park state changed unexpectedly"); - return; - } - Err(n) => panic!("inconsistent park_timeout state: {}", n), - } - - loop { - m = self.cvar.wait(m).unwrap(); - - if self - .state - .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst) - .is_ok() - { - return; - } - } - } - - pub(crate) fn unpark(&self) { - match self.state.swap(NOTIFIED, SeqCst) { - EMPTY => return, - NOTIFIED => return, - PARKED => {} - _ => panic!("inconsistent state in unpark"), - } - - drop(self.lock.lock().unwrap()); - self.cvar.notify_one(); - } -} - -impl Wake for Inner { - fn wake(self: Arc) { - self.unpark(); - } - - fn wake_by_ref(self: &Arc) { - self.unpark(); - } -} diff --git a/src/executors/suspender.rs b/src/executors/suspender.rs new file mode 100755 index 0000000..23cfd62 --- /dev/null +++ b/src/executors/suspender.rs @@ -0,0 +1,109 @@ +use std::{ + sync::{Arc, Condvar, Mutex}, + task::Waker, +}; + +use super::waker::waker_helper; + +pub(crate) fn pair() -> (Suspender, Waker) { + let suspender = Suspender::new(); + let resumer = suspender.resumer(); + (suspender, waker_helper(move || resumer.resume())) +} + +pub(crate) struct Suspender { + resumer: Resumer, +} + +impl Suspender { + pub(crate) fn new() -> Suspender { + Suspender { + resumer: Resumer { + inner: Arc::new(Inner { + lock: Mutex::new(State::Empty), + cvar: Condvar::new(), + }), + }, + } + } + + pub(crate) fn suspend(&self) { + self.resumer.inner.suspend(); + } + + pub(crate) fn resumer(&self) -> Resumer { + Resumer { + inner: self.resumer.inner.clone(), + } + } +} + +pub(crate) struct Resumer { + inner: Arc, +} + +impl Resumer { + pub(crate) fn resume(&self) { + self.inner.resume(); + } +} + +#[derive(PartialEq)] +enum State { + Empty, + Notified, + Suspended, +} + +struct Inner { + lock: Mutex, + cvar: Condvar, +} + +impl Inner { + #[inline] + fn suspend(&self) { + // Acquire the lock first + let Ok(mut lock) = self.lock.lock() else { + return; + }; + + // check the state the lock is in right now + match *lock { + // suspend the thread + State::Empty => *lock = State::Suspended, + // already notified this thread so just revert state back to empty + // then return + State::Notified => { + *lock = State::Empty; + return; + } + State::Suspended => panic!("cannot suspend a thread that is already in a suspended state"), + } + + // suspend this thread until we get a notification + while *lock == State::Suspended { + lock = self.cvar.wait(lock).unwrap(); + } + + // revert state back to empty so the next time this method is called + // it can suspend this callee thread + *lock = State::Empty; + } + + #[inline] + fn resume(&self) { + // Acquire the lock first + let Ok(mut lock) = self.lock.lock() else { + return; + }; + + // check if the state is not in a notified state yet + if *lock != State::Notified { + // send notification + *lock = State::Notified; + // resume the suspended thread + self.cvar.notify_one(); + } + } +} diff --git a/src/executors/task_executor.rs b/src/executors/task_executor.rs index 2985d65..6b7899f 100644 --- a/src/executors/task_executor.rs +++ b/src/executors/task_executor.rs @@ -1,34 +1,48 @@ -use std::task::{Context, Poll, Waker}; +use std::{ + cell::RefCell, + task::{Context, Poll, Waker}, +}; -use crate::async_runtime::task::Task; +use crate::shared::priority_task::PrioritizedTask; -use super::{ - parker::{pair, Parker}, - waker::waker_helper, -}; +use super::{pair, Suspender}; -fn parker_and_waker() -> (Parker, Waker) { - let (parker, unparker) = pair(); - let waker = waker_helper(move || { - unparker.unpark(); - }); - (parker, waker) +thread_local! { + pub(crate) static WAKER_PAIR: RefCell<(Suspender, Waker)> = { + RefCell::new(pair()) + }; } -pub(crate) fn block_task(task: Task) { +pub(crate) fn block_task(task: PrioritizedTask) { if task.is_completed() { return; } - let (parker, waker) = parker_and_waker(); - let mut context: Context<'_> = Context::from_waker(&waker); - let Ok(mut future) = task.lock() else { - return; - }; - loop { - match future.as_mut().poll(&mut context) { - Poll::Ready(output) => return output, - Poll::Pending => parker.park(), - } - } + WAKER_PAIR.with(move |waker_pair| { + let mut task = task; + match waker_pair.try_borrow_mut() { + Ok(waker_pair) => { + let (suspender, waker) = &*waker_pair; + let mut context: Context<'_> = Context::from_waker(waker); + loop { + let Poll::Ready(()) = task.poll_task(&mut context) else { + suspender.suspend(); + continue; + }; + return; + } + } + Err(_) => { + let (suspender, waker) = pair(); + let mut context: Context<'_> = Context::from_waker(&waker); + loop { + let Poll::Ready(()) = task.poll_task(&mut context) else { + suspender.suspend(); + continue; + }; + return; + } + } + }; + }); } diff --git a/src/executors/waker.rs b/src/executors/waker.rs index 1c4b35b..93fd6ba 100644 --- a/src/executors/waker.rs +++ b/src/executors/waker.rs @@ -1,15 +1,48 @@ -use std::{sync::Arc, task::Waker}; - -use super::waker_traits::{IntoWaker, WakeRef}; +use std::{ + mem, + sync::Arc, + task::{RawWaker, RawWakerVTable, Waker}, +}; +// Waker implementation struct WakerHelper(F); -pub(crate) fn waker_helper(f: F) -> Waker { - Arc::new(WakerHelper(f)).into_waker() +pub(crate) fn waker_helper(f: F) -> Waker { + let raw: *const () = Arc::into_raw(Arc::new(f)) as *const (); + let vtable: &RawWakerVTable = &WakerHelper::::VTABLE; + unsafe { Waker::from_raw(RawWaker::new(raw, vtable)) } } -impl WakeRef for WakerHelper { - fn wake_by_ref(&self) { - (self.0)(); +impl WakerHelper { + // A virtual function table (vtable) that specifies the behavior of a RawWaker + const VTABLE: RawWakerVTable = RawWakerVTable::new( + Self::clone_waker, + Self::wake, + Self::wake_by_ref, + Self::drop_waker, + ); + + // clones the waker + unsafe fn clone_waker(ptr: *const ()) -> RawWaker { + let arc: mem::ManuallyDrop> = mem::ManuallyDrop::new(Arc::from_raw(ptr as *const F)); + _ = arc.clone(); + RawWaker::new(ptr, &Self::VTABLE) + } + + // wakes up by consuming it + unsafe fn wake(ptr: *const ()) { + let arc: Arc = Arc::from_raw(ptr as *const F); + (arc)(); + } + + // wakes up by reference + unsafe fn wake_by_ref(ptr: *const ()) { + let arc: mem::ManuallyDrop> = mem::ManuallyDrop::new(Arc::from_raw(ptr as *const F)); + (arc)(); + } + + // drops the waker + unsafe fn drop_waker(ptr: *const ()) { + drop(Arc::from_raw(ptr as *const F)) } } diff --git a/src/executors/waker_traits.rs b/src/executors/waker_traits.rs deleted file mode 100644 index f3ee5af..0000000 --- a/src/executors/waker_traits.rs +++ /dev/null @@ -1,106 +0,0 @@ -use std::{ - mem::ManuallyDrop, - sync::Arc, - task::{RawWaker, RawWakerVTable, Waker}, -}; - -/// # Safety -/// All safe here -pub(crate) unsafe trait ViaRawPointer { - type Target: ?Sized; - - fn into_raw(self) -> *mut Self::Target; - - unsafe fn from_raw(ptr: *mut Self::Target) -> Self; -} - -pub(crate) trait WakeRef { - fn wake_by_ref(&self); -} - -pub(crate) trait Wake: WakeRef + Sized { - #[inline] - fn wake(self) { - self.wake_by_ref() - } -} - -pub(crate) trait IntoWaker { - const VTABLE: &'static RawWakerVTable; - - #[must_use] - fn into_waker(self) -> Waker; -} - -impl IntoWaker for T -where - T: Wake + Clone + 'static + ViaRawPointer, - T::Target: Sized, -{ - const VTABLE: &'static RawWakerVTable = &RawWakerVTable::new( - // clone - |raw| { - let raw = raw as *mut T::Target; - - let waker = ManuallyDrop::::new(unsafe { ViaRawPointer::from_raw(raw) }); - let cloned: T = (*waker).clone(); - - // We can't save the `into_raw` back into the raw waker, so we must - // simply assert that the pointer has remained the same. This is - // part of the ViaRawPointer safety contract, so we only check it - // in debug builds. - debug_assert_eq!(ManuallyDrop::into_inner(waker).into_raw(), raw); - - let cloned_raw = cloned.into_raw(); - let cloned_raw = cloned_raw as *const (); - RawWaker::new(cloned_raw, T::VTABLE) - }, - // wake by value - |raw| { - let raw = raw as *mut T::Target; - let waker: T = unsafe { ViaRawPointer::from_raw(raw) }; - waker.wake(); - }, - // wake by ref - |raw| { - let raw = raw as *mut T::Target; - let waker = ManuallyDrop::::new(unsafe { ViaRawPointer::from_raw(raw) }); - waker.wake_by_ref(); - - debug_assert_eq!(ManuallyDrop::into_inner(waker).into_raw(), raw); - }, - // Drop - |raw| { - let raw = raw as *mut T::Target; - let _waker: T = unsafe { ViaRawPointer::from_raw(raw) }; - }, - ); - - fn into_waker(self) -> Waker { - let raw = self.into_raw(); - let raw = raw as *const (); - let raw_waker = RawWaker::new(raw, T::VTABLE); - unsafe { Waker::from_raw(raw_waker) } - } -} - -unsafe impl ViaRawPointer for Arc { - type Target = T; - - fn into_raw(self) -> *mut T { - Arc::into_raw(self) as *mut T - } - - unsafe fn from_raw(ptr: *mut T) -> Self { - Arc::from_raw(ptr as *const T) - } -} - -impl WakeRef for Arc { - #[inline] - fn wake_by_ref(&self) { - T::wake_by_ref(self.as_ref()) - } -} - -impl Wake for Arc {} diff --git a/src/lib.rs b/src/lib.rs index f154fac..0c302f2 100755 --- a/src/lib.rs +++ b/src/lib.rs @@ -91,9 +91,6 @@ //! See [`with_discarding_spawn_group`](self::with_discarding_spawn_group) //! for more information //! -//! * ``sleep`` similar to ``std::thread::sleep`` but for sleeping in asynchronous environments. See [`sleep`](self::sleep) -//! for more information -//! //! * ``block_on`` polls future to finish. See [`block_on`](self::block_on) //! for more information //! @@ -174,19 +171,14 @@ mod async_stream; mod executors; mod meta_types; mod shared; -mod sleeper; mod threadpool_impl; -mod yield_now; pub use discarding_spawn_group::DiscardingSpawnGroup; pub use err_spawn_group::ErrSpawnGroup; pub use executors::block_on; pub use meta_types::GetType; pub use shared::priority::Priority; -pub use sleeper::sleep; pub use spawn_group::SpawnGroup; -pub use yield_now::ready; -pub use yield_now::yield_now; use std::future::Future; use std::marker::PhantomData; @@ -197,7 +189,7 @@ use std::thread::available_parallelism; /// This closure ensures that before the function call ends, all spawned child tasks are implicitly waited for, or the programmer can explicitly wait by calling its ``wait_for_all()`` method /// of the ``SpawnGroup`` struct. /// -/// This function use a threadpool of the same number of threads as the number of active processor count that is default amount of parallelism a program can use on the system for polling the futures +/// This function use a threadpool of the same number of threads as the number of active processor count that is default amount of parallelism a program can use on the system for polling the spawned tasks /// /// See [`SpawnGroup`](spawn_group::SpawnGroup) /// for more. @@ -236,14 +228,14 @@ use std::thread::available_parallelism; /// assert_eq!(final_result, 55); /// # }); /// ``` -pub async fn with_type_spawn_group( +pub async fn with_type_spawn_group<'a, Closure, Fut, ResultType, ReturnType>( of_type: PhantomData, body: Closure, ) -> ReturnType where - Closure: FnOnce(spawn_group::SpawnGroup) -> Fut + Send + 'static, + Closure: FnOnce(spawn_group::SpawnGroup) -> Fut + 'a, Fut: Future + Send + 'static, - ResultType: Send + 'static, + ResultType: 'static, { let count: usize; if let Ok(thread_count) = available_parallelism() { @@ -261,7 +253,7 @@ where /// This closure ensures that before the function call ends, all spawned child tasks are implicitly waited for, or the programmer can explicitly wait by calling its ``wait_for_all()`` method /// of the ``SpawnGroup`` struct. /// -/// This function use a threadpool of the same number of threads as the number of active processor count that is default amount of parallelism a program can use on the system for polling the futures +/// This function use a threadpool of the same number of threads as the number of active processor count that is default amount of parallelism a program can use on the system for polling the spawned tasks /// /// See [`SpawnGroup`](spawn_group::SpawnGroup) /// for more. @@ -299,11 +291,11 @@ where /// assert_eq!(final_result, 55); /// # }); /// ``` -pub async fn with_spawn_group(body: Closure) -> ReturnType +pub async fn with_spawn_group<'a, Closure, Fut, ResultType, ReturnType>(body: Closure) -> ReturnType where - Closure: FnOnce(spawn_group::SpawnGroup) -> Fut + Send + 'static, + Closure: FnOnce(spawn_group::SpawnGroup) -> Fut + 'a, Fut: Future + Send + 'static, - ResultType: Send + 'static, + ResultType: 'static, { let count: usize; if let Ok(thread_count) = available_parallelism() { @@ -321,7 +313,7 @@ where /// This closure ensures that before the function call ends, all spawned child tasks are implicitly waited for, or the programmer can explicitly wait by calling its ``wait_for_all()`` method /// of the ``ErrSpawnGroup`` struct /// -/// This function use a threadpool of the same number of threads as the number of active processor count that is default amount of parallelism a program can use on the system for polling the futures +/// This function use a threadpool of the same number of threads as the number of active processor count that is default amount of parallelism a program can use on the system for polling the spawned tasks /// /// See [`ErrSpawnGroup`](err_spawn_group::ErrSpawnGroup) /// for more. @@ -404,16 +396,16 @@ where /// assert_eq!(final_results.2, 2); /// # }); /// ``` -pub async fn with_err_type_spawn_group( +pub async fn with_err_type_spawn_group<'a, Closure, Fut, ResultType, ErrorType, ReturnType>( of_type: PhantomData, error_type: PhantomData, body: Closure, ) -> ReturnType where - ErrorType: Send + 'static, - Fut: Future, - Closure: FnOnce(err_spawn_group::ErrSpawnGroup) -> Fut + Send + 'static, - ResultType: Send + 'static, + ErrorType: 'static, + Fut: Future + Send, + Closure: FnOnce(err_spawn_group::ErrSpawnGroup) -> Fut + 'a, + ResultType: 'static, { let count: usize; if let Ok(thread_count) = available_parallelism() { @@ -432,7 +424,7 @@ where /// This closure ensures that before the function call ends, all spawned child tasks are implicitly waited for, or the programmer can explicitly wait by calling its ``wait_for_all()`` method /// of the ``ErrSpawnGroup`` struct /// -/// This function use a threadpoolof the same number of threads as the number of active processor count that is default amount of parallelism a program can use on the system for polling the futures +/// This function use a threadpoolof the same number of threads as the number of active processor count that is default amount of parallelism a program can use on the system for polling the spawned tasks /// /// See [`ErrSpawnGroup`](err_spawn_group::ErrSpawnGroup) /// for more. @@ -513,14 +505,14 @@ where /// assert_eq!(final_results.2, 2); /// # }); /// ``` -pub async fn with_err_spawn_group( +pub async fn with_err_spawn_group<'a, Closure, Fut, ResultType, ErrorType, ReturnType>( body: Closure, ) -> ReturnType where - ErrorType: Send + 'static, - Fut: Future, - Closure: FnOnce(err_spawn_group::ErrSpawnGroup) -> Fut + Send + 'static, - ResultType: Send + 'static, + ErrorType: 'static, + Fut: Future + Send, + Closure: FnOnce(err_spawn_group::ErrSpawnGroup) -> Fut + 'a, + ResultType: 'static, { let count: usize; if let Ok(thread_count) = available_parallelism() { @@ -536,7 +528,7 @@ where /// /// Ensures that before the function call ends, all spawned tasks are implicitly waited for /// -/// This function use a threadpool of the same number of threads as the number of active processor count that is default amount of parallelism a program can use on the system for polling the futures +/// This function use a threadpool of the same number of threads as the number of active processor count that is default amount of parallelism a program can use on the system for polling the spawned tasks /// /// See [`DiscardingSpawnGroup`](discarding_spawn_group::DiscardingSpawnGroup) /// for more. @@ -569,10 +561,10 @@ where /// }).await; /// # }); /// ``` -pub async fn with_discarding_spawn_group(body: Closure) -> ReturnType +pub async fn with_discarding_spawn_group<'a, Closure, Fut, ReturnType>(body: Closure) -> ReturnType where - Fut: Future, - Closure: FnOnce(discarding_spawn_group::DiscardingSpawnGroup) -> Fut + Send + 'static, + Fut: Future + Send, + Closure: FnOnce(discarding_spawn_group::DiscardingSpawnGroup) -> Fut + 'a, { let count: usize; if let Ok(thread_count) = available_parallelism() { diff --git a/src/shared/mod.rs b/src/shared/mod.rs index 82f8a61..665e86e 100755 --- a/src/shared/mod.rs +++ b/src/shared/mod.rs @@ -1,2 +1,3 @@ pub(crate) mod priority; pub(crate) mod runtime; +pub(crate) mod priority_task; \ No newline at end of file diff --git a/src/shared/priority.rs b/src/shared/priority.rs index 3c66a34..3c98fd1 100755 --- a/src/shared/priority.rs +++ b/src/shared/priority.rs @@ -1,6 +1,7 @@ /// Task Priority /// -/// Spawn groups uses it to rank the importance of their spawned tasks and order of returned values only when waited for. +/// Spawn groups uses it to rank the importance of their spawned tasks and order of returned values only when waited for +/// that is when the ``wait_for_all`` or ``wait_non_async`` methods are called #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Default)] pub enum Priority { BACKGROUND = 0, diff --git a/src/shared/priority_task.rs b/src/shared/priority_task.rs new file mode 100644 index 0000000..be26b6a --- /dev/null +++ b/src/shared/priority_task.rs @@ -0,0 +1,51 @@ +use std::{ + cmp::Ordering, + ops::{Deref, DerefMut}, +}; + +use crate::{async_runtime::task::Task, Priority}; + +pub(crate) struct PrioritizedTask { + task: Task, + priority: Priority, +} + +impl Deref for PrioritizedTask { + type Target = Task; + + fn deref(&self) -> &Self::Target { + &self.task + } +} + +impl DerefMut for PrioritizedTask { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.task + } +} + +impl PartialEq for PrioritizedTask { + fn eq(&self, other: &Self) -> bool { + self.priority == other.priority + } +} + +impl Eq for PrioritizedTask {} + +impl PrioritizedTask { + pub(crate) fn new(priority: Priority, task: Task) -> Self { + Self { task, priority } + } +} + +impl Ord for PrioritizedTask { + fn cmp(&self, other: &Self) -> Ordering { + other.priority.cmp(&self.priority) + } +} + +impl PartialOrd for PrioritizedTask { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} diff --git a/src/shared/runtime.rs b/src/shared/runtime.rs index e44768a..6036f2e 100755 --- a/src/shared/runtime.rs +++ b/src/shared/runtime.rs @@ -1,26 +1,34 @@ use crate::{ - async_runtime::{executor::Executor, task::Task}, - async_stream::AsyncStream, - block_on, - executors::block_task, + async_runtime::executor::Executor, async_stream::AsyncStream, executors::block_task, shared::priority::Priority, }; -use std::{cell::RefCell, future::Future}; +use std::{ + collections::BinaryHeap, + future::Future, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Mutex, + }, +}; + +use super::priority_task::PrioritizedTask; -type TaskQueue = RefCell>; +type TaskQueue = Mutex>; -pub struct RuntimeEngine { - tasks: TaskQueue, +pub(crate) struct RuntimeEngine { + prioritized_tasks: TaskQueue, runtime: Executor, + task_count: Arc, stream: AsyncStream, } impl RuntimeEngine { pub(crate) fn new(count: usize) -> Self { Self { - tasks: RefCell::new(vec![]), - stream: AsyncStream::new(), + prioritized_tasks: Mutex::new(BinaryHeap::with_capacity(1000)), runtime: Executor::new(count), + task_count: Arc::new(AtomicUsize::default()), + stream: AsyncStream::new(), } } } @@ -28,8 +36,11 @@ impl RuntimeEngine { impl RuntimeEngine { pub(crate) fn cancel(&mut self) { self.runtime.cancel(); - self.tasks.borrow_mut().clear(); - self.stream.cancel_tasks(); + let Ok(mut prioritized_tasks) = self.prioritized_tasks.lock() else { + return; + }; + prioritized_tasks.clear(); + self.task_count.store(0, Ordering::Release); self.poll(); } } @@ -41,50 +52,60 @@ impl RuntimeEngine { pub(crate) fn end(&mut self) { self.runtime.cancel(); - self.tasks.borrow_mut().clear(); - self.stream.cancel_tasks(); + let Ok(mut prioritized_tasks) = self.prioritized_tasks.lock() else { + return; + }; + prioritized_tasks.clear(); + self.task_count.store(0, Ordering::Release); self.runtime.end() } } -impl RuntimeEngine { +impl RuntimeEngine { pub(crate) fn wait_for_all_tasks(&self) { - self.runtime.cancel(); - let mut tasks = self.tasks.borrow_mut(); - if tasks.is_empty() { + let Ok(mut prioritized_tasks) = self.prioritized_tasks.lock() else { + return; + }; + if prioritized_tasks.is_empty() { return; } - tasks.retain(|(_, task)| { - task.cancel(); - !task.is_completed() + self.runtime.cancel(); + prioritized_tasks.retain(|prioritized_task| { + prioritized_task.cancel_task(); + !prioritized_task.is_completed() }); - tasks.sort_by(|lhs, rhs| lhs.0.cmp(&rhs.0)); - if tasks.is_empty() { + if prioritized_tasks.is_empty() { return; } - while let Some((_, task)) = tasks.pop() { - if task.is_completed() { + while let Some(prioritized_task) = prioritized_tasks.pop() { + if prioritized_task.is_completed() { continue; } - self.runtime.submit(move || block_task(task)); + self.runtime.submit(move || block_task(prioritized_task)); } + self.task_count.store(0, Ordering::Release); + drop(prioritized_tasks); self.poll() } } -impl RuntimeEngine { +impl RuntimeEngine { pub(crate) fn write_task(&self, priority: Priority, task: F) where - F: Future + Send + 'static, + F: Future + 'static, { + let Ok(mut prioritized_tasks) = self.prioritized_tasks.lock() else { + return; + }; self.stream.increment(); - let stream: AsyncStream = self.stream(); - self.tasks.borrow_mut().push(( + self.task_count.fetch_add(1, Ordering::SeqCst); + let (stream, task_counter) = (self.stream(), self.task_count.clone()); + prioritized_tasks.push(PrioritizedTask::new( priority, self.runtime.spawn(async move { let task_result = task.await; - block_on(async { stream.insert_item(task_result).await }); - stream.decrement_task_count(); + stream.insert_item(task_result).await; + task_counter.fetch_sub(1, Ordering::SeqCst); }), )); } @@ -94,4 +115,8 @@ impl RuntimeEngine { pub(crate) fn poll(&self) { self.runtime.poll_all(); } + + pub(crate) fn task_count(&self) -> usize { + self.task_count.load(Ordering::Acquire) + } } diff --git a/src/sleeper/delay.rs b/src/sleeper/delay.rs deleted file mode 100755 index 53a651d..0000000 --- a/src/sleeper/delay.rs +++ /dev/null @@ -1,36 +0,0 @@ -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, - time::{Duration, Instant}, -}; - -#[derive(Debug)] -#[must_use = "futures do nothing unless you `.await` or poll them"] -pub struct Delay { - duration: Duration, - now: Instant, -} - -impl Delay { - pub(crate) fn new(duration: Duration) -> Self { - Delay { - duration, - now: Instant::now(), - } - } -} - -impl Future for Delay { - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.now.elapsed() >= self.duration { - true => Poll::Ready(()), - false => { - cx.waker().wake_by_ref(); - Poll::Pending - } - } - } -} diff --git a/src/sleeper/mod.rs b/src/sleeper/mod.rs deleted file mode 100755 index ff019a8..0000000 --- a/src/sleeper/mod.rs +++ /dev/null @@ -1,25 +0,0 @@ -mod delay; - -use std::time::Duration; - -use self::delay::Delay; - -/// Sleeps for the specified amount of time. -/// -/// This function might sleep for slightly longer than the specified duration but never less. -/// -/// This function is an async version of ``std::thread::sleep``. -/// -/// Example -/// -/// ```rust -/// use spawn_groups::{block_on, sleep}; -/// use std::time::Duration; -/// -/// block_on(async { -/// sleep(Duration::from_secs(2)).await; -/// }); -/// ``` -pub fn sleep(duration: Duration) -> Delay { - Delay::new(duration) -} diff --git a/src/spawn_group.rs b/src/spawn_group.rs index b715be6..6287025 100755 --- a/src/spawn_group.rs +++ b/src/spawn_group.rs @@ -26,15 +26,15 @@ use std::{ /// It dereferences into a ``futures`` crate ``Stream`` type where the results of each finished child task is stored and it pops out the result in First-In First-Out /// FIFO order whenever it is being used -pub struct SpawnGroup { +pub struct SpawnGroup { /// A field that indicates if the spawn group had been cancelled pub is_cancelled: bool, - wait_at_drop: bool, count: Arc, runtime: RuntimeEngine, + wait_at_drop: bool, } -impl SpawnGroup { +impl SpawnGroup { /// Instantiates `SpawnGroup` with a specific number of threads to use in the underlying threadpool when polling futures /// /// # Parameters @@ -50,14 +50,14 @@ impl SpawnGroup { } } -impl SpawnGroup { +impl SpawnGroup { /// Don't implicity wait for spawned child tasks to finish before being dropped pub fn dont_wait_at_drop(&mut self) { self.wait_at_drop = false; } } -impl SpawnGroup { +impl SpawnGroup { /// Spawns a new task into the spawn group /// # Parameters /// @@ -95,14 +95,14 @@ impl SpawnGroup { } } -impl SpawnGroup { +impl SpawnGroup { /// Returns the first element of the stream, or None if it is empty. pub async fn first(&self) -> Option { self.runtime.stream().first().await } } -impl SpawnGroup { +impl SpawnGroup { /// Waits for all remaining child tasks for finish. pub async fn wait_for_all(&mut self) { self.wait_non_async() @@ -115,7 +115,7 @@ impl SpawnGroup { } } -impl SpawnGroup { +impl SpawnGroup { fn increment_count(&self) { self.count.fetch_add(1, Ordering::Acquire); } @@ -129,7 +129,7 @@ impl SpawnGroup { } } -impl SpawnGroup { +impl SpawnGroup { /// A Boolean value that indicates whether the group has any remaining tasks. /// /// At the start of the body of a ``with_spawn_group()`` call, , or before calling ``spawn_task`` or ``spawn_task_unless_cancelled`` methods @@ -139,21 +139,21 @@ impl SpawnGroup { /// - true: if there's no child task still running /// - false: if any child task is still running pub fn is_empty(&self) -> bool { - if self.count() == 0 || self.runtime.stream().task_count() == 0 { + if self.count() == 0 || self.runtime.task_count() == 0 { return true; } false } } -impl SpawnGroup { +impl SpawnGroup { /// Returns an instance of the `Stream` trait. pub fn stream(&self) -> impl Stream { self.runtime.stream() } } -impl SpawnGroup { +impl SpawnGroup { /// Waits for a specific number of spawned child tasks to finish and returns their respectively result as a vector /// /// # Panics @@ -198,7 +198,7 @@ impl SpawnGroup { } } -impl Drop for SpawnGroup { +impl Drop for SpawnGroup { fn drop(&mut self) { if self.wait_at_drop { self.runtime.wait_for_all_tasks(); @@ -207,7 +207,7 @@ impl Drop for SpawnGroup { } } -impl Stream for SpawnGroup { +impl Stream for SpawnGroup { type Item = ValueType; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { diff --git a/src/threadpool_impl/channel.rs b/src/threadpool_impl/channel.rs index 9f25cfa..ddf4483 100644 --- a/src/threadpool_impl/channel.rs +++ b/src/threadpool_impl/channel.rs @@ -6,31 +6,20 @@ use std::{ }, }; -pub struct Channel { - pair: Arc<(Mutex>, Condvar)>, - closed: Arc, +pub(crate) struct Channel { + inner: Arc>, } impl Channel { pub(crate) fn enqueue(&self, value: ItemType) { - if self.closed.load(Ordering::Relaxed) { - return; - } - let Ok(mut lock) = self.pair.0.lock() else { - return; - }; - lock.push_back(value); - if lock.len() == 1 { - self.pair.1.notify_one(); - } + self.inner.enqueue(value) } } impl Channel { - pub fn new() -> Self { + pub(crate) fn new() -> Self { Self { - pair: Arc::new((Mutex::new(VecDeque::new()), Condvar::new())), - closed: Arc::new(AtomicBool::new(false)), + inner: Arc::new(Inner::new()), } } } @@ -38,14 +27,54 @@ impl Channel { impl Clone for Channel { fn clone(&self) -> Self { Self { - pair: self.pair.clone(), - closed: self.closed.clone(), + inner: self.inner.clone(), } } } impl Channel { pub(crate) fn dequeue(&self) -> Option { + self.inner.dequeue() + } +} + +impl Channel { + pub(crate) fn close(&self) { + self.inner.close() + } + + pub(crate) fn clear(&self) { + self.inner.clear() + } +} + +struct Inner { + closed: AtomicBool, + pair: (Mutex>, Condvar), +} + +impl Inner { + fn new() -> Self { + Self { + closed: AtomicBool::new(false), + pair: (Mutex::new(VecDeque::new()), Condvar::new()), + } + } + + fn enqueue(&self, value: ItemType) { + if self.closed.load(Ordering::Relaxed) { + return; + } + let Ok(mut lock) = self.pair.0.lock() else { + return; + }; + lock.push_back(value); + if lock.len() == 1 { + self.pair.1.notify_one(); + } + } + + fn dequeue(&self) -> Option { if self.closed.load(Ordering::Relaxed) { return None; } @@ -60,31 +89,21 @@ impl Channel { } lock.pop_front() } -} -impl Channel { - pub fn close(&self) { - if self.closed.load(Ordering::Relaxed) { + fn close(&self) { + if self.closed.swap(true, Ordering::Relaxed) { return; } - if let Ok(_lock) = self.pair.0.lock() { - self.closed.store(true, Ordering::Relaxed); - self.pair.1.notify_all(); - } + let Ok(_lock) = self.pair.0.lock() else { + return; + }; + self.pair.1.notify_all(); } - pub(crate) fn clear(&self) { + fn clear(&self) { let Ok(mut lock) = self.pair.0.lock() else { return; }; lock.clear(); } } - -impl Iterator for Channel { - type Item = ItemType; - - fn next(&mut self) -> Option { - self.dequeue() - } -} diff --git a/src/threadpool_impl/threadpool.rs b/src/threadpool_impl/threadpool.rs index 0400849..e464e89 100644 --- a/src/threadpool_impl/threadpool.rs +++ b/src/threadpool_impl/threadpool.rs @@ -1,5 +1,4 @@ use std::{ - backtrace, panic, sync::atomic::{AtomicUsize, Ordering}, thread::spawn, }; @@ -7,32 +6,27 @@ use std::{ use super::{waitgroup::WaitGroup, Channel, Func}; pub(crate) struct ThreadPool { - task_channels: Vec>>, index: AtomicUsize, + task_channels: Vec>>, wait_group: WaitGroup, } impl ThreadPool { pub(crate) fn new(count: usize) -> Self { - let mut count = count; - if count < 1 { - count = 1; + let mut task_channels = Vec::with_capacity(count); + for _ in 1..=count { + let channel: Channel> = Channel::new(); + let chan = channel.clone(); + spawn(move || { + while let Some(ops) = channel.dequeue() { + ops(); + } + }); + task_channels.push(chan); } ThreadPool { - task_channels: (1..=count) - .map(|_| { - let channel: Channel> = Channel::new(); - let chan = channel.clone(); - spawn(move || { - panic_hook(); - for ops in channel { - ops(); - } - }); - chan - }) - .collect(), index: AtomicUsize::new(0), + task_channels, wait_group: WaitGroup::new(), } } @@ -42,7 +36,7 @@ impl ThreadPool { fn current_index(&self) -> usize { self.index.swap( (self.index.load(Ordering::Relaxed) + 1) % self.task_channels.len(), - Ordering::SeqCst, + Ordering::Relaxed, ) } @@ -76,7 +70,6 @@ impl ThreadPool { impl Drop for ThreadPool { fn drop(&mut self) { - _ = panic::take_hook(); self.end(); } } @@ -88,16 +81,3 @@ impl ThreadPool { .for_each(|channel| channel.clear()); } } - -fn panic_hook() { - panic::set_hook(Box::new(move |info: &panic::PanicInfo<'_>| { - let msg = format!( - "Threadpool panicked at location {} with {} \nBacktrace:\n{}", - info.location().unwrap(), - info.to_string().split('\n').collect::>()[1], - backtrace::Backtrace::capture() - ); - eprintln!("{}", msg); - _ = panic::take_hook(); - })); -} diff --git a/src/yield_now/mod.rs b/src/yield_now/mod.rs deleted file mode 100755 index 152dcdd..0000000 --- a/src/yield_now/mod.rs +++ /dev/null @@ -1,34 +0,0 @@ -use crate::yield_now::yielder::Yielder; - -mod yielder; - -/// Wakes the current task and returns [`std::task::Poll::Pending`] once. -/// -/// This function is useful when we want to cooperatively give time to the task executor. It is -/// generally a good idea to yield inside loops because that way we make sure long-running tasks -/// don't prevent other tasks from running. -/// -/// # Examples -/// ``` -/// use spawn_groups::{block_on, yield_now}; -/// block_on(async { -/// yield_now().await; -/// }); -/// ``` -pub fn yield_now() -> Yielder { - Yielder::default() -} - -/// Resolves to the provided value. -/// -/// # Examples -/// ``` -/// use spawn_groups::{block_on, ready}; -/// block_on(async { -/// let ten = ready(10).await; -/// assert_eq!(ten, 10); -/// }); -/// ``` -pub async fn ready(val: ValueType) -> ValueType { - val -} diff --git a/src/yield_now/yielder.rs b/src/yield_now/yielder.rs deleted file mode 100755 index 7436af5..0000000 --- a/src/yield_now/yielder.rs +++ /dev/null @@ -1,25 +0,0 @@ -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, -}; - -#[derive(Debug, Default)] -#[must_use = "futures do nothing unless you `.await` or poll them"] -pub struct Yielder { - yield_now: bool, -} - -impl Future for Yielder { - type Output = (); - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.yield_now { - true => Poll::Ready(()), - false => { - self.yield_now = true; - cx.waker().wake_by_ref(); - Poll::Pending - } - } - } -} From 9234a999c562a0381bb6b9d39d3789c033514dac Mon Sep 17 00:00:00 2001 From: Genaro-Chris Date: Tue, 25 Feb 2025 22:57:35 +0100 Subject: [PATCH 10/10] Ready for v2.0 --- Cargo.toml | 4 +- src/async_runtime/executor.rs | 96 ---------------- src/async_runtime/mod.rs | 2 - src/async_runtime/task.rs | 84 -------------- src/async_stream/mod.rs | 158 +------------------------- src/async_stream/stream.rs | 162 +++++++++++++++++++++++++++ src/discarding_spawn_group.rs | 13 ++- src/err_spawn_group.rs | 24 +++- src/executors/future_executor.rs | 45 +++----- src/executors/mod.rs | 5 - src/executors/suspender.rs | 109 ------------------ src/executors/task_executor.rs | 48 -------- src/executors/waker.rs | 48 -------- src/lib.rs | 74 ++++-------- src/meta_types.rs | 9 +- src/shared/mod.rs | 13 ++- src/shared/mutex.rs | 15 +++ src/shared/priority.rs | 4 +- src/shared/priority_task.rs | 51 +++++---- src/shared/runtime.rs | 101 ++++++----------- src/shared/suspender.rs | 96 ++++++++++++++++ src/shared/task.rs | 30 +++++ src/shared/task_enum.rs | 9 ++ src/shared/waker.rs | 56 +++++++++ src/shared/waker_pair.rs | 8 ++ src/spawn_group.rs | 25 +++-- src/threadpool_impl/channel.rs | 96 ++++++---------- src/threadpool_impl/mod.rs | 7 +- src/threadpool_impl/task_priority.rs | 25 +++++ src/threadpool_impl/thread.rs | 63 +++++++++++ src/threadpool_impl/threadpool.rs | 86 ++++++-------- src/threadpool_impl/waitgroup.rs | 42 ------- 32 files changed, 707 insertions(+), 901 deletions(-) delete mode 100644 src/async_runtime/executor.rs delete mode 100755 src/async_runtime/mod.rs delete mode 100755 src/async_runtime/task.rs create mode 100644 src/async_stream/stream.rs delete mode 100755 src/executors/suspender.rs delete mode 100644 src/executors/task_executor.rs delete mode 100644 src/executors/waker.rs create mode 100644 src/shared/mutex.rs create mode 100755 src/shared/suspender.rs create mode 100755 src/shared/task.rs create mode 100644 src/shared/task_enum.rs create mode 100644 src/shared/waker.rs create mode 100644 src/shared/waker_pair.rs create mode 100644 src/threadpool_impl/task_priority.rs create mode 100644 src/threadpool_impl/thread.rs delete mode 100644 src/threadpool_impl/waitgroup.rs diff --git a/Cargo.toml b/Cargo.toml index 7c85c70..574fd8e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,5 +14,5 @@ publish = true [dependencies] -futures-lite = "1.13.0" -async-mutex = "1.4.0" +async-lock = "3.4.0" +futures-lite = { version = "2.5.0", features = ["race"] } diff --git a/src/async_runtime/executor.rs b/src/async_runtime/executor.rs deleted file mode 100644 index d4194c7..0000000 --- a/src/async_runtime/executor.rs +++ /dev/null @@ -1,96 +0,0 @@ -use crate::{ - executors::{pair, Suspender, WAKER_PAIR}, - threadpool_impl::ThreadPool, -}; - -use super::task::Task; - -use std::{ - future::Future, - task::{Context, Poll}, -}; - -pub struct Executor { - pool: ThreadPool, -} - -impl Executor { - pub(crate) fn new(count: usize) -> Self { - Self { - pool: ThreadPool::new(count), - } - } -} - -impl Executor { - pub(crate) fn submit(&self, task: Task) - where - Task: FnOnce() + Send + 'static, - { - self.pool.submit(task); - } - - pub(crate) fn spawn(&self, task: F) -> Task - where - F: Future + 'static, - { - let task: Task = Task::new(task); - self.async_poll_task(task.clone()); - task - } - - fn async_poll_task(&self, task: Task) { - if task.is_completed() || task.is_cancelled() { - return; - } - - self.submit(move || { - WAKER_PAIR.with(move |waker_pair| { - match waker_pair.try_borrow_mut() { - Ok(waker_pair) => { - let (suspender, waker) = &*waker_pair; - let mut context: Context<'_> = Context::from_waker(waker); - poll_task(task, suspender, &mut context) - } - Err(_) => { - let (suspender, waker) = pair(); - let mut context: Context<'_> = Context::from_waker(&waker); - poll_task(task, &suspender, &mut context) - } - }; - }); - }); - } - - pub(crate) fn cancel(&self) { - self.pool.clear(); - self.poll_all(); - self.pool.clear(); - } - - pub(crate) fn poll_all(&self) { - self.pool.wait_for_all(); - } - - pub(crate) fn end(&self) { - self.pool.end(); - } -} - -#[inline] -fn poll_task(task: Task, suspender: &Suspender, context: &mut Context<'_>) { - let mut task = task; - loop { - match task.poll_task(context) { - Poll::Ready(()) => { - return; - } - Poll::Pending => { - suspender.suspend(); - if task.is_cancelled() { - return; - } - } - } - } -} diff --git a/src/async_runtime/mod.rs b/src/async_runtime/mod.rs deleted file mode 100755 index 20991e8..0000000 --- a/src/async_runtime/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub(crate) mod executor; -pub(crate) mod task; diff --git a/src/async_runtime/task.rs b/src/async_runtime/task.rs deleted file mode 100755 index 43b5a65..0000000 --- a/src/async_runtime/task.rs +++ /dev/null @@ -1,84 +0,0 @@ -use std::{ - future::Future, hint, pin::Pin, rc::Rc, sync::atomic::{AtomicBool, AtomicU8, Ordering}, task::{Context, Poll} -}; - -type LocalFuture = dyn Future; - -#[derive(Clone)] -pub(crate) struct Task { - inner: Rc, -} - -impl Task { - pub(crate) fn new + 'static>(future: Fut) -> Self { - Self { - inner: Rc::new(Inner::new(future)), - } - } -} - -impl Task { - pub(crate) fn is_completed(&self) -> bool { - self.inner.complete.load(Ordering::Acquire) - } - - pub(crate) fn complete(&self) { - self.inner.complete.store(true, Ordering::Release) - } - - pub(crate) fn cancel_task(&self) { - self.inner.cancelled.store(true, Ordering::Release) - } - - pub(crate) fn is_cancelled(&self) -> bool { - self.inner.cancelled.load(Ordering::Acquire) - } - - pub(crate) fn poll_task(&mut self, cx: &mut Context<'_>) -> Poll<()> { - // ensures that only this method is polling the future right now regardless of all other cloned tasks - // basically a lightweight spinlock to prevent data race bugs while polling - while self - .inner - .poll_check - .compare_exchange(0, 1, Ordering::Relaxed, Ordering::Relaxed) - .is_err() - { - hint::spin_loop(); - } - - let result = unsafe { Pin::new_unchecked(&mut (*self.inner.ptr)).poll(cx) }; - if result.is_ready() { - self.complete(); - } - self.inner.poll_check.store(0, Ordering::Release); - result - } -} - -unsafe impl Send for Task {} - -struct Inner { - poll_check: AtomicU8, - ptr: *mut LocalFuture, - cancelled: AtomicBool, - complete: AtomicBool, -} - -impl Inner { - fn new(future: impl Future + 'static) -> Self { - Self { - poll_check: AtomicU8::new(0), - ptr: Box::into_raw(Box::new(future)), - complete: AtomicBool::new(false), - cancelled: AtomicBool::new(false), - } - } -} - -impl Drop for Inner { - fn drop(&mut self) { - unsafe { - _ = Box::from_raw(self.ptr); - } - } -} diff --git a/src/async_stream/mod.rs b/src/async_stream/mod.rs index 82edc3e..4d648f4 100755 --- a/src/async_stream/mod.rs +++ b/src/async_stream/mod.rs @@ -1,157 +1,3 @@ -use std::{ - collections::VecDeque, - future::Future, - pin::Pin, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, - task::{Context, Poll, Waker}, -}; +mod stream; -use async_mutex::Mutex; -use futures_lite::Stream; - -pub struct AsyncStream { - inner: Arc>, -} - -struct Inner { - inner_lock: Mutex>, - item_count: AtomicUsize, -} - -impl Inner { - fn new() -> Self { - Self { - inner_lock: Mutex::new(InnerState::new()), - item_count: AtomicUsize::new(0), - } - } -} - -impl AsyncStream { - #[inline] - pub(crate) async fn insert_item(&self, value: ItemType) { - let mut inner_lock = self.inner.inner_lock.lock().await; - inner_lock.buffer.push_back(value); - // check if any waker was registered - let Some(waker) = inner_lock.waker.take() else { - return; - }; - // wakeup the waker - waker.wake(); - } -} - -impl AsyncStream { - pub(crate) async fn buffer_count(&self) -> usize { - self.inner.inner_lock.lock().await.buffer.len() - } -} - -impl AsyncStream { - pub(crate) fn increment(&self) { - self.inner.item_count.fetch_add(1, Ordering::Relaxed); - } -} - -impl AsyncStream { - pub async fn first(&mut self) -> Option { - let mut inner_lock = self.inner.inner_lock.lock().await; - if inner_lock.buffer.is_empty() || self.item_count() == 0 { - return None; - } - - let value = inner_lock.buffer.pop_front()?; - self.inner.item_count.fetch_sub(1, Ordering::Relaxed); - Some(value) - } -} - -impl AsyncStream { - pub(crate) fn item_count(&self) -> usize { - self.inner.item_count.load(Ordering::Acquire) - } -} - -impl Clone for AsyncStream { - fn clone(&self) -> Self { - Self { - inner: self.inner.clone(), - } - } -} - -impl AsyncStream { - pub(crate) fn new() -> Self { - AsyncStream { - inner: Arc::new(Inner::new()), - } - } -} - -enum Stages { - Empty, - Wait, - Ready(T), -} - -struct InnerState { - buffer: VecDeque, - waker: Option, -} - -impl InnerState { - fn new() -> InnerState { - Self { - buffer: VecDeque::with_capacity(1000), - waker: None, - } - } -} - -impl AsyncStream { - fn poll(&self, cx: &mut Context<'_>) -> Poll>> { - let waker = cx.waker().clone(); - let mut future = async move { - let mut inner_lock = self.inner.inner_lock.lock().await; - if self.item_count() == 0 && inner_lock.buffer.is_empty() { - return Stages::Empty; - } - let Some(value) = inner_lock.buffer.pop_front() else { - // register the waker so we can called it later - inner_lock.waker.replace(waker); - return Stages::Wait; - }; - - self.inner.item_count.fetch_sub(1, Ordering::Relaxed); - Stages::Ready(Some(value)) - }; - unsafe { Future::poll(Pin::new_unchecked(&mut future), cx) } - } -} - -impl Stream for AsyncStream { - type Item = ItemType; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.poll(cx) { - Poll::Pending => { - // This means the lock has not been acquired yet - // so immediately wake up this waker - cx.waker().wake_by_ref(); - Poll::Pending - } - Poll::Ready(stage) => match stage { - Stages::Empty => Poll::Ready(None), - Stages::Wait => Poll::Pending, - Stages::Ready(value) => Poll::Ready(value), - }, - } - } - - fn size_hint(&self) -> (usize, Option) { - (0, Some(self.item_count())) - } -} +pub(crate) use stream::AsyncStream; diff --git a/src/async_stream/stream.rs b/src/async_stream/stream.rs new file mode 100644 index 0000000..ac480e9 --- /dev/null +++ b/src/async_stream/stream.rs @@ -0,0 +1,162 @@ +use std::{ + collections::VecDeque, + future::Future, + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + task::{Context, Poll, Waker}, +}; + +use async_lock::Mutex; +use futures_lite::Stream; + +pub struct AsyncStream { + inner: Arc>, +} + +impl AsyncStream { + pub(crate) async fn insert_item(&self, value: ItemType) { + let mut inner_lock = self.inner.inner_lock.lock().await; + inner_lock.buffer.push_back(value); + // check if any waker was registered + let Some(waker) = inner_lock.wakers.take() else { + return; + }; + + drop(inner_lock); + + // wakeup the waker + waker.wake(); + } +} + +impl AsyncStream { + pub(crate) async fn buffer_count(&self) -> usize { + self.inner.inner_lock.lock().await.buffer.len() + } +} + +impl AsyncStream { + pub(crate) fn increment(&self) { + self.inner.item_count.fetch_add(1, Ordering::Relaxed); + } +} + +impl AsyncStream { + pub async fn first(&self) -> Option { + let mut inner_lock = self.inner.inner_lock.lock().await; + if inner_lock.buffer.is_empty() || self.item_count() == 0 { + return None; + } + + let value = inner_lock.buffer.pop_front()?; + self.inner.item_count.fetch_sub(1, Ordering::Relaxed); + Some(value) + } +} + +impl AsyncStream { + pub(crate) fn item_count(&self) -> usize { + self.inner.item_count.load(Ordering::Acquire) + } +} + +impl Clone for AsyncStream { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +impl AsyncStream { + pub(crate) fn new() -> Self { + AsyncStream { + inner: Arc::new(Inner::new()), + } + } +} + +struct Inner { + inner_lock: Mutex>, + item_count: AtomicUsize, +} + +impl Inner { + fn new() -> Self { + Self { + inner_lock: Mutex::new(InnerState::new()), + item_count: AtomicUsize::new(0), + } + } +} + +enum Stages { + Empty, + Wait, + Ready(T), +} + +struct InnerState { + buffer: VecDeque, + wakers: Option, +} + +impl InnerState { + fn new() -> InnerState { + Self { + buffer: VecDeque::new(), + wakers: None, + } + } +} + +impl AsyncStream { + fn poll_item(&self, cx: &mut Context<'_>) -> Poll> { + if self.item_count() == 0 { + return Poll::Ready(Stages::Empty); + } + let waker = cx.waker().clone(); + let mut future = async move { + let mut inner_lock = self.inner.inner_lock.lock().await; + if self.item_count() == 0 && inner_lock.buffer.is_empty() { + return Stages::Empty; + } + let Some(value) = inner_lock.buffer.pop_front() else { + // register the waker so we can called it later + inner_lock.wakers.replace(waker); + return Stages::Wait; + }; + + self.inner.item_count.fetch_sub(1, Ordering::Relaxed); + Stages::Ready(value) + }; + unsafe { Pin::new_unchecked(&mut future) }.poll(cx) + } +} + +impl Stream for AsyncStream { + type Item = ItemType; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.poll_item(cx) { + Poll::Pending => { + // This means the lock has not been acquired yet + // so immediately wake up this waker + cx.waker().wake_by_ref(); + Poll::Pending + } + Poll::Ready(stage) => match stage { + Stages::Empty => Poll::Ready(None), + Stages::Wait => Poll::Pending, + Stages::Ready(value) => Poll::Ready(Some(value)), + }, + } + } + + fn size_hint(&self) -> (usize, Option) { + (0, Some(self.item_count())) + } +} diff --git a/src/discarding_spawn_group.rs b/src/discarding_spawn_group.rs index 1e00456..d3fd393 100755 --- a/src/discarding_spawn_group.rs +++ b/src/discarding_spawn_group.rs @@ -17,9 +17,9 @@ use std::future::Future; /// any order. /// pub struct DiscardingSpawnGroup { + runtime: RuntimeEngine<()>, /// A field that indicates if the spawn group has been cancelled pub is_cancelled: bool, - runtime: RuntimeEngine<()>, wait_at_drop: bool, } @@ -45,6 +45,17 @@ impl DiscardingSpawnGroup { } } +impl Default for DiscardingSpawnGroup { + /// Instantiates `DiscardingSpawnGroup` with the number of threads as the number of cores as the system to use in the underlying threadpool when polling futures + fn default() -> Self { + Self { + is_cancelled: false, + runtime: RuntimeEngine::default(), + wait_at_drop: true, + } + } +} + impl DiscardingSpawnGroup { /// Spawns a new task into the spawn group /// diff --git a/src/err_spawn_group.rs b/src/err_spawn_group.rs index 44c3f0f..e5359f5 100755 --- a/src/err_spawn_group.rs +++ b/src/err_spawn_group.rs @@ -26,10 +26,10 @@ use std::{ /// It dereferences into a ``futures`` crate ``Stream`` type where the results of each finished child task is stored and it pops out the result in First-In First-Out /// FIFO order whenever it is being used pub struct ErrSpawnGroup { + runtime: RuntimeEngine>, + count: Arc, /// A field that indicates if the spawn group had been cancelled pub is_cancelled: bool, - count: Arc, - runtime: RuntimeEngine>, wait_at_drop: bool, } @@ -40,10 +40,22 @@ impl ErrSpawnGroup { /// /// * `num_of_threads`: number of threads to use pub fn new(num_of_threads: usize) -> Self { + Self { + runtime: RuntimeEngine::new(num_of_threads), + count: Arc::new(AtomicUsize::new(0)), + is_cancelled: false, + wait_at_drop: true, + } + } +} + +impl Default for ErrSpawnGroup { + /// Instantiates `ErrSpawnGroup` with the number of threads as the number of cores as the system to use in the underlying threadpool when polling futures + fn default() -> Self { Self { is_cancelled: false, count: Arc::new(AtomicUsize::new(0)), - runtime: RuntimeEngine::new(num_of_threads), + runtime: RuntimeEngine::default(), wait_at_drop: true, } } @@ -90,7 +102,7 @@ impl ErrSpawnGroup { F: Future> + Send + 'static, { if !self.is_cancelled { - self.spawn_task(priority, closure) + self.runtime.write_task(priority, closure) } } } @@ -124,11 +136,11 @@ impl ErrSpawnGroup { impl ErrSpawnGroup { fn increment_count(&self) { - self.count.fetch_add(1, Ordering::Acquire); + self.count.fetch_add(1, Ordering::Relaxed); } fn count(&self) -> usize { - self.count.load(Ordering::Acquire) + self.count.load(Ordering::Relaxed) } fn decrement_count_to_zero(&self) { diff --git a/src/executors/future_executor.rs b/src/executors/future_executor.rs index 8e92a41..8c3f059 100644 --- a/src/executors/future_executor.rs +++ b/src/executors/future_executor.rs @@ -1,10 +1,11 @@ use std::{ - cell::RefCell, future::Future, + pin::Pin, + sync::Arc, task::{Context, Poll, Waker}, }; -use super::suspender::{pair, Suspender}; +use crate::shared::{pair, Suspender}; /// Blocks the current thread until the future is polled to finish. /// @@ -17,36 +18,24 @@ use super::suspender::{pair, Suspender}; /// assert_eq!(result, 1); /// ``` /// +#[inline] pub fn block_on(future: Fut) -> Fut::Output { - let mut future = future; - let mut future = unsafe { std::pin::Pin::new_unchecked(&mut future) }; thread_local! { - static WAKER_PAIR: RefCell<(Suspender, Waker)> = { - RefCell::new(pair()) + static PAIR: (Arc, Waker) = { + pair() }; } - return WAKER_PAIR.with(|waker_pair| match waker_pair.try_borrow_mut() { - Ok(waker_pair) => { - let (suspender, waker) = &*waker_pair; - let mut context: Context<'_> = Context::from_waker(waker); - loop { - let Poll::Ready(output) = Future::poll(future.as_mut(), &mut context) else { - suspender.suspend(); - continue; - }; - return output; - } - } - Err(_) => { - let (suspender, waker) = pair(); - let mut context: Context<'_> = Context::from_waker(&waker); - loop { - let Poll::Ready(output) = Future::poll(future.as_mut(), &mut context) else { - suspender.suspend(); - continue; - }; - return output; + + PAIR.with(move |waker_pair| { + let mut future = future; + let mut future = unsafe { Pin::new_unchecked(&mut future) }; + let (suspender, waker) = &*waker_pair; + let mut context: Context<'_> = Context::from_waker(waker); + loop { + match future.as_mut().poll(&mut context) { + Poll::Pending => suspender.suspend(), + Poll::Ready(output) => return output, } } - }); + }) } diff --git a/src/executors/mod.rs b/src/executors/mod.rs index e02e9d1..c1a8848 100755 --- a/src/executors/mod.rs +++ b/src/executors/mod.rs @@ -1,8 +1,3 @@ mod future_executor; -mod suspender; -mod task_executor; -mod waker; pub use future_executor::block_on; -pub(crate) use task_executor::{block_task, WAKER_PAIR}; -pub(crate) use suspender::{Suspender, pair}; \ No newline at end of file diff --git a/src/executors/suspender.rs b/src/executors/suspender.rs deleted file mode 100755 index 23cfd62..0000000 --- a/src/executors/suspender.rs +++ /dev/null @@ -1,109 +0,0 @@ -use std::{ - sync::{Arc, Condvar, Mutex}, - task::Waker, -}; - -use super::waker::waker_helper; - -pub(crate) fn pair() -> (Suspender, Waker) { - let suspender = Suspender::new(); - let resumer = suspender.resumer(); - (suspender, waker_helper(move || resumer.resume())) -} - -pub(crate) struct Suspender { - resumer: Resumer, -} - -impl Suspender { - pub(crate) fn new() -> Suspender { - Suspender { - resumer: Resumer { - inner: Arc::new(Inner { - lock: Mutex::new(State::Empty), - cvar: Condvar::new(), - }), - }, - } - } - - pub(crate) fn suspend(&self) { - self.resumer.inner.suspend(); - } - - pub(crate) fn resumer(&self) -> Resumer { - Resumer { - inner: self.resumer.inner.clone(), - } - } -} - -pub(crate) struct Resumer { - inner: Arc, -} - -impl Resumer { - pub(crate) fn resume(&self) { - self.inner.resume(); - } -} - -#[derive(PartialEq)] -enum State { - Empty, - Notified, - Suspended, -} - -struct Inner { - lock: Mutex, - cvar: Condvar, -} - -impl Inner { - #[inline] - fn suspend(&self) { - // Acquire the lock first - let Ok(mut lock) = self.lock.lock() else { - return; - }; - - // check the state the lock is in right now - match *lock { - // suspend the thread - State::Empty => *lock = State::Suspended, - // already notified this thread so just revert state back to empty - // then return - State::Notified => { - *lock = State::Empty; - return; - } - State::Suspended => panic!("cannot suspend a thread that is already in a suspended state"), - } - - // suspend this thread until we get a notification - while *lock == State::Suspended { - lock = self.cvar.wait(lock).unwrap(); - } - - // revert state back to empty so the next time this method is called - // it can suspend this callee thread - *lock = State::Empty; - } - - #[inline] - fn resume(&self) { - // Acquire the lock first - let Ok(mut lock) = self.lock.lock() else { - return; - }; - - // check if the state is not in a notified state yet - if *lock != State::Notified { - // send notification - *lock = State::Notified; - // resume the suspended thread - self.cvar.notify_one(); - } - } -} diff --git a/src/executors/task_executor.rs b/src/executors/task_executor.rs deleted file mode 100644 index 6b7899f..0000000 --- a/src/executors/task_executor.rs +++ /dev/null @@ -1,48 +0,0 @@ -use std::{ - cell::RefCell, - task::{Context, Poll, Waker}, -}; - -use crate::shared::priority_task::PrioritizedTask; - -use super::{pair, Suspender}; - -thread_local! { - pub(crate) static WAKER_PAIR: RefCell<(Suspender, Waker)> = { - RefCell::new(pair()) - }; -} - -pub(crate) fn block_task(task: PrioritizedTask) { - if task.is_completed() { - return; - } - - WAKER_PAIR.with(move |waker_pair| { - let mut task = task; - match waker_pair.try_borrow_mut() { - Ok(waker_pair) => { - let (suspender, waker) = &*waker_pair; - let mut context: Context<'_> = Context::from_waker(waker); - loop { - let Poll::Ready(()) = task.poll_task(&mut context) else { - suspender.suspend(); - continue; - }; - return; - } - } - Err(_) => { - let (suspender, waker) = pair(); - let mut context: Context<'_> = Context::from_waker(&waker); - loop { - let Poll::Ready(()) = task.poll_task(&mut context) else { - suspender.suspend(); - continue; - }; - return; - } - } - }; - }); -} diff --git a/src/executors/waker.rs b/src/executors/waker.rs deleted file mode 100644 index 93fd6ba..0000000 --- a/src/executors/waker.rs +++ /dev/null @@ -1,48 +0,0 @@ -use std::{ - mem, - sync::Arc, - task::{RawWaker, RawWakerVTable, Waker}, -}; - -// Waker implementation -struct WakerHelper(F); - -pub(crate) fn waker_helper(f: F) -> Waker { - let raw: *const () = Arc::into_raw(Arc::new(f)) as *const (); - let vtable: &RawWakerVTable = &WakerHelper::::VTABLE; - unsafe { Waker::from_raw(RawWaker::new(raw, vtable)) } -} - -impl WakerHelper { - // A virtual function table (vtable) that specifies the behavior of a RawWaker - const VTABLE: RawWakerVTable = RawWakerVTable::new( - Self::clone_waker, - Self::wake, - Self::wake_by_ref, - Self::drop_waker, - ); - - // clones the waker - unsafe fn clone_waker(ptr: *const ()) -> RawWaker { - let arc: mem::ManuallyDrop> = mem::ManuallyDrop::new(Arc::from_raw(ptr as *const F)); - _ = arc.clone(); - RawWaker::new(ptr, &Self::VTABLE) - } - - // wakes up by consuming it - unsafe fn wake(ptr: *const ()) { - let arc: Arc = Arc::from_raw(ptr as *const F); - (arc)(); - } - - // wakes up by reference - unsafe fn wake_by_ref(ptr: *const ()) { - let arc: mem::ManuallyDrop> = mem::ManuallyDrop::new(Arc::from_raw(ptr as *const F)); - (arc)(); - } - - // drops the waker - unsafe fn drop_waker(ptr: *const ()) { - drop(Arc::from_raw(ptr as *const F)) - } -} diff --git a/src/lib.rs b/src/lib.rs index 0c302f2..37b696a 100755 --- a/src/lib.rs +++ b/src/lib.rs @@ -154,7 +154,7 @@ //! //! # Note //! * Import ``StreamExt`` trait from ``futures_lite::StreamExt`` or ``futures::stream::StreamExt`` or ``async_std::stream::StreamExt`` to provide a variety of convenient combinator functions on the various spawn groups. -//! * To await all running child tasks to finish their execution, call ``wait_for_all`` method on the spawn group instance unless using the [`with_discarding_spawn_group`](self::with_discarding_spawn_group) function. +//! * To await all running child tasks to finish their execution, call ``wait_for_all`` or ``wait_non_async`` methods on the various group instances //! //! # Warning //! * This crate relies on atomics @@ -166,7 +166,6 @@ mod discarding_spawn_group; mod err_spawn_group; mod spawn_group; -mod async_runtime; mod async_stream; mod executors; mod meta_types; @@ -182,7 +181,6 @@ pub use spawn_group::SpawnGroup; use std::future::Future; use std::marker::PhantomData; -use std::thread::available_parallelism; /// Starts a scoped closure that takes a mutable ``SpawnGroup`` instance as an argument which can execute any number of child tasks which its result values are of the generic ``ResultType`` type. /// @@ -228,23 +226,17 @@ use std::thread::available_parallelism; /// assert_eq!(final_result, 55); /// # }); /// ``` -pub async fn with_type_spawn_group<'a, Closure, Fut, ResultType, ReturnType>( +pub async fn with_type_spawn_group( of_type: PhantomData, body: Closure, ) -> ReturnType where - Closure: FnOnce(spawn_group::SpawnGroup) -> Fut + 'a, - Fut: Future + Send + 'static, + Closure: FnOnce(spawn_group::SpawnGroup) -> Fut, + Fut: Future + 'static, ResultType: 'static, { - let count: usize; - if let Ok(thread_count) = available_parallelism() { - count = thread_count.get(); - } else { - count = 1; - } _ = of_type; - let task_group = spawn_group::SpawnGroup::::new(count); + let task_group = spawn_group::SpawnGroup::::default(); body(task_group).await } @@ -291,19 +283,13 @@ where /// assert_eq!(final_result, 55); /// # }); /// ``` -pub async fn with_spawn_group<'a, Closure, Fut, ResultType, ReturnType>(body: Closure) -> ReturnType +pub async fn with_spawn_group(body: Closure) -> ReturnType where - Closure: FnOnce(spawn_group::SpawnGroup) -> Fut + 'a, - Fut: Future + Send + 'static, + Closure: FnOnce(spawn_group::SpawnGroup) -> Fut, + Fut: Future + 'static, ResultType: 'static, { - let count: usize; - if let Ok(thread_count) = available_parallelism() { - count = thread_count.get(); - } else { - count = 1; - } - let task_group = spawn_group::SpawnGroup::::new(count); + let task_group = spawn_group::SpawnGroup::::default(); body(task_group).await } @@ -396,25 +382,19 @@ where /// assert_eq!(final_results.2, 2); /// # }); /// ``` -pub async fn with_err_type_spawn_group<'a, Closure, Fut, ResultType, ErrorType, ReturnType>( +pub async fn with_err_type_spawn_group( of_type: PhantomData, error_type: PhantomData, body: Closure, ) -> ReturnType where ErrorType: 'static, - Fut: Future + Send, - Closure: FnOnce(err_spawn_group::ErrSpawnGroup) -> Fut + 'a, + Fut: Future, + Closure: FnOnce(err_spawn_group::ErrSpawnGroup) -> Fut, ResultType: 'static, { - let count: usize; - if let Ok(thread_count) = available_parallelism() { - count = thread_count.get(); - } else { - count = 1; - } _ = (of_type, error_type); - let task_group = err_spawn_group::ErrSpawnGroup::::new(count); + let task_group = err_spawn_group::ErrSpawnGroup::::default(); body(task_group).await } @@ -505,22 +485,16 @@ where /// assert_eq!(final_results.2, 2); /// # }); /// ``` -pub async fn with_err_spawn_group<'a, Closure, Fut, ResultType, ErrorType, ReturnType>( +pub async fn with_err_spawn_group( body: Closure, ) -> ReturnType where ErrorType: 'static, - Fut: Future + Send, - Closure: FnOnce(err_spawn_group::ErrSpawnGroup) -> Fut + 'a, + Fut: Future, + Closure: FnOnce(err_spawn_group::ErrSpawnGroup) -> Fut, ResultType: 'static, { - let count: usize; - if let Ok(thread_count) = available_parallelism() { - count = thread_count.get(); - } else { - count = 1; - } - let task_group = err_spawn_group::ErrSpawnGroup::::new(count); + let task_group = err_spawn_group::ErrSpawnGroup::::default(); body(task_group).await } @@ -561,17 +535,11 @@ where /// }).await; /// # }); /// ``` -pub async fn with_discarding_spawn_group<'a, Closure, Fut, ReturnType>(body: Closure) -> ReturnType +pub async fn with_discarding_spawn_group(body: Closure) -> ReturnType where - Fut: Future + Send, - Closure: FnOnce(discarding_spawn_group::DiscardingSpawnGroup) -> Fut + 'a, + Fut: Future, + Closure: FnOnce(discarding_spawn_group::DiscardingSpawnGroup) -> Fut, { - let count: usize; - if let Ok(thread_count) = available_parallelism() { - count = thread_count.get(); - } else { - count = 1; - } - let discarding_tg = discarding_spawn_group::DiscardingSpawnGroup::new(count); + let discarding_tg = discarding_spawn_group::DiscardingSpawnGroup::default(); body(discarding_tg).await } diff --git a/src/meta_types.rs b/src/meta_types.rs index be56ba7..85f3631 100755 --- a/src/meta_types.rs +++ b/src/meta_types.rs @@ -2,23 +2,22 @@ /// /// `GetType` provides a metatype that's a type of a type, /// it also enables a developer to pass a type as a value to specify a generic type of a parameter -/// +/// /// # Examples /// ``` /// use spawn_groups::GetType; /// use std::marker::PhantomData; -/// +/// /// fn closure_taker(with_value: T, returning_type: PhantomData, closure: FUNC) -> U /// where FUNC: Fn(T) -> U { /// closure(with_value) /// } /// /// let string_result = closure_taker(32, String::TYPE, |val| format!("{}", val) ); -/// +/// /// assert_eq!(string_result, String::from("32")); /// ``` -/// - +/// use std::marker::PhantomData; /// `GetType` trait implements asssociated constant for every type and this associated constant provides a metatype value that's a type's type value diff --git a/src/shared/mod.rs b/src/shared/mod.rs index 665e86e..32c911c 100755 --- a/src/shared/mod.rs +++ b/src/shared/mod.rs @@ -1,3 +1,14 @@ +pub(crate) mod mutex; pub(crate) mod priority; +pub(crate) mod priority_task; pub(crate) mod runtime; -pub(crate) mod priority_task; \ No newline at end of file +mod suspender; +mod task; +mod waker; +mod waker_pair; +mod task_enum; + +pub(crate) use suspender::{pair, Suspender}; +pub(crate) use waker_pair::WAKER_PAIR; +pub(crate) use task_enum::TaskOrBarrier; +pub(crate) use task::Task; \ No newline at end of file diff --git a/src/shared/mutex.rs b/src/shared/mutex.rs new file mode 100644 index 0000000..44a9c94 --- /dev/null +++ b/src/shared/mutex.rs @@ -0,0 +1,15 @@ +use std::sync::{Mutex, MutexGuard}; + +#[derive(Default)] +pub(crate) struct StdMutex(Mutex); +pub(crate) type StdMutexGuard<'a, T> = MutexGuard<'a, T>; + +impl StdMutex { + pub(crate) fn new(t: T) -> Self { + Self(Mutex::new(t)) + } + + pub(crate) fn lock(&self) -> StdMutexGuard { + self.0.lock().unwrap_or_else(|e| e.into_inner()) + } +} diff --git a/src/shared/priority.rs b/src/shared/priority.rs index 3c98fd1..e5cd08e 100755 --- a/src/shared/priority.rs +++ b/src/shared/priority.rs @@ -1,8 +1,8 @@ /// Task Priority /// /// Spawn groups uses it to rank the importance of their spawned tasks and order of returned values only when waited for -/// that is when the ``wait_for_all`` or ``wait_non_async`` methods are called -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Default)] +/// that is when the ``wait_for_all`` or ``wait_non_async`` method is called +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Default)] pub enum Priority { BACKGROUND = 0, LOW, diff --git a/src/shared/priority_task.rs b/src/shared/priority_task.rs index be26b6a..2faf46d 100644 --- a/src/shared/priority_task.rs +++ b/src/shared/priority_task.rs @@ -1,51 +1,50 @@ use std::{ cmp::Ordering, - ops::{Deref, DerefMut}, + future::Future, + sync::{Arc, Barrier}, }; -use crate::{async_runtime::task::Task, Priority}; +use crate::threadpool_impl::TaskPriority; -pub(crate) struct PrioritizedTask { - task: Task, - priority: Priority, -} - -impl Deref for PrioritizedTask { - type Target = Task; - - fn deref(&self) -> &Self::Target { - &self.task - } -} +use super::{task::Task, task_enum::TaskOrBarrier}; -impl DerefMut for PrioritizedTask { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.task - } +pub(crate) struct PrioritizedTask { + pub(crate) task: TaskOrBarrier, + priority: TaskPriority, } -impl PartialEq for PrioritizedTask { +impl PartialEq for PrioritizedTask { fn eq(&self, other: &Self) -> bool { self.priority == other.priority } } -impl Eq for PrioritizedTask {} +impl Eq for PrioritizedTask {} + +impl PrioritizedTask { + pub(crate) fn new + 'static>(priority: TaskPriority, future: F) -> Self { + Self { + task: TaskOrBarrier::Task(Task::new(future)), + priority, + } + } -impl PrioritizedTask { - pub(crate) fn new(priority: Priority, task: Task) -> Self { - Self { task, priority } + pub(crate) fn new_with(priority: TaskPriority, barrier: Arc) -> Self { + Self { + task: TaskOrBarrier::Barrier(barrier), + priority, + } } } -impl Ord for PrioritizedTask { +impl Ord for PrioritizedTask { fn cmp(&self, other: &Self) -> Ordering { other.priority.cmp(&self.priority) } } -impl PartialOrd for PrioritizedTask { +impl PartialOrd for PrioritizedTask { fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) + Some(other.cmp(self)) } } diff --git a/src/shared/runtime.rs b/src/shared/runtime.rs index 6036f2e..3a14de5 100755 --- a/src/shared/runtime.rs +++ b/src/shared/runtime.rs @@ -1,47 +1,45 @@ -use crate::{ - async_runtime::executor::Executor, async_stream::AsyncStream, executors::block_task, - shared::priority::Priority, -}; +use crate::{async_stream::AsyncStream, shared::priority::Priority, threadpool_impl::ThreadPool}; use std::{ - collections::BinaryHeap, future::Future, sync::{ atomic::{AtomicUsize, Ordering}, - Arc, Mutex, + Arc, }, }; use super::priority_task::PrioritizedTask; -type TaskQueue = Mutex>; - pub(crate) struct RuntimeEngine { - prioritized_tasks: TaskQueue, - runtime: Executor, - task_count: Arc, stream: AsyncStream, + pool: ThreadPool, + task_count: Arc, } impl RuntimeEngine { pub(crate) fn new(count: usize) -> Self { Self { - prioritized_tasks: Mutex::new(BinaryHeap::with_capacity(1000)), - runtime: Executor::new(count), + pool: ThreadPool::new(count), + stream: AsyncStream::new(), task_count: Arc::new(AtomicUsize::default()), + } + } +} + +impl Default for RuntimeEngine { + fn default() -> Self { + Self { + pool: ThreadPool::default(), stream: AsyncStream::new(), + task_count: Arc::new(AtomicUsize::default()), } } } impl RuntimeEngine { - pub(crate) fn cancel(&mut self) { - self.runtime.cancel(); - let Ok(mut prioritized_tasks) = self.prioritized_tasks.lock() else { - return; - }; - prioritized_tasks.clear(); - self.task_count.store(0, Ordering::Release); - self.poll(); + pub(crate) fn cancel(&self) { + self.pool.clear(); + self.pool.wait_for_all(); + self.task_count.store(0, Ordering::Relaxed); } } @@ -50,70 +48,41 @@ impl RuntimeEngine { self.stream.clone() } - pub(crate) fn end(&mut self) { - self.runtime.cancel(); - let Ok(mut prioritized_tasks) = self.prioritized_tasks.lock() else { - return; - }; - prioritized_tasks.clear(); - self.task_count.store(0, Ordering::Release); - self.runtime.end() + pub(crate) fn end(&self) { + self.pool.clear(); + self.pool.wait_for_all(); + self.task_count.store(0, Ordering::Relaxed); + self.pool.end() } } impl RuntimeEngine { pub(crate) fn wait_for_all_tasks(&self) { - let Ok(mut prioritized_tasks) = self.prioritized_tasks.lock() else { - return; - }; - if prioritized_tasks.is_empty() { - return; - } - self.runtime.cancel(); - prioritized_tasks.retain(|prioritized_task| { - prioritized_task.cancel_task(); - !prioritized_task.is_completed() - }); - if prioritized_tasks.is_empty() { - return; - } - while let Some(prioritized_task) = prioritized_tasks.pop() { - if prioritized_task.is_completed() { - continue; - } - self.runtime.submit(move || block_task(prioritized_task)); - } - self.task_count.store(0, Ordering::Release); - drop(prioritized_tasks); - self.poll() + self.poll(); + self.task_count.store(0, Ordering::Relaxed); } } impl RuntimeEngine { - pub(crate) fn write_task(&self, priority: Priority, task: F) + pub(crate) fn write_task(&mut self, priority: Priority, task: F) where F: Future + 'static, { - let Ok(mut prioritized_tasks) = self.prioritized_tasks.lock() else { - return; - }; - self.stream.increment(); - self.task_count.fetch_add(1, Ordering::SeqCst); let (stream, task_counter) = (self.stream(), self.task_count.clone()); - prioritized_tasks.push(PrioritizedTask::new( - priority, - self.runtime.spawn(async move { + stream.increment(); + task_counter.fetch_add(1, Ordering::Relaxed); + self.pool + .submit(PrioritizedTask::new(priority.into(), async move { let task_result = task.await; stream.insert_item(task_result).await; - task_counter.fetch_sub(1, Ordering::SeqCst); - }), - )); + task_counter.fetch_sub(1, Ordering::Relaxed); + })); } } impl RuntimeEngine { - pub(crate) fn poll(&self) { - self.runtime.poll_all(); + fn poll(&self) { + self.pool.wait_for_all(); } pub(crate) fn task_count(&self) -> usize { diff --git a/src/shared/suspender.rs b/src/shared/suspender.rs new file mode 100755 index 0000000..fe874f3 --- /dev/null +++ b/src/shared/suspender.rs @@ -0,0 +1,96 @@ +use std::{ + sync::{Arc, Condvar}, + task::Waker, +}; + +use crate::shared::mutex::StdMutex; + +use super::waker::waker_helper; + +pub(crate) fn pair() -> (Arc, Waker) { + let suspender = Arc::new(Suspender::new()); + let resumer = suspender.clone(); + (suspender, waker_helper(resumer)) +} + +pub(crate) struct Suspender { + inner: Inner, +} + +impl Suspender { + pub(crate) fn new() -> Suspender { + Suspender { + inner: Inner { + lock: StdMutex::new(State::Initial), + cvar: Condvar::new(), + }, + } + } + + pub(crate) fn suspend(&self) { + self.inner.suspend(); + } + + pub(crate) fn resume(&self) { + self.inner.resume(); + } +} + +#[derive(PartialEq)] +enum State { + Initial, + Notified, + Suspended, +} + +struct Inner { + lock: StdMutex, + cvar: Condvar, +} + +impl Inner { + fn suspend(&self) { + // Acquire the lock first + let mut lock = self.lock.lock(); + + // check the state the lock is in right now + match *lock { + // suspend the thread + State::Initial => { + *lock = State::Suspended; + // suspend this thread until we get a notification + while *lock == State::Suspended { + lock = self.cvar.wait(lock).unwrap(); + } + } + // already notified this thread so just revert state back to empty + // then return + State::Notified => { + *lock = State::Initial; + } + State::Suspended => { + panic!("cannot suspend a thread that is already in a suspended state") + } + } + } + + fn resume(&self) { + // Acquire the lock first + let mut lock = self.lock.lock(); + + // check if the state is empty or suspended + match *lock { + State::Initial => { + // send notification + *lock = State::Notified; + } + State::Suspended => { + // send notification + *lock = State::Notified; + // resume the suspended thread + self.cvar.notify_one(); + } + _ => {} + } + } +} diff --git a/src/shared/task.rs b/src/shared/task.rs new file mode 100755 index 0000000..15eeffb --- /dev/null +++ b/src/shared/task.rs @@ -0,0 +1,30 @@ +use std::{ + future::Future, + panic::{RefUnwindSafe, UnwindSafe}, + pin::Pin, + task::{Context, Poll}, +}; + +type LocalFuture = dyn Future; + +pub(crate) struct Task { + future: Pin>>, +} + +impl Task { + pub(crate) fn new + 'static>(future: Fut) -> Self { + Self { + future: unsafe { Pin::new_unchecked(Box::new(future)) }, + } + } + + #[inline] + pub(crate) fn poll_task(&mut self, cx: &mut Context<'_>) -> Poll { + self.future.as_mut().poll(cx) + } +} + +impl UnwindSafe for Task {} +impl RefUnwindSafe for Task {} +unsafe impl Send for Task {} +unsafe impl Sync for Task {} diff --git a/src/shared/task_enum.rs b/src/shared/task_enum.rs new file mode 100644 index 0000000..aa1f580 --- /dev/null +++ b/src/shared/task_enum.rs @@ -0,0 +1,9 @@ +use std::sync::{Arc, Barrier}; + +use super::task::Task; + +// Naming is hard guys +pub(crate) enum TaskOrBarrier { + Task(Task), + Barrier(Arc), +} diff --git a/src/shared/waker.rs b/src/shared/waker.rs new file mode 100644 index 0000000..d9cb381 --- /dev/null +++ b/src/shared/waker.rs @@ -0,0 +1,56 @@ +use std::{ + mem::ManuallyDrop, + sync::Arc, + task::{RawWaker, RawWakerVTable, Waker}, +}; + +use super::suspender::Suspender; + +pub(crate) fn waker_helper(suspender: Arc) -> Waker { + let raw: *const () = Arc::into_raw(suspender) as *const (); + unsafe { + Waker::from_raw(RawWaker::new( + raw, + &RawWakerVTable::new(clone_waker, wake, wake_by_ref, drop_waker), + )) + } +} + +// clones the waker +pub(crate) unsafe fn clone_waker(ptr: *const ()) -> RawWaker { + let ptr = ptr as *mut Suspender; + let waker = ManuallyDrop::new(Arc::from_raw(ptr)); + let cloned = (*waker).clone(); + + debug_assert_eq!(Arc::into_raw(ManuallyDrop::into_inner(waker)), ptr); + + let cloned_ptr = Arc::into_raw(cloned) as *const (); + + RawWaker::new( + cloned_ptr, + &RawWakerVTable::new(clone_waker, wake, wake_by_ref, drop_waker), + ) +} + +// wakes up by consuming it +pub(crate) unsafe fn wake(ptr: *const ()) { + let ptr = ptr as *mut Suspender; + let waker = Arc::from_raw(ptr); + waker.resume(); +} + +// wakes up by reference +pub(crate) unsafe fn wake_by_ref(ptr: *const ()) { + let ptr = ptr as *mut Suspender; + + let waker = ManuallyDrop::new(Arc::from_raw(ptr)); + waker.resume(); + + debug_assert_eq!(Arc::into_raw(ManuallyDrop::into_inner(waker)), ptr); +} + +// drops the waker +pub(crate) unsafe fn drop_waker(ptr: *const ()) { + let ptr = ptr as *mut Suspender; + drop(Arc::from_raw(ptr)); +} diff --git a/src/shared/waker_pair.rs b/src/shared/waker_pair.rs new file mode 100644 index 0000000..e8af5ac --- /dev/null +++ b/src/shared/waker_pair.rs @@ -0,0 +1,8 @@ +use super::{pair, Suspender}; +use std::{sync::Arc, task::Waker}; + +thread_local! { + pub(crate) static WAKER_PAIR: (Arc, Waker) = { + pair() + }; +} diff --git a/src/spawn_group.rs b/src/spawn_group.rs index 6287025..682db00 100755 --- a/src/spawn_group.rs +++ b/src/spawn_group.rs @@ -25,12 +25,11 @@ use std::{ /// /// It dereferences into a ``futures`` crate ``Stream`` type where the results of each finished child task is stored and it pops out the result in First-In First-Out /// FIFO order whenever it is being used - pub struct SpawnGroup { + runtime: RuntimeEngine, + count: Arc, /// A field that indicates if the spawn group had been cancelled pub is_cancelled: bool, - count: Arc, - runtime: RuntimeEngine, wait_at_drop: bool, } @@ -41,10 +40,22 @@ impl SpawnGroup { /// /// * `num_of_threads`: number of threads to use pub fn new(num_of_threads: usize) -> Self { + Self { + runtime: RuntimeEngine::new(num_of_threads), + count: Arc::new(AtomicUsize::new(0)), + is_cancelled: false, + wait_at_drop: true, + } + } +} + +impl Default for SpawnGroup { + /// Instantiates `SpawnGroup` with the number of threads as the number of cores as the system to use in the underlying threadpool when polling futures + fn default() -> Self { Self { is_cancelled: false, count: Arc::new(AtomicUsize::new(0)), - runtime: RuntimeEngine::new(num_of_threads), + runtime: RuntimeEngine::default(), wait_at_drop: true, } } @@ -83,7 +94,7 @@ impl SpawnGroup { F: Future + Send + 'static, { if !self.is_cancelled { - self.spawn_task(priority, closure) + self.runtime.write_task(priority, closure) } } @@ -117,11 +128,11 @@ impl SpawnGroup { impl SpawnGroup { fn increment_count(&self) { - self.count.fetch_add(1, Ordering::Acquire); + self.count.fetch_add(1, Ordering::Relaxed); } fn count(&self) -> usize { - self.count.load(Ordering::Acquire) + self.count.load(Ordering::Relaxed) } fn decrement_count_to_zero(&self) { diff --git a/src/threadpool_impl/channel.rs b/src/threadpool_impl/channel.rs index ddf4483..f597d57 100644 --- a/src/threadpool_impl/channel.rs +++ b/src/threadpool_impl/channel.rs @@ -1,109 +1,87 @@ use std::{ - collections::VecDeque, + collections::BinaryHeap, sync::{ atomic::{AtomicBool, Ordering}, - Arc, Condvar, Mutex, + Condvar, }, }; -pub(crate) struct Channel { - inner: Arc>, +use crate::shared::mutex::StdMutex; + +pub(crate) struct Channel { + inner: Inner, } -impl Channel { - pub(crate) fn enqueue(&self, value: ItemType) { +impl Channel { + pub(crate) fn enqueue(&self, value: T) { self.inner.enqueue(value) } } -impl Channel { +impl Channel { pub(crate) fn new() -> Self { Self { - inner: Arc::new(Inner::new()), - } - } -} - -impl Clone for Channel { - fn clone(&self) -> Self { - Self { - inner: self.inner.clone(), + inner: Inner::new(), } } } -impl Channel { - pub(crate) fn dequeue(&self) -> Option { +impl Channel { + pub(crate) fn dequeue(&self) -> Option { self.inner.dequeue() } } -impl Channel { - pub(crate) fn close(&self) { - self.inner.close() - } - +impl Channel { pub(crate) fn clear(&self) { self.inner.clear() } + + pub(crate) fn end(&self) { + self.inner.end() + } } -struct Inner { +struct Inner { + mtx: StdMutex>, + cvar: Condvar, closed: AtomicBool, - pair: (Mutex>, Condvar), } -impl Inner { +impl Inner { fn new() -> Self { Self { + mtx: StdMutex::new(BinaryHeap::new()), + cvar: Condvar::new(), closed: AtomicBool::new(false), - pair: (Mutex::new(VecDeque::new()), Condvar::new()), } } - fn enqueue(&self, value: ItemType) { - if self.closed.load(Ordering::Relaxed) { - return; - } - let Ok(mut lock) = self.pair.0.lock() else { - return; - }; - lock.push_back(value); - if lock.len() == 1 { - self.pair.1.notify_one(); - } + fn enqueue(&self, value: T) { + let mut lock = self.mtx.lock(); + lock.push(value); + self.cvar.notify_one(); } - fn dequeue(&self) -> Option { - if self.closed.load(Ordering::Relaxed) { - return None; - } - let Ok(mut lock) = self.pair.0.lock() else { - return None; - }; + fn dequeue(&self) -> Option { + let mut lock = self.mtx.lock(); while lock.is_empty() { if self.closed.load(Ordering::Relaxed) { return None; } - lock = self.pair.1.wait(lock).unwrap(); + lock = self.cvar.wait(lock).unwrap(); } - lock.pop_front() + lock.pop() } - fn close(&self) { - if self.closed.swap(true, Ordering::Relaxed) { - return; - } - let Ok(_lock) = self.pair.0.lock() else { - return; - }; - self.pair.1.notify_all(); + fn clear(&self) { + self.mtx.lock().clear(); } - fn clear(&self) { - let Ok(mut lock) = self.pair.0.lock() else { - return; - }; + fn end(&self) { + let mut lock = self.mtx.lock(); + self.closed.store(true, Ordering::Relaxed); lock.clear(); + self.cvar.notify_all(); } } diff --git a/src/threadpool_impl/mod.rs b/src/threadpool_impl/mod.rs index 97f5b7e..ae06898 100644 --- a/src/threadpool_impl/mod.rs +++ b/src/threadpool_impl/mod.rs @@ -1,8 +1,7 @@ mod channel; +mod task_priority; +mod thread; mod threadpool; -mod waitgroup; -pub(crate) type Func = dyn FnOnce() + Send; - -pub(crate) use channel::Channel; +pub(crate) use task_priority::TaskPriority; pub(crate) use threadpool::ThreadPool; diff --git a/src/threadpool_impl/task_priority.rs b/src/threadpool_impl/task_priority.rs new file mode 100644 index 0000000..ca786d6 --- /dev/null +++ b/src/threadpool_impl/task_priority.rs @@ -0,0 +1,25 @@ +use crate::Priority; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] +pub(crate) enum TaskPriority { + Wait, + Background, + Low, + Utility, + Medium, + High, + UserInitiated, +} + +impl From for TaskPriority { + fn from(value: Priority) -> Self { + match value { + Priority::BACKGROUND => TaskPriority::Background, + Priority::LOW => TaskPriority::Low, + Priority::UTILITY => TaskPriority::Utility, + Priority::MEDIUM => TaskPriority::Medium, + Priority::HIGH => TaskPriority::High, + Priority::USERINITIATED => TaskPriority::UserInitiated, + } + } +} diff --git a/src/threadpool_impl/thread.rs b/src/threadpool_impl/thread.rs new file mode 100644 index 0000000..473eed3 --- /dev/null +++ b/src/threadpool_impl/thread.rs @@ -0,0 +1,63 @@ +use std::{ + panic::catch_unwind, + sync::Arc, + task::{Context, Poll}, + thread::spawn, +}; + +use crate::shared::{priority_task::PrioritizedTask, Suspender, Task, TaskOrBarrier, WAKER_PAIR}; + +use super::channel::Channel; + +pub(crate) struct UniqueThread { + task_channel: Arc>>, +} + +impl UniqueThread { + pub(crate) fn submit_task(&self, task: PrioritizedTask<()>) { + self.task_channel.enqueue(task); + } + + pub(crate) fn clear(&self) { + self.task_channel.clear() + } + + pub(crate) fn end(&self) { + self.task_channel.end() + } +} + +impl Default for UniqueThread { + fn default() -> Self { + let task_channel: Arc>> = Arc::new(Channel::new()); + let chan = task_channel.clone(); + spawn(move || { + WAKER_PAIR.with(|waker_pair| loop { + let Some(task) = chan.dequeue() else { return }; + let mut context = Context::from_waker(&waker_pair.1); + match task.task { + TaskOrBarrier::Task(mut task) => { + drop(catch_unwind(move || { + poll_task(&mut task, &waker_pair.0, &mut context) + })); + } + TaskOrBarrier::Barrier(barrier) => { + barrier.wait(); + } + } + }); + }); + Self { task_channel } + } +} + +fn poll_task(task: &mut Task<()>, suspender: &Arc, context: &mut Context<'_>) { + loop { + match task.poll_task(context) { + Poll::Ready(()) => return, + Poll::Pending => { + suspender.suspend(); + } + } + } +} diff --git a/src/threadpool_impl/threadpool.rs b/src/threadpool_impl/threadpool.rs index e464e89..c647233 100644 --- a/src/threadpool_impl/threadpool.rs +++ b/src/threadpool_impl/threadpool.rs @@ -1,83 +1,67 @@ use std::{ - sync::atomic::{AtomicUsize, Ordering}, - thread::spawn, + sync::{Arc, Barrier}, + thread::available_parallelism, }; -use super::{waitgroup::WaitGroup, Channel, Func}; +use crate::shared::priority_task::PrioritizedTask; + +use super::{task_priority::TaskPriority, thread::UniqueThread}; pub(crate) struct ThreadPool { - index: AtomicUsize, - task_channels: Vec>>, - wait_group: WaitGroup, + handles: Vec, + index: usize, } impl ThreadPool { pub(crate) fn new(count: usize) -> Self { - let mut task_channels = Vec::with_capacity(count); - for _ in 1..=count { - let channel: Channel> = Channel::new(); - let chan = channel.clone(); - spawn(move || { - while let Some(ops) = channel.dequeue() { - ops(); - } - }); - task_channels.push(chan); - } + assert!(count > 0); ThreadPool { - index: AtomicUsize::new(0), - task_channels, - wait_group: WaitGroup::new(), + index: 0, + handles: (1..=count).map(|_| UniqueThread::default()).collect(), } } } -impl ThreadPool { - fn current_index(&self) -> usize { - self.index.swap( - (self.index.load(Ordering::Relaxed) + 1) % self.task_channels.len(), - Ordering::Relaxed, - ) +impl Default for ThreadPool { + fn default() -> Self { + let count: usize = available_parallelism() + .unwrap_or(unsafe { std::num::NonZeroUsize::new_unchecked(1) }) + .get(); + + ThreadPool { + handles: (1..=count).map(|_| UniqueThread::default()).collect(), + index: 0, + } } +} - pub(crate) fn submit(&self, task: Task) - where - Task: FnOnce() + 'static + Send, - { - self.task_channels[self.current_index()].enqueue(Box::new(task)); +impl ThreadPool { + pub(crate) fn submit(&mut self, task: PrioritizedTask<()>) { + let old_index = self.index; + self.index = (self.index + 1) % self.handles.len(); + self.handles[old_index].submit_task(task); } pub(crate) fn wait_for_all(&self) { - self.task_channels.iter().for_each(|channel| { - let wait_group = self.wait_group.clone(); - wait_group.enter(); - channel.enqueue(Box::new(move || { - wait_group.leave(); - })); + let barrier = Arc::new(Barrier::new(self.handles.len() + 1)); + self.handles.iter().for_each(|channel| { + channel.submit_task(PrioritizedTask::new_with( + TaskPriority::Wait, + barrier.clone(), + )); }); - self.wait_group.wait(); + barrier.wait(); } } impl ThreadPool { pub(crate) fn end(&self) { - self.task_channels.iter().for_each(|channel| { - channel.close(); - channel.clear(); - }); - } -} - -impl Drop for ThreadPool { - fn drop(&mut self) { - self.end(); + self.handles.iter().for_each(|channel| channel.end()); } } impl ThreadPool { pub(crate) fn clear(&self) { - self.task_channels - .iter() - .for_each(|channel| channel.clear()); + self.handles.iter().for_each(|channel| channel.clear()); } } diff --git a/src/threadpool_impl/waitgroup.rs b/src/threadpool_impl/waitgroup.rs deleted file mode 100644 index f0292a2..0000000 --- a/src/threadpool_impl/waitgroup.rs +++ /dev/null @@ -1,42 +0,0 @@ -use std::sync::{Arc, Condvar, Mutex}; - -#[derive(Clone)] -pub(crate) struct WaitGroup { - pair: Arc<(Mutex, Condvar)>, -} - -impl WaitGroup { - pub(crate) fn new() -> Self { - Self { - pair: Arc::new((Mutex::new(0), Condvar::new())), - } - } -} - -impl WaitGroup { - pub(crate) fn enter(&self) { - let Ok(mut guard) = self.pair.0.lock() else { - return; - }; - (*guard) += 1; - } - - pub(crate) fn leave(&self) { - let Ok(mut guard) = self.pair.0.lock() else { - return; - }; - (*guard) -= 1; - if (*guard) == 0 { - self.pair.1.notify_all(); - } - } - - pub(crate) fn wait(&self) { - let Ok(mut guard) = self.pair.0.lock() else { - return; - }; - while *guard > 0 { - guard = self.pair.1.wait(guard).unwrap(); - } - } -}