diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 12c6eee..cb19208 100755 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -2,9 +2,9 @@ name: Rust on: push: - branches: [ "master" ] + branches: [ "*" ] pull_request: - branches: [ "master" ] + branches: [ "*" ] env: CARGO_TERM_COLOR: always @@ -31,7 +31,14 @@ jobs: profile: minimal toolchain: nightly override: true + - name: Install Miri + run: | + rustup toolchain install nightly --component miri + rustup override set nightly + cargo miri test - name: Build run: cargo build - name: Test run: cargo test + - name: Miri Test + run: cargo miri test diff --git a/Cargo.toml b/Cargo.toml old mode 100755 new mode 100644 index ff15669..574fd8e --- a/Cargo.toml +++ b/Cargo.toml @@ -14,8 +14,5 @@ publish = true [dependencies] -async-trait = "0.1.73" -cooked-waker = "5.0.0" -parking_lot = "0.12.1" -futures-lite = "1.13.0" -async-mutex = "1.4.0" +async-lock = "3.4.0" +futures-lite = { version = "2.5.0", features = ["race"] } diff --git a/src/async_runtime/executor.rs b/src/async_runtime/executor.rs deleted file mode 100755 index 2cb1689..0000000 --- a/src/async_runtime/executor.rs +++ /dev/null @@ -1,141 +0,0 @@ -use crate::{pin_future, threadpool_impl::ThreadPool}; - -use super::{notifier::Notifier, task::Task, task_queue::TaskQueue}; - -use cooked_waker::IntoWaker; -use parking_lot::{lock_api::MutexGuard, Condvar, Mutex, RawMutex}; - -use std::{ - future::Future, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - task::{Context, Poll, Waker}, -}; - -#[derive(Clone)] -pub struct Executor { - cancel: Arc, - lock_pair: Arc<(Mutex, Condvar)>, - pool: Arc, - queue: TaskQueue, - started: Arc, -} - -impl Default for Executor { - fn default() -> Self { - let result: Executor = Self { - cancel: Arc::new(AtomicBool::new(false)), - lock_pair: Arc::new((Mutex::new(false), Condvar::new())), - pool: Arc::new(ThreadPool::default()), - queue: TaskQueue::default(), - started: Arc::new(AtomicBool::new(false)), - }; - result.start(); - result - } -} - -impl Executor { - pub(crate) fn new(count: usize) -> Self { - let result: Executor = Self { - cancel: Arc::new(AtomicBool::new(false)), - lock_pair: Arc::new((Mutex::new(false), Condvar::new())), - pool: Arc::new(ThreadPool::new(count)), - queue: TaskQueue::default(), - started: Arc::new(AtomicBool::new(false)), - }; - result.start(); - result - } -} - -impl Executor { - fn started(&self) -> bool { - self.started.load(Ordering::Acquire) - } - - fn update(&self, val: bool) { - self.started.store(val, Ordering::Release); - } -} - -impl Executor { - pub(crate) fn submit(&self, task: Task) - where - Task: FnOnce() + Send + 'static, - { - self.pool.submit(task); - } - - pub(crate) fn spawn(&self, task: Fut) -> Task - where - Fut: Future + 'static + Send, - { - let task: Task = Task::new(task); - self.queue.push(&task); - - if !self.started() { - self.notify(); - } - task - } - - fn notify(&self) { - self.update(true); - let pair2: Arc<(Mutex, Condvar)> = self.lock_pair.clone(); - std::thread::spawn(move || { - let (lock, cvar) = &*pair2; - let mut started: MutexGuard<'_, RawMutex, bool> = lock.lock(); - *started = true; - cvar.notify_one(); - }); - } - - pub(crate) fn cancel(&self) { - self.cancel.store(true, Ordering::Release); - *self.lock_pair.0.lock() = false; - self.update(false); - self.queue.drain_all(); - self.cancel.store(false, Ordering::Release); - } - - pub(crate) fn run(&self) { - while !self.cancel.load(Ordering::Acquire) { - self.queue.clone().for_each(|task| { - let queue: TaskQueue = self.queue.clone(); - self.submit(move || { - let waker: Waker = Arc::new(Notifier::default()).into_waker(); - pin_future!(task); - let mut cx: Context<'_> = Context::from_waker(&waker); - match task.as_mut().poll(&mut cx) { - Poll::Ready(()) => (), - Poll::Pending => { - queue.push(&task); - } - } - }); - }); - } - self.poll_all(); - self.queue.drain_all(); - } - - pub(crate) fn poll_all(&self) { - self.pool.wait_for_all(); - } - - pub(crate) fn start(&self) { - let lock_pair: Arc<(Mutex, Condvar)> = self.lock_pair.clone(); - let executor: Executor = self.clone(); - std::thread::spawn(move || { - let (lock, cvar) = &*lock_pair; - let mut started: MutexGuard<'_, RawMutex, bool> = lock.lock(); - while !*started { - cvar.wait(&mut started); - } - executor.run(); - }); - } -} diff --git a/src/async_runtime/mod.rs b/src/async_runtime/mod.rs deleted file mode 100755 index 023ec5a..0000000 --- a/src/async_runtime/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub(crate) mod executor; -pub(crate) mod notifier; -pub(crate) mod task; -mod task_queue; -mod pin_macro; \ No newline at end of file diff --git a/src/async_runtime/notifier.rs b/src/async_runtime/notifier.rs deleted file mode 100755 index a918817..0000000 --- a/src/async_runtime/notifier.rs +++ /dev/null @@ -1,29 +0,0 @@ -use cooked_waker::WakeRef; -use std::sync::{Condvar, Mutex, MutexGuard}; - -#[derive(Default)] -pub struct Notifier { - was_notified: Mutex, - cv: Condvar, -} - -impl WakeRef for Notifier { - fn wake_by_ref(&self) { - let was_notified: bool = - { std::mem::replace(&mut self.was_notified.lock().unwrap(), true) }; - if !was_notified { - self.cv.notify_one(); - } - } -} - -impl Notifier { - pub(crate) fn wait(&self) { - let mut was_notified: MutexGuard<'_, bool> = self.was_notified.lock().unwrap(); - - while !*was_notified { - was_notified = self.cv.wait(was_notified).unwrap(); - } - *was_notified = false; - } -} diff --git a/src/async_runtime/pin_macro.rs b/src/async_runtime/pin_macro.rs deleted file mode 100755 index 74c3caa..0000000 --- a/src/async_runtime/pin_macro.rs +++ /dev/null @@ -1,8 +0,0 @@ -/// Pins the variable implementing the ``std::future::Future`` trait onto the stack -#[macro_export] -macro_rules! pin_future { - ($x:ident) => { - let mut $x = $x; - let mut $x = unsafe { std::pin::Pin::new_unchecked(&mut $x) }; - }; -} diff --git a/src/async_runtime/task.rs b/src/async_runtime/task.rs deleted file mode 100755 index c303efd..0000000 --- a/src/async_runtime/task.rs +++ /dev/null @@ -1,54 +0,0 @@ -use parking_lot::Mutex; -use std::{ - future::Future, - pin::Pin, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - task::Poll, -}; - -type LocalBoxedFuture = Pin + Send + 'static>>; - -#[derive(Clone)] -pub struct Task { - pub(crate) future: Arc>, - pub(crate) complete: Arc, -} - -impl Task { - pub(crate) fn new + Send + 'static>(fut: Fut) -> Self { - Self { - future: Arc::new(Mutex::new(Box::pin(fut))), - complete: Arc::new(AtomicBool::new(false)), - } - } - - pub(crate) fn is_completed(&self) -> bool { - self.complete.load(Ordering::Acquire) - } - - fn complete(&self) { - self.complete.store(true, Ordering::Release); - } -} - -impl Future for Task { - type Output = (); - fn poll( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll { - match self.future.lock().as_mut().poll(cx) { - Poll::Ready(()) => { - self.complete(); - Poll::Ready(()) - } - Poll::Pending => { - cx.waker().wake_by_ref(); - Poll::Pending - } - } - } -} diff --git a/src/async_runtime/task_queue.rs b/src/async_runtime/task_queue.rs deleted file mode 100755 index f49f16d..0000000 --- a/src/async_runtime/task_queue.rs +++ /dev/null @@ -1,34 +0,0 @@ -use super::task::Task; -use parking_lot::Mutex; -use std::{collections::VecDeque, iter::Iterator, sync::Arc}; - -#[derive(Clone, Default)] -pub struct TaskQueue { - buffer: Arc>>, -} - -impl TaskQueue { - pub(crate) fn push(&self, task: &Task) { - self.buffer.lock().push_back(task.clone()); - } -} - -impl TaskQueue { - pub(crate) fn drain_all(&self) { - self.buffer.lock().clear(); - } -} - -impl Iterator for TaskQueue { - type Item = Task; - - fn next(&mut self) -> Option { - let Some(task) = self.buffer.lock().pop_front() else { - return None; - }; - if !task.is_completed() { - return Some(task); - } - None - } -} diff --git a/src/async_stream/mod.rs b/src/async_stream/mod.rs index 58c7e52..4d648f4 100755 --- a/src/async_stream/mod.rs +++ b/src/async_stream/mod.rs @@ -1,117 +1,3 @@ -use std::{ - collections::VecDeque, - pin::Pin, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, - task::{Context, Poll}, -}; +mod stream; -use async_mutex::{Mutex, MutexGuard}; -use futures_lite::{Stream, StreamExt}; - -use crate::executors::block_on; - -pub struct AsyncStream { - buffer: Arc>>, - started: bool, - counts: (Arc, Arc), - cancelled: bool, -} - -impl AsyncStream { - pub(crate) async fn insert_item(&mut self, value: ItemType) { - if !self.started { - self.started = true; - } - self.buffer.lock().await.push_back(value); - } -} - -impl AsyncStream { - pub(crate) async fn buffer_count(&self) -> usize { - self.buffer.lock().await.len() - } -} - -impl AsyncStream { - pub(crate) fn increment(&self) { - self.counts.0.fetch_add(1, Ordering::Acquire); - self.counts.1.fetch_add(1, Ordering::Acquire); - } -} - -impl AsyncStream { - pub(crate) async fn first(&mut self) -> Option { - self.next().await - } -} - -impl AsyncStream { - pub(crate) fn task_count(&self) -> usize { - self.counts.1.load(Ordering::Acquire) - } - - pub(crate) fn decrement_task_count(&self) { - if self.task_count() > 0 { - self.counts.1.fetch_sub(1, Ordering::Acquire); - } - } - - pub(crate) fn item_count(&self) -> usize { - self.counts.0.load(Ordering::Acquire) - } - - pub(crate) fn decrement_count(&self) { - if self.item_count() > 0 { - self.counts.0.fetch_sub(1, Ordering::Acquire); - } - } - - pub(crate) fn cancel_tasks(&mut self) { - self.cancelled = true; - self.counts.1.store(0, Ordering::Release); - } -} - -impl Clone for AsyncStream { - fn clone(&self) -> Self { - Self { - buffer: self.buffer.clone(), - started: self.started, - counts: self.counts.clone(), - cancelled: self.cancelled, - } - } -} - -impl AsyncStream { - pub(crate) fn new() -> Self { - AsyncStream:: { - buffer: Arc::new(Mutex::new(VecDeque::new())), - started: false, - counts: (Arc::new(AtomicUsize::new(0)), Arc::new(AtomicUsize::new(0))), - cancelled: false, - } - } -} - -impl Stream for AsyncStream { - type Item = ItemType; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - block_on(async move { - let mut inner_lock: MutexGuard<'_, VecDeque> = self.buffer.lock().await; - if self.cancelled && inner_lock.is_empty() || self.item_count() == 0 { - return Poll::Ready(None); - } - let Some(value) = inner_lock.pop_front() else { - cx.waker().wake_by_ref(); - return Poll::Pending; - }; - self.decrement_count(); - Poll::Ready(Some(value)) - }) - } -} +pub(crate) use stream::AsyncStream; diff --git a/src/async_stream/stream.rs b/src/async_stream/stream.rs new file mode 100644 index 0000000..ac480e9 --- /dev/null +++ b/src/async_stream/stream.rs @@ -0,0 +1,162 @@ +use std::{ + collections::VecDeque, + future::Future, + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + task::{Context, Poll, Waker}, +}; + +use async_lock::Mutex; +use futures_lite::Stream; + +pub struct AsyncStream { + inner: Arc>, +} + +impl AsyncStream { + pub(crate) async fn insert_item(&self, value: ItemType) { + let mut inner_lock = self.inner.inner_lock.lock().await; + inner_lock.buffer.push_back(value); + // check if any waker was registered + let Some(waker) = inner_lock.wakers.take() else { + return; + }; + + drop(inner_lock); + + // wakeup the waker + waker.wake(); + } +} + +impl AsyncStream { + pub(crate) async fn buffer_count(&self) -> usize { + self.inner.inner_lock.lock().await.buffer.len() + } +} + +impl AsyncStream { + pub(crate) fn increment(&self) { + self.inner.item_count.fetch_add(1, Ordering::Relaxed); + } +} + +impl AsyncStream { + pub async fn first(&self) -> Option { + let mut inner_lock = self.inner.inner_lock.lock().await; + if inner_lock.buffer.is_empty() || self.item_count() == 0 { + return None; + } + + let value = inner_lock.buffer.pop_front()?; + self.inner.item_count.fetch_sub(1, Ordering::Relaxed); + Some(value) + } +} + +impl AsyncStream { + pub(crate) fn item_count(&self) -> usize { + self.inner.item_count.load(Ordering::Acquire) + } +} + +impl Clone for AsyncStream { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +impl AsyncStream { + pub(crate) fn new() -> Self { + AsyncStream { + inner: Arc::new(Inner::new()), + } + } +} + +struct Inner { + inner_lock: Mutex>, + item_count: AtomicUsize, +} + +impl Inner { + fn new() -> Self { + Self { + inner_lock: Mutex::new(InnerState::new()), + item_count: AtomicUsize::new(0), + } + } +} + +enum Stages { + Empty, + Wait, + Ready(T), +} + +struct InnerState { + buffer: VecDeque, + wakers: Option, +} + +impl InnerState { + fn new() -> InnerState { + Self { + buffer: VecDeque::new(), + wakers: None, + } + } +} + +impl AsyncStream { + fn poll_item(&self, cx: &mut Context<'_>) -> Poll> { + if self.item_count() == 0 { + return Poll::Ready(Stages::Empty); + } + let waker = cx.waker().clone(); + let mut future = async move { + let mut inner_lock = self.inner.inner_lock.lock().await; + if self.item_count() == 0 && inner_lock.buffer.is_empty() { + return Stages::Empty; + } + let Some(value) = inner_lock.buffer.pop_front() else { + // register the waker so we can called it later + inner_lock.wakers.replace(waker); + return Stages::Wait; + }; + + self.inner.item_count.fetch_sub(1, Ordering::Relaxed); + Stages::Ready(value) + }; + unsafe { Pin::new_unchecked(&mut future) }.poll(cx) + } +} + +impl Stream for AsyncStream { + type Item = ItemType; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.poll_item(cx) { + Poll::Pending => { + // This means the lock has not been acquired yet + // so immediately wake up this waker + cx.waker().wake_by_ref(); + Poll::Pending + } + Poll::Ready(stage) => match stage { + Stages::Empty => Poll::Ready(None), + Stages::Wait => Poll::Pending, + Stages::Ready(value) => Poll::Ready(Some(value)), + }, + } + } + + fn size_hint(&self) -> (usize, Option) { + (0, Some(self.item_count())) + } +} diff --git a/src/discarding_spawn_group.rs b/src/discarding_spawn_group.rs index 4b0a7b6..d3fd393 100755 --- a/src/discarding_spawn_group.rs +++ b/src/discarding_spawn_group.rs @@ -1,6 +1,4 @@ -use crate::shared::{ - initializible::Initializible, priority::Priority, runtime::RuntimeEngine, sharedfuncs::Shared, -}; +use crate::shared::{priority::Priority, runtime::RuntimeEngine}; use std::future::Future; @@ -19,9 +17,9 @@ use std::future::Future; /// any order. /// pub struct DiscardingSpawnGroup { + runtime: RuntimeEngine<()>, /// A field that indicates if the spawn group has been cancelled pub is_cancelled: bool, - runtime: RuntimeEngine<()>, wait_at_drop: bool, } @@ -34,7 +32,7 @@ impl DiscardingSpawnGroup { impl DiscardingSpawnGroup { /// Instantiates `DiscardingSpawnGroup` with a specific number of threads to use in the underlying threadpool when polling futures - /// + /// /// # Parameters /// /// * `num_of_threads`: number of threads to use @@ -42,7 +40,18 @@ impl DiscardingSpawnGroup { Self { is_cancelled: false, runtime: RuntimeEngine::new(num_of_threads), - wait_at_drop: false, + wait_at_drop: true, + } + } +} + +impl Default for DiscardingSpawnGroup { + /// Instantiates `DiscardingSpawnGroup` with the number of threads as the number of cores as the system to use in the underlying threadpool when polling futures + fn default() -> Self { + Self { + is_cancelled: false, + runtime: RuntimeEngine::default(), + wait_at_drop: true, } } } @@ -56,9 +65,9 @@ impl DiscardingSpawnGroup { /// * `closure`: an async closure that doesn't return anything pub fn spawn_task(&mut self, priority: Priority, closure: F) where - F: Future::Result> + Send + 'static, + F: Future + Send + 'static, { - self.add_task(priority, closure); + self.runtime.write_task(priority, closure); } /// Spawn a new task only if the group is not cancelled yet, @@ -70,14 +79,17 @@ impl DiscardingSpawnGroup { /// * `closure`: an async closure that return doesn't return anything pub fn spawn_task_unlessed_cancelled(&mut self, priority: Priority, closure: F) where - F: Future::Result> + Send + 'static, + F: Future + Send + 'static, { - self.add_task_unlessed_cancelled(priority, closure); + if !self.is_cancelled { + self.runtime.write_task(priority, closure); + } } /// Cancels all running task in the spawn group pub fn cancel_all(&mut self) { - self.cancel_all_tasks(); + self.runtime.cancel(); + self.is_cancelled = true; } } @@ -91,54 +103,30 @@ impl DiscardingSpawnGroup { /// - true: if there's no child task still running /// - false: if any child task is still running pub fn is_empty(&self) -> bool { - if self.runtime.stream().task_count() == 0 { + if self.runtime.task_count() == 0 { return true; } false } } -impl Drop for DiscardingSpawnGroup { - fn drop(&mut self) { - if self.wait_at_drop { - self.runtime.wait_for_all_tasks(); - } else { - self.runtime.end() - } - } -} - -impl Shared for DiscardingSpawnGroup { - type Result = (); - - fn add_task(&mut self, priority: Priority, closure: F) - where - F: Future + Send + 'static, - { - self.runtime.write_task(priority, closure); - } - - fn add_task_unlessed_cancelled(&mut self, priority: Priority, closure: F) - where - F: Future + Send + 'static, - { - if !self.is_cancelled { - self.add_task(priority, closure) - } +impl DiscardingSpawnGroup { + /// Waits for all remaining child tasks for finish. + pub async fn wait_for_all(&mut self) { + self.runtime.wait_for_all_tasks(); } - fn cancel_all_tasks(&mut self) { - self.runtime.cancel(); - self.is_cancelled = true; + /// Waits for all remaining child tasks for finish in non async context. + pub fn wait_non_async(&mut self) { + self.runtime.wait_for_all_tasks(); } } -impl Initializible for DiscardingSpawnGroup { - fn init() -> Self { - DiscardingSpawnGroup { - is_cancelled: false, - runtime: RuntimeEngine::init(), - wait_at_drop: true, +impl Drop for DiscardingSpawnGroup { + fn drop(&mut self) { + if self.wait_at_drop { + self.runtime.wait_for_all_tasks(); } + self.runtime.end() } } diff --git a/src/err_spawn_group.rs b/src/err_spawn_group.rs index 2c96140..e5359f5 100755 --- a/src/err_spawn_group.rs +++ b/src/err_spawn_group.rs @@ -1,8 +1,4 @@ -use crate::shared::{ - initializible::Initializible, priority::Priority, runtime::RuntimeEngine, sharedfuncs::Shared, - wait::Waitable, -}; -use async_trait::async_trait; +use crate::shared::{priority::Priority, runtime::RuntimeEngine}; use futures_lite::{Stream, StreamExt}; use std::{ future::Future, @@ -29,38 +25,50 @@ use std::{ /// /// It dereferences into a ``futures`` crate ``Stream`` type where the results of each finished child task is stored and it pops out the result in First-In First-Out /// FIFO order whenever it is being used -pub struct ErrSpawnGroup { +pub struct ErrSpawnGroup { + runtime: RuntimeEngine>, + count: Arc, /// A field that indicates if the spawn group had been cancelled pub is_cancelled: bool, - count: Arc, - runtime: RuntimeEngine>, wait_at_drop: bool, } -impl ErrSpawnGroup { +impl ErrSpawnGroup { /// Instantiates `ErrSpawnGroup` with a specific number of threads to use in the underlying threadpool when polling futures - /// + /// /// # Parameters /// /// * `num_of_threads`: number of threads to use pub fn new(num_of_threads: usize) -> Self { + Self { + runtime: RuntimeEngine::new(num_of_threads), + count: Arc::new(AtomicUsize::new(0)), + is_cancelled: false, + wait_at_drop: true, + } + } +} + +impl Default for ErrSpawnGroup { + /// Instantiates `ErrSpawnGroup` with the number of threads as the number of cores as the system to use in the underlying threadpool when polling futures + fn default() -> Self { Self { is_cancelled: false, count: Arc::new(AtomicUsize::new(0)), - runtime: RuntimeEngine::new(num_of_threads), - wait_at_drop: false, + runtime: RuntimeEngine::default(), + wait_at_drop: true, } } } -impl ErrSpawnGroup { +impl ErrSpawnGroup { /// Don't implicity wait for spawned child tasks to finish before being dropped pub fn dont_wait_at_drop(&mut self) { self.wait_at_drop = false; } } -impl ErrSpawnGroup { +impl ErrSpawnGroup { /// Spawns a new task into the spawn group /// /// # Parameters @@ -69,16 +77,17 @@ impl ErrSpawnGroup { /// * `closure`: an async closure that return a value of type ``Result`` pub fn spawn_task(&mut self, priority: Priority, closure: F) where - F: Future as Shared>::Result> - + Send - + 'static, + F: Future> + Send + 'static, { - self.add_task(priority, closure); + self.increment_count(); + self.runtime.write_task(priority, closure); } /// Cancels all running task in the spawn group pub fn cancel_all(&mut self) { - self.cancel_all_tasks(); + self.runtime.cancel(); + self.is_cancelled = true; + self.decrement_count_to_zero(); } /// Spawn a new task only if the group is not cancelled yet, @@ -90,50 +99,56 @@ impl ErrSpawnGroup { /// * `closure`: an async closure that return a value of type ``Result`` pub fn spawn_task_unlessed_cancelled(&mut self, priority: Priority, closure: F) where - F: Future as Shared>::Result> - + Send - + 'static, + F: Future> + Send + 'static, { - self.add_task_unlessed_cancelled(priority, closure); + if !self.is_cancelled { + self.runtime.write_task(priority, closure) + } } } -impl ErrSpawnGroup { +impl ErrSpawnGroup { /// Returns the first element of the stream, or None if it is empty. - pub async fn first(&self) -> Option< as Shared>::Result> { + pub async fn first(&self) -> Option> { self.runtime.stream().first().await } } -impl ErrSpawnGroup { +impl ErrSpawnGroup { /// Returns an instance of the `Stream` trait. pub fn stream(&self) -> impl Stream> { self.runtime.stream() } } -impl ErrSpawnGroup { +impl ErrSpawnGroup { /// Waits for all remaining child tasks for finish. pub async fn wait_for_all(&mut self) { - self.wait().await; + self.wait_non_async() + } + + /// Waits for all remaining child tasks for finish in non async context. + pub fn wait_non_async(&mut self) { + self.runtime.wait_for_all_tasks(); + self.decrement_count_to_zero() } } -impl ErrSpawnGroup { +impl ErrSpawnGroup { fn increment_count(&self) { - self.count.fetch_add(1, Ordering::Acquire); + self.count.fetch_add(1, Ordering::Relaxed); } fn count(&self) -> usize { - self.count.load(Ordering::Acquire) + self.count.load(Ordering::Relaxed) } fn decrement_count_to_zero(&self) { - self.count.store(0, Ordering::Release); + self.count.store(0, Ordering::Relaxed); } } -impl ErrSpawnGroup { +impl ErrSpawnGroup { /// A Boolean value that indicates whether the group has any remaining tasks. /// /// At the start of the body of a ``with_err_spawn_group`` function call, or before calling ``spawn_task`` or ``spawn_task_unless_cancelled`` methods @@ -143,14 +158,14 @@ impl ErrSpawnGroup { /// - true: if there's no child task still running /// - false: if any child task is still running pub fn is_empty(&self) -> bool { - if self.count() == 0 || self.runtime.stream().task_count() == 0 { + if self.count() == 0 || self.runtime.task_count() == 0 { return true; } false } } -impl ErrSpawnGroup { +impl ErrSpawnGroup { /// Waits for a specific number of spawned child tasks to finish and returns their respectively result as a vector /// /// # Panics @@ -195,70 +210,19 @@ impl ErrSpawnGroup { } } -impl Drop for ErrSpawnGroup { +impl Drop for ErrSpawnGroup { fn drop(&mut self) { if self.wait_at_drop { self.runtime.wait_for_all_tasks(); - } else { - self.runtime.end() - } - } -} - -impl Initializible for ErrSpawnGroup { - fn init() -> Self { - ErrSpawnGroup:: { - count: Arc::new(AtomicUsize::new(0)), - is_cancelled: false, - runtime: RuntimeEngine::init(), - wait_at_drop: true, } + self.runtime.end() } } -impl Shared - for ErrSpawnGroup -{ - type Result = Result; - - fn add_task(&mut self, priority: Priority, closure: F) - where - F: Future + Send + 'static, - { - self.increment_count(); - self.runtime.write_task(priority, closure); - } - - fn cancel_all_tasks(&mut self) { - self.runtime.cancel(); - self.is_cancelled = true; - self.decrement_count_to_zero(); - } - - fn add_task_unlessed_cancelled(&mut self, priority: Priority, closure: F) - where - F: Future + Send + 'static, - { - if !self.is_cancelled { - self.add_task(priority, closure) - } - } -} - -impl Stream for ErrSpawnGroup { +impl Stream for ErrSpawnGroup { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.runtime.stream().poll_next(cx) } } - -#[async_trait] -impl Waitable - for ErrSpawnGroup -{ - async fn wait(&self) { - self.runtime.wait_for_all_tasks(); - self.decrement_count_to_zero(); - } -} diff --git a/src/executors/future_executor.rs b/src/executors/future_executor.rs new file mode 100644 index 0000000..8c3f059 --- /dev/null +++ b/src/executors/future_executor.rs @@ -0,0 +1,41 @@ +use std::{ + future::Future, + pin::Pin, + sync::Arc, + task::{Context, Poll, Waker}, +}; + +use crate::shared::{pair, Suspender}; + +/// Blocks the current thread until the future is polled to finish. +/// +/// Example +/// ```rust +/// let result = spawn_groups::block_on(async { +/// println!("This is an async executor"); +/// 1 +/// }); +/// assert_eq!(result, 1); +/// ``` +/// +#[inline] +pub fn block_on(future: Fut) -> Fut::Output { + thread_local! { + static PAIR: (Arc, Waker) = { + pair() + }; + } + + PAIR.with(move |waker_pair| { + let mut future = future; + let mut future = unsafe { Pin::new_unchecked(&mut future) }; + let (suspender, waker) = &*waker_pair; + let mut context: Context<'_> = Context::from_waker(waker); + loop { + match future.as_mut().poll(&mut context) { + Poll::Pending => suspender.suspend(), + Poll::Ready(output) => return output, + } + } + }) +} diff --git a/src/executors/local_executor.rs b/src/executors/local_executor.rs deleted file mode 100755 index 6b7d67c..0000000 --- a/src/executors/local_executor.rs +++ /dev/null @@ -1,34 +0,0 @@ -use std::{ - future::Future, - sync::Arc, - task::{Context, Waker}, -}; - -use cooked_waker::IntoWaker; - -use crate::{async_runtime::notifier::Notifier, pin_future}; - -thread_local! { - pub(crate) static WAKER_PAIR: (Arc, Waker) = { - let notifier = Arc::new(Notifier::default()); - let waker = notifier.clone().into_waker(); - (notifier, waker) - }; -} - -pub(crate) fn block_future( - future: Fut, - notifier: Arc, - waker: &Waker, -) -> Fut::Output { - let mut context: Context<'_> = Context::from_waker(waker); - pin_future!(future); - loop { - match future.as_mut().poll(&mut context) { - std::task::Poll::Ready(output) => return output, - std::task::Poll::Pending => { - notifier.wait() - } - } - } -} diff --git a/src/executors/mod.rs b/src/executors/mod.rs index 9de0074..c1a8848 100755 --- a/src/executors/mod.rs +++ b/src/executors/mod.rs @@ -1,49 +1,3 @@ -use std::{future::Future, sync::Arc, task::Waker}; +mod future_executor; -use cooked_waker::IntoWaker; - -use crate::async_runtime::{notifier::Notifier, task::Task}; - -use self::{local_executor::block_future, task_executor::block_on_task}; - -mod local_executor; -mod task_executor; - -/// Blocks the current thread until the future is polled to finish. -/// -/// Example -/// ```rust -/// let result = spawn_groups::block_on(async { -/// println!("This is an async executor"); -/// 1 -/// }); -/// assert_eq!(result, 1); -/// ``` -/// -pub fn block_on(future: Fut) -> Fut::Output { - let waker_pair: Result<(Arc, Waker), std::thread::AccessError> = - local_executor::WAKER_PAIR - .try_with(|waker_pair: &(Arc, Waker)| waker_pair.clone()); - match waker_pair { - Ok((notifier, waker)) => block_future(future, notifier, &waker), - Err(_) => { - let notifier: Arc = Arc::new(Notifier::default()); - let waker: Waker = notifier.clone().into_waker(); - block_future(future, notifier, &waker) - } - } -} - -pub(crate) fn block_task(task: Task) { - let waker_pair: Result<(Arc, Waker), std::thread::AccessError> = - local_executor::WAKER_PAIR - .try_with(|waker_pair: &(Arc, Waker)| waker_pair.clone()); - match waker_pair { - Ok((notifier, waker)) => block_on_task(task, notifier, &waker), - Err(_) => { - let notifier: Arc = Arc::new(Notifier::default()); - let waker: Waker = notifier.clone().into_waker(); - block_on_task(task, notifier, &waker) - } - } -} +pub use future_executor::block_on; diff --git a/src/executors/task_executor.rs b/src/executors/task_executor.rs deleted file mode 100755 index 3835499..0000000 --- a/src/executors/task_executor.rs +++ /dev/null @@ -1,28 +0,0 @@ -use std::{ - sync::Arc, - task::{Context, Waker}, -}; - -use crate::async_runtime::{notifier::Notifier, task::Task}; -use cooked_waker::IntoWaker; - -thread_local! { - pub(crate) static WAKER_PAIR: (Arc, Waker) = { - let notifier = Arc::new(Notifier::default()); - let waker = notifier.clone().into_waker(); - (notifier, waker) - }; -} - -pub(crate) fn block_on_task(task: Task, notifier: Arc, waker: &Waker) { - if task.is_completed() { - return; - } - let mut context: Context<'_> = Context::from_waker(waker); - loop { - match task.future.lock().as_mut().poll(&mut context) { - std::task::Poll::Ready(()) => return, - std::task::Poll::Pending => notifier.wait(), - } - } -} diff --git a/src/lib.rs b/src/lib.rs index a7de5e0..37b696a 100755 --- a/src/lib.rs +++ b/src/lib.rs @@ -91,9 +91,6 @@ //! See [`with_discarding_spawn_group`](self::with_discarding_spawn_group) //! for more information //! -//! * ``sleep`` similar to ``std::thread::sleep`` but for sleeping in asynchronous environments. See [`sleep`](self::sleep) -//! for more information -//! //! * ``block_on`` polls future to finish. See [`block_on`](self::block_on) //! for more information //! @@ -157,7 +154,7 @@ //! //! # Note //! * Import ``StreamExt`` trait from ``futures_lite::StreamExt`` or ``futures::stream::StreamExt`` or ``async_std::stream::StreamExt`` to provide a variety of convenient combinator functions on the various spawn groups. -//! * To await all running child tasks to finish their execution, call ``wait_for_all`` method on the spawn group instance unless using the [`with_discarding_spawn_group`](self::with_discarding_spawn_group) function. +//! * To await all running child tasks to finish their execution, call ``wait_for_all`` or ``wait_non_async`` methods on the various group instances //! //! # Warning //! * This crate relies on atomics @@ -169,24 +166,18 @@ mod discarding_spawn_group; mod err_spawn_group; mod spawn_group; -mod async_runtime; mod async_stream; mod executors; mod meta_types; mod shared; -mod sleeper; mod threadpool_impl; -mod yield_now; pub use discarding_spawn_group::DiscardingSpawnGroup; pub use err_spawn_group::ErrSpawnGroup; pub use executors::block_on; pub use meta_types::GetType; -use shared::initializible::Initializible; pub use shared::priority::Priority; -pub use sleeper::sleep; pub use spawn_group::SpawnGroup; -pub use yield_now::yield_now; use std::future::Future; use std::marker::PhantomData; @@ -196,7 +187,7 @@ use std::marker::PhantomData; /// This closure ensures that before the function call ends, all spawned child tasks are implicitly waited for, or the programmer can explicitly wait by calling its ``wait_for_all()`` method /// of the ``SpawnGroup`` struct. /// -/// This function use a threadpool of the same number of threads as the number of active processor count that is default amount of parallelism a program can use on the system for polling the futures +/// This function use a threadpool of the same number of threads as the number of active processor count that is default amount of parallelism a program can use on the system for polling the spawned tasks /// /// See [`SpawnGroup`](spawn_group::SpawnGroup) /// for more. @@ -240,12 +231,12 @@ pub async fn with_type_spawn_group( body: Closure, ) -> ReturnType where - Closure: FnOnce(spawn_group::SpawnGroup) -> Fut + Send + 'static, - Fut: Future + Send + 'static, - ResultType: Send + 'static, + Closure: FnOnce(spawn_group::SpawnGroup) -> Fut, + Fut: Future + 'static, + ResultType: 'static, { _ = of_type; - let task_group = spawn_group::SpawnGroup::::init(); + let task_group = spawn_group::SpawnGroup::::default(); body(task_group).await } @@ -254,7 +245,7 @@ where /// This closure ensures that before the function call ends, all spawned child tasks are implicitly waited for, or the programmer can explicitly wait by calling its ``wait_for_all()`` method /// of the ``SpawnGroup`` struct. /// -/// This function use a threadpool of the same number of threads as the number of active processor count that is default amount of parallelism a program can use on the system for polling the futures +/// This function use a threadpool of the same number of threads as the number of active processor count that is default amount of parallelism a program can use on the system for polling the spawned tasks /// /// See [`SpawnGroup`](spawn_group::SpawnGroup) /// for more. @@ -294,11 +285,11 @@ where /// ``` pub async fn with_spawn_group(body: Closure) -> ReturnType where - Closure: FnOnce(spawn_group::SpawnGroup) -> Fut + Send + 'static, - Fut: Future + Send + 'static, - ResultType: Send + 'static, + Closure: FnOnce(spawn_group::SpawnGroup) -> Fut, + Fut: Future + 'static, + ResultType: 'static, { - let task_group = spawn_group::SpawnGroup::::init(); + let task_group = spawn_group::SpawnGroup::::default(); body(task_group).await } @@ -308,7 +299,7 @@ where /// This closure ensures that before the function call ends, all spawned child tasks are implicitly waited for, or the programmer can explicitly wait by calling its ``wait_for_all()`` method /// of the ``ErrSpawnGroup`` struct /// -/// This function use a threadpool of the same number of threads as the number of active processor count that is default amount of parallelism a program can use on the system for polling the futures +/// This function use a threadpool of the same number of threads as the number of active processor count that is default amount of parallelism a program can use on the system for polling the spawned tasks /// /// See [`ErrSpawnGroup`](err_spawn_group::ErrSpawnGroup) /// for more. @@ -397,13 +388,13 @@ pub async fn with_err_type_spawn_group ReturnType where - ErrorType: Send + 'static, + ErrorType: 'static, Fut: Future, - Closure: FnOnce(err_spawn_group::ErrSpawnGroup) -> Fut + Send + 'static, - ResultType: Send + 'static, + Closure: FnOnce(err_spawn_group::ErrSpawnGroup) -> Fut, + ResultType: 'static, { _ = (of_type, error_type); - let task_group = err_spawn_group::ErrSpawnGroup::::init(); + let task_group = err_spawn_group::ErrSpawnGroup::::default(); body(task_group).await } @@ -413,7 +404,7 @@ where /// This closure ensures that before the function call ends, all spawned child tasks are implicitly waited for, or the programmer can explicitly wait by calling its ``wait_for_all()`` method /// of the ``ErrSpawnGroup`` struct /// -/// This function use a threadpoolof the same number of threads as the number of active processor count that is default amount of parallelism a program can use on the system for polling the futures +/// This function use a threadpoolof the same number of threads as the number of active processor count that is default amount of parallelism a program can use on the system for polling the spawned tasks /// /// See [`ErrSpawnGroup`](err_spawn_group::ErrSpawnGroup) /// for more. @@ -498,12 +489,12 @@ pub async fn with_err_spawn_group ReturnType where - ErrorType: Send + 'static, + ErrorType: 'static, Fut: Future, - Closure: FnOnce(err_spawn_group::ErrSpawnGroup) -> Fut + Send + 'static, - ResultType: Send + 'static, + Closure: FnOnce(err_spawn_group::ErrSpawnGroup) -> Fut, + ResultType: 'static, { - let task_group = err_spawn_group::ErrSpawnGroup::::init(); + let task_group = err_spawn_group::ErrSpawnGroup::::default(); body(task_group).await } @@ -511,7 +502,7 @@ where /// /// Ensures that before the function call ends, all spawned tasks are implicitly waited for /// -/// This function use a threadpool of the same number of threads as the number of active processor count that is default amount of parallelism a program can use on the system for polling the futures +/// This function use a threadpool of the same number of threads as the number of active processor count that is default amount of parallelism a program can use on the system for polling the spawned tasks /// /// See [`DiscardingSpawnGroup`](discarding_spawn_group::DiscardingSpawnGroup) /// for more. @@ -547,8 +538,8 @@ where pub async fn with_discarding_spawn_group(body: Closure) -> ReturnType where Fut: Future, - Closure: FnOnce(discarding_spawn_group::DiscardingSpawnGroup) -> Fut + Send + 'static, + Closure: FnOnce(discarding_spawn_group::DiscardingSpawnGroup) -> Fut, { - let discarding_tg = discarding_spawn_group::DiscardingSpawnGroup::init(); + let discarding_tg = discarding_spawn_group::DiscardingSpawnGroup::default(); body(discarding_tg).await } diff --git a/src/meta_types.rs b/src/meta_types.rs index be56ba7..85f3631 100755 --- a/src/meta_types.rs +++ b/src/meta_types.rs @@ -2,23 +2,22 @@ /// /// `GetType` provides a metatype that's a type of a type, /// it also enables a developer to pass a type as a value to specify a generic type of a parameter -/// +/// /// # Examples /// ``` /// use spawn_groups::GetType; /// use std::marker::PhantomData; -/// +/// /// fn closure_taker(with_value: T, returning_type: PhantomData, closure: FUNC) -> U /// where FUNC: Fn(T) -> U { /// closure(with_value) /// } /// /// let string_result = closure_taker(32, String::TYPE, |val| format!("{}", val) ); -/// +/// /// assert_eq!(string_result, String::from("32")); /// ``` -/// - +/// use std::marker::PhantomData; /// `GetType` trait implements asssociated constant for every type and this associated constant provides a metatype value that's a type's type value diff --git a/src/shared/initializible.rs b/src/shared/initializible.rs deleted file mode 100755 index 4dbc663..0000000 --- a/src/shared/initializible.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub trait Initializible { - fn init() -> Self; -} diff --git a/src/shared/mod.rs b/src/shared/mod.rs index 41384ef..32c911c 100755 --- a/src/shared/mod.rs +++ b/src/shared/mod.rs @@ -1,5 +1,14 @@ -pub(crate) mod initializible; +pub(crate) mod mutex; pub(crate) mod priority; +pub(crate) mod priority_task; pub(crate) mod runtime; -pub(crate) mod sharedfuncs; -pub(crate) mod wait; +mod suspender; +mod task; +mod waker; +mod waker_pair; +mod task_enum; + +pub(crate) use suspender::{pair, Suspender}; +pub(crate) use waker_pair::WAKER_PAIR; +pub(crate) use task_enum::TaskOrBarrier; +pub(crate) use task::Task; \ No newline at end of file diff --git a/src/shared/mutex.rs b/src/shared/mutex.rs new file mode 100644 index 0000000..44a9c94 --- /dev/null +++ b/src/shared/mutex.rs @@ -0,0 +1,15 @@ +use std::sync::{Mutex, MutexGuard}; + +#[derive(Default)] +pub(crate) struct StdMutex(Mutex); +pub(crate) type StdMutexGuard<'a, T> = MutexGuard<'a, T>; + +impl StdMutex { + pub(crate) fn new(t: T) -> Self { + Self(Mutex::new(t)) + } + + pub(crate) fn lock(&self) -> StdMutexGuard { + self.0.lock().unwrap_or_else(|e| e.into_inner()) + } +} diff --git a/src/shared/priority.rs b/src/shared/priority.rs index 3c66a34..e5cd08e 100755 --- a/src/shared/priority.rs +++ b/src/shared/priority.rs @@ -1,7 +1,8 @@ /// Task Priority /// -/// Spawn groups uses it to rank the importance of their spawned tasks and order of returned values only when waited for. -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Default)] +/// Spawn groups uses it to rank the importance of their spawned tasks and order of returned values only when waited for +/// that is when the ``wait_for_all`` or ``wait_non_async`` method is called +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Default)] pub enum Priority { BACKGROUND = 0, LOW, diff --git a/src/shared/priority_task.rs b/src/shared/priority_task.rs new file mode 100644 index 0000000..2faf46d --- /dev/null +++ b/src/shared/priority_task.rs @@ -0,0 +1,50 @@ +use std::{ + cmp::Ordering, + future::Future, + sync::{Arc, Barrier}, +}; + +use crate::threadpool_impl::TaskPriority; + +use super::{task::Task, task_enum::TaskOrBarrier}; + +pub(crate) struct PrioritizedTask { + pub(crate) task: TaskOrBarrier, + priority: TaskPriority, +} + +impl PartialEq for PrioritizedTask { + fn eq(&self, other: &Self) -> bool { + self.priority == other.priority + } +} + +impl Eq for PrioritizedTask {} + +impl PrioritizedTask { + pub(crate) fn new + 'static>(priority: TaskPriority, future: F) -> Self { + Self { + task: TaskOrBarrier::Task(Task::new(future)), + priority, + } + } + + pub(crate) fn new_with(priority: TaskPriority, barrier: Arc) -> Self { + Self { + task: TaskOrBarrier::Barrier(barrier), + priority, + } + } +} + +impl Ord for PrioritizedTask { + fn cmp(&self, other: &Self) -> Ordering { + other.priority.cmp(&self.priority) + } +} + +impl PartialOrd for PrioritizedTask { + fn partial_cmp(&self, other: &Self) -> Option { + Some(other.cmp(self)) + } +} diff --git a/src/shared/runtime.rs b/src/shared/runtime.rs index 37598f6..3a14de5 100755 --- a/src/shared/runtime.rs +++ b/src/shared/runtime.rs @@ -1,56 +1,45 @@ -use crate::{ - async_runtime::{executor::Executor, task::Task}, - async_stream::AsyncStream, - executors::block_task, - shared::{initializible::Initializible, priority::Priority}, -}; -use parking_lot::Mutex; +use crate::{async_stream::AsyncStream, shared::priority::Priority, threadpool_impl::ThreadPool}; use std::{ future::Future, sync::{ - atomic::{AtomicBool, Ordering}, + atomic::{AtomicUsize, Ordering}, Arc, }, }; -type TaskQueue = Arc>>; +use super::priority_task::PrioritizedTask; -pub struct RuntimeEngine { - tasks: TaskQueue, - runtime: Executor, +pub(crate) struct RuntimeEngine { stream: AsyncStream, - wait_flag: Arc, + pool: ThreadPool, + task_count: Arc, } -impl Initializible for RuntimeEngine { - fn init() -> Self { +impl RuntimeEngine { + pub(crate) fn new(count: usize) -> Self { Self { - tasks: Arc::new(Mutex::new(vec![])), + pool: ThreadPool::new(count), stream: AsyncStream::new(), - runtime: Executor::default(), - wait_flag: Arc::new(AtomicBool::new(false)), + task_count: Arc::new(AtomicUsize::default()), } } } -impl RuntimeEngine { - pub(crate) fn new(count: usize) -> Self { +impl Default for RuntimeEngine { + fn default() -> Self { Self { - tasks: Arc::new(Mutex::new(vec![])), + pool: ThreadPool::default(), stream: AsyncStream::new(), - runtime: Executor::new(count), - wait_flag: Arc::new(AtomicBool::new(false)), + task_count: Arc::new(AtomicUsize::default()), } } } impl RuntimeEngine { - pub(crate) fn cancel(&mut self) { - self.store(true); - self.runtime.cancel(); - self.tasks.lock().clear(); - self.stream.cancel_tasks(); - self.poll(); + pub(crate) fn cancel(&self) { + self.pool.clear(); + self.pool.wait_for_all(); + self.task_count.store(0, Ordering::Relaxed); } } @@ -59,64 +48,44 @@ impl RuntimeEngine { self.stream.clone() } - pub(crate) fn end(&mut self) { - self.runtime.cancel(); - self.tasks.lock().clear(); + pub(crate) fn end(&self) { + self.pool.clear(); + self.pool.wait_for_all(); + self.task_count.store(0, Ordering::Relaxed); + self.pool.end() } } -impl RuntimeEngine { +impl RuntimeEngine { pub(crate) fn wait_for_all_tasks(&self) { self.poll(); - self.runtime.cancel(); - self.tasks.lock().sort_by(|lhs, rhs| lhs.0.cmp(&rhs.0)); - self.store(true); - while let Some((_, handle)) = self.tasks.lock().pop() { - self.runtime.submit(move || { - block_task(handle); - }); - } - self.poll(); + self.task_count.store(0, Ordering::Relaxed); } } impl RuntimeEngine { - pub(crate) fn load(&self) -> bool { - self.wait_flag.load(Ordering::Acquire) - } - - pub(crate) fn store(&self, val: bool) { - self.wait_flag.store(val, Ordering::Release); - } -} - -impl RuntimeEngine { - pub(crate) fn write_task(&self, priority: Priority, task: F) + pub(crate) fn write_task(&mut self, priority: Priority, task: F) where - F: Future + Send + 'static, + F: Future + 'static, { - if self.load() { - self.runtime.start(); - self.store(false); - } - self.stream.increment(); - let mut stream: AsyncStream = self.stream(); - let runtime = self.runtime.clone(); - let tasks: Arc>> = self.tasks.clone(); - self.runtime.submit(move || { - tasks.lock().push(( - priority, - runtime.spawn(async move { - stream.insert_item(task.await).await; - stream.decrement_task_count(); - }), - )); - }); + let (stream, task_counter) = (self.stream(), self.task_count.clone()); + stream.increment(); + task_counter.fetch_add(1, Ordering::Relaxed); + self.pool + .submit(PrioritizedTask::new(priority.into(), async move { + let task_result = task.await; + stream.insert_item(task_result).await; + task_counter.fetch_sub(1, Ordering::Relaxed); + })); } } impl RuntimeEngine { - pub(crate) fn poll(&self) { - self.runtime.poll_all(); + fn poll(&self) { + self.pool.wait_for_all(); + } + + pub(crate) fn task_count(&self) -> usize { + self.task_count.load(Ordering::Acquire) } } diff --git a/src/shared/sharedfuncs.rs b/src/shared/sharedfuncs.rs deleted file mode 100755 index 5a96d40..0000000 --- a/src/shared/sharedfuncs.rs +++ /dev/null @@ -1,20 +0,0 @@ -use crate::shared::priority::Priority; -use std::future::Future; - - -/// The basic functionalities between all kinds of spawn groups -pub trait Shared { - /// A value return when a task is being awaited for - type Result; - /// Add a new task into the engine - fn add_task(&mut self, priority: Priority, closure: F) - where - F: Future + Send + 'static; - /// Cancels all running tasks in the engine - fn cancel_all_tasks(&mut self); - /// Add a new task only if the engine is not cancelled yet, - /// otherwise does nothing - fn add_task_unlessed_cancelled(&mut self, priority: Priority, closure: F) - where - F: Future + Send + 'static; -} diff --git a/src/shared/suspender.rs b/src/shared/suspender.rs new file mode 100755 index 0000000..fe874f3 --- /dev/null +++ b/src/shared/suspender.rs @@ -0,0 +1,96 @@ +use std::{ + sync::{Arc, Condvar}, + task::Waker, +}; + +use crate::shared::mutex::StdMutex; + +use super::waker::waker_helper; + +pub(crate) fn pair() -> (Arc, Waker) { + let suspender = Arc::new(Suspender::new()); + let resumer = suspender.clone(); + (suspender, waker_helper(resumer)) +} + +pub(crate) struct Suspender { + inner: Inner, +} + +impl Suspender { + pub(crate) fn new() -> Suspender { + Suspender { + inner: Inner { + lock: StdMutex::new(State::Initial), + cvar: Condvar::new(), + }, + } + } + + pub(crate) fn suspend(&self) { + self.inner.suspend(); + } + + pub(crate) fn resume(&self) { + self.inner.resume(); + } +} + +#[derive(PartialEq)] +enum State { + Initial, + Notified, + Suspended, +} + +struct Inner { + lock: StdMutex, + cvar: Condvar, +} + +impl Inner { + fn suspend(&self) { + // Acquire the lock first + let mut lock = self.lock.lock(); + + // check the state the lock is in right now + match *lock { + // suspend the thread + State::Initial => { + *lock = State::Suspended; + // suspend this thread until we get a notification + while *lock == State::Suspended { + lock = self.cvar.wait(lock).unwrap(); + } + } + // already notified this thread so just revert state back to empty + // then return + State::Notified => { + *lock = State::Initial; + } + State::Suspended => { + panic!("cannot suspend a thread that is already in a suspended state") + } + } + } + + fn resume(&self) { + // Acquire the lock first + let mut lock = self.lock.lock(); + + // check if the state is empty or suspended + match *lock { + State::Initial => { + // send notification + *lock = State::Notified; + } + State::Suspended => { + // send notification + *lock = State::Notified; + // resume the suspended thread + self.cvar.notify_one(); + } + _ => {} + } + } +} diff --git a/src/shared/task.rs b/src/shared/task.rs new file mode 100755 index 0000000..15eeffb --- /dev/null +++ b/src/shared/task.rs @@ -0,0 +1,30 @@ +use std::{ + future::Future, + panic::{RefUnwindSafe, UnwindSafe}, + pin::Pin, + task::{Context, Poll}, +}; + +type LocalFuture = dyn Future; + +pub(crate) struct Task { + future: Pin>>, +} + +impl Task { + pub(crate) fn new + 'static>(future: Fut) -> Self { + Self { + future: unsafe { Pin::new_unchecked(Box::new(future)) }, + } + } + + #[inline] + pub(crate) fn poll_task(&mut self, cx: &mut Context<'_>) -> Poll { + self.future.as_mut().poll(cx) + } +} + +impl UnwindSafe for Task {} +impl RefUnwindSafe for Task {} +unsafe impl Send for Task {} +unsafe impl Sync for Task {} diff --git a/src/shared/task_enum.rs b/src/shared/task_enum.rs new file mode 100644 index 0000000..aa1f580 --- /dev/null +++ b/src/shared/task_enum.rs @@ -0,0 +1,9 @@ +use std::sync::{Arc, Barrier}; + +use super::task::Task; + +// Naming is hard guys +pub(crate) enum TaskOrBarrier { + Task(Task), + Barrier(Arc), +} diff --git a/src/shared/wait.rs b/src/shared/wait.rs deleted file mode 100755 index cfa9613..0000000 --- a/src/shared/wait.rs +++ /dev/null @@ -1,6 +0,0 @@ -use async_trait::async_trait; - -#[async_trait] -pub trait Waitable { - async fn wait(&self); -} diff --git a/src/shared/waker.rs b/src/shared/waker.rs new file mode 100644 index 0000000..d9cb381 --- /dev/null +++ b/src/shared/waker.rs @@ -0,0 +1,56 @@ +use std::{ + mem::ManuallyDrop, + sync::Arc, + task::{RawWaker, RawWakerVTable, Waker}, +}; + +use super::suspender::Suspender; + +pub(crate) fn waker_helper(suspender: Arc) -> Waker { + let raw: *const () = Arc::into_raw(suspender) as *const (); + unsafe { + Waker::from_raw(RawWaker::new( + raw, + &RawWakerVTable::new(clone_waker, wake, wake_by_ref, drop_waker), + )) + } +} + +// clones the waker +pub(crate) unsafe fn clone_waker(ptr: *const ()) -> RawWaker { + let ptr = ptr as *mut Suspender; + let waker = ManuallyDrop::new(Arc::from_raw(ptr)); + let cloned = (*waker).clone(); + + debug_assert_eq!(Arc::into_raw(ManuallyDrop::into_inner(waker)), ptr); + + let cloned_ptr = Arc::into_raw(cloned) as *const (); + + RawWaker::new( + cloned_ptr, + &RawWakerVTable::new(clone_waker, wake, wake_by_ref, drop_waker), + ) +} + +// wakes up by consuming it +pub(crate) unsafe fn wake(ptr: *const ()) { + let ptr = ptr as *mut Suspender; + let waker = Arc::from_raw(ptr); + waker.resume(); +} + +// wakes up by reference +pub(crate) unsafe fn wake_by_ref(ptr: *const ()) { + let ptr = ptr as *mut Suspender; + + let waker = ManuallyDrop::new(Arc::from_raw(ptr)); + waker.resume(); + + debug_assert_eq!(Arc::into_raw(ManuallyDrop::into_inner(waker)), ptr); +} + +// drops the waker +pub(crate) unsafe fn drop_waker(ptr: *const ()) { + let ptr = ptr as *mut Suspender; + drop(Arc::from_raw(ptr)); +} diff --git a/src/shared/waker_pair.rs b/src/shared/waker_pair.rs new file mode 100644 index 0000000..e8af5ac --- /dev/null +++ b/src/shared/waker_pair.rs @@ -0,0 +1,8 @@ +use super::{pair, Suspender}; +use std::{sync::Arc, task::Waker}; + +thread_local! { + pub(crate) static WAKER_PAIR: (Arc, Waker) = { + pair() + }; +} diff --git a/src/sleeper/delay.rs b/src/sleeper/delay.rs deleted file mode 100755 index 53a651d..0000000 --- a/src/sleeper/delay.rs +++ /dev/null @@ -1,36 +0,0 @@ -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, - time::{Duration, Instant}, -}; - -#[derive(Debug)] -#[must_use = "futures do nothing unless you `.await` or poll them"] -pub struct Delay { - duration: Duration, - now: Instant, -} - -impl Delay { - pub(crate) fn new(duration: Duration) -> Self { - Delay { - duration, - now: Instant::now(), - } - } -} - -impl Future for Delay { - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.now.elapsed() >= self.duration { - true => Poll::Ready(()), - false => { - cx.waker().wake_by_ref(); - Poll::Pending - } - } - } -} diff --git a/src/sleeper/mod.rs b/src/sleeper/mod.rs deleted file mode 100755 index ff019a8..0000000 --- a/src/sleeper/mod.rs +++ /dev/null @@ -1,25 +0,0 @@ -mod delay; - -use std::time::Duration; - -use self::delay::Delay; - -/// Sleeps for the specified amount of time. -/// -/// This function might sleep for slightly longer than the specified duration but never less. -/// -/// This function is an async version of ``std::thread::sleep``. -/// -/// Example -/// -/// ```rust -/// use spawn_groups::{block_on, sleep}; -/// use std::time::Duration; -/// -/// block_on(async { -/// sleep(Duration::from_secs(2)).await; -/// }); -/// ``` -pub fn sleep(duration: Duration) -> Delay { - Delay::new(duration) -} diff --git a/src/spawn_group.rs b/src/spawn_group.rs index 4c42947..682db00 100755 --- a/src/spawn_group.rs +++ b/src/spawn_group.rs @@ -1,8 +1,4 @@ -use crate::shared::{ - initializible::Initializible, priority::Priority, runtime::RuntimeEngine, sharedfuncs::Shared, - wait::Waitable, -}; -use async_trait::async_trait; +use crate::shared::{priority::Priority, runtime::RuntimeEngine}; use futures_lite::{Stream, StreamExt}; use std::{ future::Future, @@ -29,39 +25,50 @@ use std::{ /// /// It dereferences into a ``futures`` crate ``Stream`` type where the results of each finished child task is stored and it pops out the result in First-In First-Out /// FIFO order whenever it is being used - -pub struct SpawnGroup { +pub struct SpawnGroup { + runtime: RuntimeEngine, + count: Arc, /// A field that indicates if the spawn group had been cancelled pub is_cancelled: bool, wait_at_drop: bool, - count: Arc, - runtime: RuntimeEngine, } -impl SpawnGroup { +impl SpawnGroup { /// Instantiates `SpawnGroup` with a specific number of threads to use in the underlying threadpool when polling futures - /// + /// /// # Parameters /// /// * `num_of_threads`: number of threads to use pub fn new(num_of_threads: usize) -> Self { + Self { + runtime: RuntimeEngine::new(num_of_threads), + count: Arc::new(AtomicUsize::new(0)), + is_cancelled: false, + wait_at_drop: true, + } + } +} + +impl Default for SpawnGroup { + /// Instantiates `SpawnGroup` with the number of threads as the number of cores as the system to use in the underlying threadpool when polling futures + fn default() -> Self { Self { is_cancelled: false, count: Arc::new(AtomicUsize::new(0)), - runtime: RuntimeEngine::new(num_of_threads), - wait_at_drop: false, + runtime: RuntimeEngine::default(), + wait_at_drop: true, } } } -impl SpawnGroup { +impl SpawnGroup { /// Don't implicity wait for spawned child tasks to finish before being dropped pub fn dont_wait_at_drop(&mut self) { self.wait_at_drop = false; } } -impl SpawnGroup { +impl SpawnGroup { /// Spawns a new task into the spawn group /// # Parameters /// @@ -69,9 +76,10 @@ impl SpawnGroup { /// * `closure`: an async closure that return a value of type ``ValueType`` pub fn spawn_task(&mut self, priority: Priority, closure: F) where - F: Future as Shared>::Result> + Send + 'static, + F: Future + Send + 'static, { - self.add_task(priority, closure); + self.increment_count(); + self.runtime.write_task(priority, closure); } /// Spawn a new task only if the group is not cancelled yet, @@ -83,46 +91,56 @@ impl SpawnGroup { /// * `closure`: an async closure that return a value of type ``ValueType`` pub fn spawn_task_unlessed_cancelled(&mut self, priority: Priority, closure: F) where - F: Future as Shared>::Result> + Send + 'static, + F: Future + Send + 'static, { - self.add_task_unlessed_cancelled(priority, closure); + if !self.is_cancelled { + self.runtime.write_task(priority, closure) + } } /// Cancels all running task in the spawn group pub fn cancel_all(&mut self) { - self.cancel_all_tasks(); + self.runtime.cancel(); + self.is_cancelled = true; + self.decrement_count_to_zero(); } } -impl SpawnGroup { +impl SpawnGroup { /// Returns the first element of the stream, or None if it is empty. pub async fn first(&self) -> Option { self.runtime.stream().first().await } } -impl SpawnGroup { +impl SpawnGroup { /// Waits for all remaining child tasks for finish. - pub async fn wait_for_all(&self) { - self.wait().await; + pub async fn wait_for_all(&mut self) { + self.wait_non_async() + } + + /// Waits for all remaining child tasks for finish in non async context. + pub fn wait_non_async(&mut self) { + self.runtime.wait_for_all_tasks(); + self.decrement_count_to_zero() } } -impl SpawnGroup { +impl SpawnGroup { fn increment_count(&self) { - self.count.fetch_add(1, Ordering::Acquire); + self.count.fetch_add(1, Ordering::Relaxed); } fn count(&self) -> usize { - self.count.load(Ordering::Acquire) + self.count.load(Ordering::Relaxed) } fn decrement_count_to_zero(&self) { - self.count.store(0, Ordering::Release); + self.count.store(0, Ordering::Relaxed); } } -impl SpawnGroup { +impl SpawnGroup { /// A Boolean value that indicates whether the group has any remaining tasks. /// /// At the start of the body of a ``with_spawn_group()`` call, , or before calling ``spawn_task`` or ``spawn_task_unless_cancelled`` methods @@ -132,21 +150,21 @@ impl SpawnGroup { /// - true: if there's no child task still running /// - false: if any child task is still running pub fn is_empty(&self) -> bool { - if self.count() == 0 || self.runtime.stream().task_count() == 0 { + if self.count() == 0 || self.runtime.task_count() == 0 { return true; } false } } -impl SpawnGroup { +impl SpawnGroup { /// Returns an instance of the `Stream` trait. pub fn stream(&self) -> impl Stream { self.runtime.stream() } } -impl SpawnGroup { +impl SpawnGroup { /// Waits for a specific number of spawned child tasks to finish and returns their respectively result as a vector /// /// # Panics @@ -191,66 +209,19 @@ impl SpawnGroup { } } -impl Drop for SpawnGroup { +impl Drop for SpawnGroup { fn drop(&mut self) { if self.wait_at_drop { self.runtime.wait_for_all_tasks(); - } else { - self.runtime.end() } + self.runtime.end() } } -impl Initializible for SpawnGroup { - fn init() -> Self { - SpawnGroup { - runtime: RuntimeEngine::init(), - is_cancelled: false, - count: Arc::new(AtomicUsize::new(0)), - wait_at_drop: true, - } - } -} - -impl Shared for SpawnGroup { - type Result = ValueType; - - fn add_task(&mut self, priority: Priority, closure: F) - where - F: Future + Send + 'static, - { - self.increment_count(); - self.runtime.write_task(priority, closure); - } - - fn cancel_all_tasks(&mut self) { - self.runtime.cancel(); - self.is_cancelled = true; - self.decrement_count_to_zero(); - } - - fn add_task_unlessed_cancelled(&mut self, priority: Priority, closure: F) - where - F: Future + Send + 'static, - { - if !self.is_cancelled { - self.add_task(priority, closure) - } - } -} - -impl Stream for SpawnGroup { +impl Stream for SpawnGroup { type Item = ValueType; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.runtime.stream().poll_next(cx) } } - -#[async_trait] -impl Waitable for SpawnGroup { - async fn wait(&self) { - self.runtime.wait_for_all_tasks(); - self.decrement_count_to_zero(); - } -} diff --git a/src/threadpool_impl/channel.rs b/src/threadpool_impl/channel.rs new file mode 100644 index 0000000..f597d57 --- /dev/null +++ b/src/threadpool_impl/channel.rs @@ -0,0 +1,87 @@ +use std::{ + collections::BinaryHeap, + sync::{ + atomic::{AtomicBool, Ordering}, + Condvar, + }, +}; + +use crate::shared::mutex::StdMutex; + +pub(crate) struct Channel { + inner: Inner, +} + +impl Channel { + pub(crate) fn enqueue(&self, value: T) { + self.inner.enqueue(value) + } +} + +impl Channel { + pub(crate) fn new() -> Self { + Self { + inner: Inner::new(), + } + } +} + +impl Channel { + pub(crate) fn dequeue(&self) -> Option { + self.inner.dequeue() + } +} + +impl Channel { + pub(crate) fn clear(&self) { + self.inner.clear() + } + + pub(crate) fn end(&self) { + self.inner.end() + } +} + +struct Inner { + mtx: StdMutex>, + cvar: Condvar, + closed: AtomicBool, +} + +impl Inner { + fn new() -> Self { + Self { + mtx: StdMutex::new(BinaryHeap::new()), + cvar: Condvar::new(), + closed: AtomicBool::new(false), + } + } + + fn enqueue(&self, value: T) { + let mut lock = self.mtx.lock(); + lock.push(value); + self.cvar.notify_one(); + } + + fn dequeue(&self) -> Option { + let mut lock = self.mtx.lock(); + while lock.is_empty() { + if self.closed.load(Ordering::Relaxed) { + return None; + } + lock = self.cvar.wait(lock).unwrap(); + } + lock.pop() + } + + fn clear(&self) { + self.mtx.lock().clear(); + } + + fn end(&self) { + let mut lock = self.mtx.lock(); + self.closed.store(true, Ordering::Relaxed); + lock.clear(); + self.cvar.notify_all(); + } +} diff --git a/src/threadpool_impl/iteratorimpl.rs b/src/threadpool_impl/iteratorimpl.rs deleted file mode 100644 index b2ee6a4..0000000 --- a/src/threadpool_impl/iteratorimpl.rs +++ /dev/null @@ -1,12 +0,0 @@ -use super::{QueueOperation, ThreadSafeQueue, Func}; - -impl Iterator for ThreadSafeQueue> { - type Item = QueueOperation; - - fn next(&mut self) -> Option { - let Some(value) = self.dequeue() else { - return Some(QueueOperation::NotYet); - }; - Some(value) - } -} diff --git a/src/threadpool_impl/mod.rs b/src/threadpool_impl/mod.rs index de45d18..ae06898 100644 --- a/src/threadpool_impl/mod.rs +++ b/src/threadpool_impl/mod.rs @@ -1,11 +1,7 @@ -mod iteratorimpl; -mod queue; -mod queueops; -mod threadpool; +mod channel; +mod task_priority; mod thread; +mod threadpool; -pub(crate) type Func = dyn FnOnce() + Send; - -pub(crate) use queue::ThreadSafeQueue; -pub(crate) use queueops::QueueOperation; +pub(crate) use task_priority::TaskPriority; pub(crate) use threadpool::ThreadPool; diff --git a/src/threadpool_impl/queue.rs b/src/threadpool_impl/queue.rs deleted file mode 100644 index 28dd70a..0000000 --- a/src/threadpool_impl/queue.rs +++ /dev/null @@ -1,42 +0,0 @@ -use std::{ - collections::VecDeque, - sync::{Arc, Mutex}, -}; - -#[derive(Default)] -pub(crate) struct ThreadSafeQueue { - buffer: Arc>>, -} - -impl ThreadSafeQueue { - pub fn enqueue(&self, value: ItemType) { - if let Ok(mut lock) = self.buffer.lock() { - lock.push_back(value); - } - } -} - -impl ThreadSafeQueue { - pub fn new() -> Self { - Self { - buffer: Arc::new(Mutex::new(VecDeque::new())), - } - } -} - -impl Clone for ThreadSafeQueue { - fn clone(&self) -> Self { - Self { - buffer: self.buffer.clone(), - } - } -} - -impl ThreadSafeQueue { - pub fn dequeue(&self) -> Option { - let Ok(mut buffer_lock) = self.buffer.lock() else { - return None; - }; - buffer_lock.pop_front() - } -} diff --git a/src/threadpool_impl/queueops.rs b/src/threadpool_impl/queueops.rs deleted file mode 100644 index d2a4ecd..0000000 --- a/src/threadpool_impl/queueops.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub(crate) enum QueueOperation { - Ready(Box), - NotYet, - Wait, -} \ No newline at end of file diff --git a/src/threadpool_impl/task_priority.rs b/src/threadpool_impl/task_priority.rs new file mode 100644 index 0000000..ca786d6 --- /dev/null +++ b/src/threadpool_impl/task_priority.rs @@ -0,0 +1,25 @@ +use crate::Priority; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] +pub(crate) enum TaskPriority { + Wait, + Background, + Low, + Utility, + Medium, + High, + UserInitiated, +} + +impl From for TaskPriority { + fn from(value: Priority) -> Self { + match value { + Priority::BACKGROUND => TaskPriority::Background, + Priority::LOW => TaskPriority::Low, + Priority::UTILITY => TaskPriority::Utility, + Priority::MEDIUM => TaskPriority::Medium, + Priority::HIGH => TaskPriority::High, + Priority::USERINITIATED => TaskPriority::UserInitiated, + } + } +} diff --git a/src/threadpool_impl/thread.rs b/src/threadpool_impl/thread.rs index e972ec1..473eed3 100644 --- a/src/threadpool_impl/thread.rs +++ b/src/threadpool_impl/thread.rs @@ -1,23 +1,63 @@ -use std::thread; +use std::{ + panic::catch_unwind, + sync::Arc, + task::{Context, Poll}, + thread::spawn, +}; + +use crate::shared::{priority_task::PrioritizedTask, Suspender, Task, TaskOrBarrier, WAKER_PAIR}; + +use super::channel::Channel; pub(crate) struct UniqueThread { - handle: thread::JoinHandle<()>, + task_channel: Arc>>, } impl UniqueThread { - pub(crate) fn new(name: String, task: Task) -> Self { - let handle = thread::Builder::new() - .name(name) - .spawn(move || { - task(); - }) - .unwrap(); - UniqueThread { handle } + pub(crate) fn submit_task(&self, task: PrioritizedTask<()>) { + self.task_channel.enqueue(task); + } + + pub(crate) fn clear(&self) { + self.task_channel.clear() + } + + pub(crate) fn end(&self) { + self.task_channel.end() } } -impl UniqueThread { - pub(crate) fn join(self) { - _ = self.handle.join(); +impl Default for UniqueThread { + fn default() -> Self { + let task_channel: Arc>> = Arc::new(Channel::new()); + let chan = task_channel.clone(); + spawn(move || { + WAKER_PAIR.with(|waker_pair| loop { + let Some(task) = chan.dequeue() else { return }; + let mut context = Context::from_waker(&waker_pair.1); + match task.task { + TaskOrBarrier::Task(mut task) => { + drop(catch_unwind(move || { + poll_task(&mut task, &waker_pair.0, &mut context) + })); + } + TaskOrBarrier::Barrier(barrier) => { + barrier.wait(); + } + } + }); + }); + Self { task_channel } + } +} + +fn poll_task(task: &mut Task<()>, suspender: &Arc, context: &mut Context<'_>) { + loop { + match task.poll_task(context) { + Poll::Ready(()) => return, + Poll::Pending => { + suspender.suspend(); + } + } } } diff --git a/src/threadpool_impl/threadpool.rs b/src/threadpool_impl/threadpool.rs index dda17e9..c647233 100644 --- a/src/threadpool_impl/threadpool.rs +++ b/src/threadpool_impl/threadpool.rs @@ -1,133 +1,67 @@ use std::{ - backtrace, panic, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, Barrier, - }, - thread, + sync::{Arc, Barrier}, + thread::available_parallelism, }; -use super::{queueops::QueueOperation, thread::UniqueThread, Func, ThreadSafeQueue}; +use crate::shared::priority_task::PrioritizedTask; -pub struct ThreadPool { +use super::{task_priority::TaskPriority, thread::UniqueThread}; + +pub(crate) struct ThreadPool { handles: Vec, - count: usize, - queue: ThreadSafeQueue>, - barrier: Arc, - stop_flag: Arc, + index: usize, } -impl Default for ThreadPool { - fn default() -> Self { - panic_hook(); - let queue = ThreadSafeQueue::new(); - let count: usize; - if let Ok(thread_count) = thread::available_parallelism() { - count = thread_count.get(); - } else { - count = 1; - } - let barrier = Arc::new(Barrier::new(count + 1)); - let stop_flag = Arc::new(AtomicBool::new(false)); - let handles = (0..count) - .map(|index| start(index, queue.clone(), barrier.clone(), stop_flag.clone())) - .collect(); +impl ThreadPool { + pub(crate) fn new(count: usize) -> Self { + assert!(count > 0); ThreadPool { - handles, - queue, - count, - barrier, - stop_flag, + index: 0, + handles: (1..=count).map(|_| UniqueThread::default()).collect(), } } } -impl ThreadPool { - pub(crate) fn new(count: usize) -> Self { - panic_hook(); - let queue = ThreadSafeQueue::new(); - let barrier = Arc::new(Barrier::new(count + 1)); - let stop_flag = Arc::new(AtomicBool::new(false)); - let handles = (0..count) - .map(|index| start(index, queue.clone(), barrier.clone(), stop_flag.clone())) - .collect(); +impl Default for ThreadPool { + fn default() -> Self { + let count: usize = available_parallelism() + .unwrap_or(unsafe { std::num::NonZeroUsize::new_unchecked(1) }) + .get(); + ThreadPool { - handles, - queue, - count, - barrier, - stop_flag, + handles: (1..=count).map(|_| UniqueThread::default()).collect(), + index: 0, } } } impl ThreadPool { - pub fn submit(&self, task: Task) - where - Task: FnOnce() + 'static + Send, - { - self.queue.enqueue(QueueOperation::Ready(Box::new(task))); + pub(crate) fn submit(&mut self, task: PrioritizedTask<()>) { + let old_index = self.index; + self.index = (self.index + 1) % self.handles.len(); + self.handles[old_index].submit_task(task); } -} -impl ThreadPool { - pub fn wait_for_all(&self) { - for _ in 0..self.count { - self.queue.enqueue(QueueOperation::Wait); - } - self.barrier.wait(); + pub(crate) fn wait_for_all(&self) { + let barrier = Arc::new(Barrier::new(self.handles.len() + 1)); + self.handles.iter().for_each(|channel| { + channel.submit_task(PrioritizedTask::new_with( + TaskPriority::Wait, + barrier.clone(), + )); + }); + barrier.wait(); } } impl ThreadPool { - fn cancel_all(&self) { - self.stop_flag - .store(true, std::sync::atomic::Ordering::Release) + pub(crate) fn end(&self) { + self.handles.iter().for_each(|channel| channel.end()); } } -impl Drop for ThreadPool { - fn drop(&mut self) { - _ = panic::take_hook(); - self.cancel_all(); - while let Some(handle) = self.handles.pop() { - handle.join(); - } +impl ThreadPool { + pub(crate) fn clear(&self) { + self.handles.iter().for_each(|channel| channel.clear()); } } - -fn start( - index: usize, - queue: ThreadSafeQueue>, - barrier: Arc, - stop_flag: Arc, -) -> UniqueThread { - UniqueThread::new(format!("ThreadPool #{}", index), move || { - for op in queue { - match (op, stop_flag.load(Ordering::Acquire)) { - (QueueOperation::NotYet, false) => continue, - (QueueOperation::Ready(work), false) => { - work(); - } - (QueueOperation::Wait, false) => _ = barrier.wait(), - _ => { - return; - } - } - } - }) -} - -fn panic_hook() { - panic::set_hook(Box::new(move |info: &panic::PanicInfo<'_>| { - let msg = format!( - "{} panicked at location {} with {} \nBacktrace:\n{}", - thread::current().name().unwrap(), - info.location().unwrap(), - info.to_string().split('\n').collect::>()[1], - backtrace::Backtrace::capture() - ); - eprintln!("{}", msg); - _ = panic::take_hook(); - })); -} diff --git a/src/yield_now/mod.rs b/src/yield_now/mod.rs deleted file mode 100755 index e85848a..0000000 --- a/src/yield_now/mod.rs +++ /dev/null @@ -1,20 +0,0 @@ -use crate::yield_now::yielder::Yielder; - -mod yielder; - -/// Wakes the current task and returns [`std::task::Poll::Pending`] once. -/// -/// This function is useful when we want to cooperatively give time to the task executor. It is -/// generally a good idea to yield inside loops because that way we make sure long-running tasks -/// don't prevent other tasks from running. -/// -/// # Examples -/// ``` -/// use spawn_groups::{block_on, yield_now}; -/// block_on(async { -/// yield_now().await; -/// }); -/// ``` -pub fn yield_now() -> Yielder { - Yielder::default() -} diff --git a/src/yield_now/yielder.rs b/src/yield_now/yielder.rs deleted file mode 100755 index 7436af5..0000000 --- a/src/yield_now/yielder.rs +++ /dev/null @@ -1,25 +0,0 @@ -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, -}; - -#[derive(Debug, Default)] -#[must_use = "futures do nothing unless you `.await` or poll them"] -pub struct Yielder { - yield_now: bool, -} - -impl Future for Yielder { - type Output = (); - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.yield_now { - true => Poll::Ready(()), - false => { - self.yield_now = true; - cx.waker().wake_by_ref(); - Poll::Pending - } - } - } -}