diff --git a/shared/client/src/client.rs b/shared/client/src/client.rs index d44c06f70..38fff1fde 100644 --- a/shared/client/src/client.rs +++ b/shared/client/src/client.rs @@ -5,15 +5,13 @@ use crate::{ }; use anyhow::anyhow; use anyhow::{Error, Result, bail}; -use futures::future::join_all; -use iroh::protocol::Router; use psyche_coordinator::{Commitment, CommitteeSelection, Coordinator, RunState}; use psyche_core::NodeIdentity; use psyche_metrics::{ClientMetrics, ClientRoleInRound, PeerConnection}; use psyche_network::{ AuthenticatableIdentity, BlobTicket, DownloadComplete, DownloadRetryInfo, DownloadType, MAX_DOWNLOAD_RETRIES, ModelRequestType, NetworkEvent, NetworkTUIState, NodeId, - PeerManagerHandle, RetriedDownloadsHandle, SharableModel, TransmittableDownload, allowlist, + ParameterDownloaderHandle, PeerManagerHandle, SharableModel, TransmittableDownload, allowlist, blob_ticket_param_request_task, raw_p2p_verify, }; use psyche_watcher::{Backend, BackendWatcher}; @@ -122,7 +120,10 @@ impl + 'sta tx_broadcast_finished, }); - let retried_downloads = RetriedDownloadsHandle::new(); + let parameter_downloader_handle = ParameterDownloaderHandle::new( + tx_params_download.clone(), + max_concurrent_parameter_requests, + ); let mut sharable_model = SharableModel::empty(); let peer_manager = Arc::new(PeerManagerHandle::new( MAX_ERRORS_PER_PEER, @@ -272,7 +273,8 @@ impl + 'sta }) => { let _ = trace_span!("NetworkEvent::DownloadComplete", hash = %hash).entered(); metrics.record_download_completed(hash, from); - if retried_downloads.remove(hash).await.is_some() { + parameter_downloader_handle.download_succeeded(); + if parameter_downloader_handle.remove(hash).await.is_some() { info!("Successfully downloaded previously failed blob {}", hex::encode(hash)); } match download_data { @@ -304,7 +306,7 @@ impl + 'sta NetworkEvent::DownloadFailed(dl) => { let _ = trace_span!("NetworkEvent::DownloadFailed", error=%dl.error).entered(); let hash = dl.blob_ticket.hash(); - let retries = retried_downloads.get(hash).await.map(|i| i.retries).unwrap_or(0); + let retries = parameter_downloader_handle.get(hash).await.map(|i| i.retries).unwrap_or(0); let download_type_clone = dl.download_type.clone(); match dl.download_type { @@ -324,18 +326,18 @@ impl + 'sta ); let router = p2p.router().clone(); let peer_manager = peer_manager.clone(); - let retried_downloads = retried_downloads.clone(); + let parameter_downloader_handle = parameter_downloader_handle.clone(); let param_requests_cancel_token = param_requests_cancel_token.clone(); tokio::spawn(async move { - let blob_ticket_to_retry = if let Ok(new_blob_ticket) = get_blob_ticket_to_download(router.clone(), request_type, peer_manager.clone(), param_requests_cancel_token).await { + let blob_ticket_to_retry = if let Ok((new_blob_ticket, _)) = blob_ticket_param_request_task(request_type, router.clone(), peer_manager.clone(), param_requests_cancel_token).await { // We remove the old hash because we're getting the blob from a new peer that has its own version of the model parameter or config blob - retried_downloads.remove(hash).await; + parameter_downloader_handle.remove(hash).await; new_blob_ticket } else { dl.blob_ticket }; - retried_downloads.insert(DownloadRetryInfo { + parameter_downloader_handle.insert(DownloadRetryInfo { retries: retries + 1, retry_time, tag: dl.tag, @@ -348,7 +350,7 @@ impl + 'sta if retries >= MAX_DOWNLOAD_RETRIES { metrics.record_download_perma_failed(); warn!("Distro result download failed (not retrying): {}", dl.error); - retried_downloads.remove(hash).await; + parameter_downloader_handle.remove(hash).await; } else { metrics.record_download_failed(); let backoff_duration = DOWNLOAD_RETRY_BACKOFF_BASE.mul_f32(2_f32.powi(retries as i32)); @@ -359,7 +361,7 @@ impl + 'sta backoff_duration, dl.error ); - retried_downloads.insert(DownloadRetryInfo { + parameter_downloader_handle.insert(DownloadRetryInfo { retries: retries + 1, retry_time, tag: dl.tag, @@ -481,12 +483,12 @@ impl + 'sta let tx_params_download = tx_params_download.clone(); let tx_config_download = tx_config_download.clone(); let metrics = metrics.clone(); - let retried_downloads = retried_downloads.clone(); + let parameter_downloader_handle = parameter_downloader_handle.clone(); tokio::spawn(async move { - let pending_retries: Vec<(psyche_network::Hash, BlobTicket, Tag, DownloadType)> = retried_downloads.pending_retries().await; + let pending_retries: Vec<(psyche_network::Hash, BlobTicket, Tag, DownloadType)> = parameter_downloader_handle.pending_retries().await; for (hash, ticket, tag, download_type) in pending_retries { - let retries = retried_downloads.update_time(hash).await; + let retries = parameter_downloader_handle.update_time(hash).await; metrics.record_download_retry(hash); // We check the type of the failed download and send it to the appropriate channel to retry it @@ -499,7 +501,7 @@ impl + 'sta match inner { ModelRequestType::Parameter(parameter) => { info!("Retrying download for model parameter: {parameter}, (attempt {})", retries); - let _ = tx_params_download.send(vec![(ticket, ModelRequestType::Parameter(parameter.clone()))]); + let _ = tx_params_download.send((ticket, ModelRequestType::Parameter(parameter.clone()))); }, ModelRequestType::Config => { info!("Retrying download for model config, (attempt {})", retries); @@ -546,60 +548,33 @@ impl + 'sta total_parameters = Some(param_names.len()); sharable_model.initialize_parameters(¶m_names, tx_params_response); - let tx_params_download = tx_params_download.clone(); let router = p2p.router(); let peer_manager = peer_manager.clone(); let param_requests_cancel_token = param_requests_cancel_token.clone(); - let handle: JoinHandle> = tokio::spawn(async move { // We use std mutex implementation here and call `.unwrap()` when acquiring the lock since there // is no chance of mutex poisoning; locks are acquired only to insert or remove items from them // and dropped immediately - let parameter_blob_tickets = Arc::new(std::sync::Mutex::new(Vec::new())); - let mut request_handles = Vec::new(); let peer_manager = peer_manager.clone(); + let parameter_downloader_handle = parameter_downloader_handle.clone(); + tokio::spawn(async move { for param_name in param_names { + info!("Limit of current parameters downloads reached, waiting for capacity"); + parameter_downloader_handle.wait_for_capacity().await; + let router = router.clone(); - let request_handle = tokio::spawn( - blob_ticket_param_request_task( + let result = blob_ticket_param_request_task( ModelRequestType::Parameter(param_name), router, - parameter_blob_tickets.clone(), peer_manager.clone(), param_requests_cancel_token.clone() - ) - ); + ).await.unwrap(); - // Check if we reached the max number of concurrent requests, and if that is the case, - // await for all of them to complete and start downloading the blobs - if request_handles.len() == max_concurrent_parameter_requests - 1 { - let mut max_concurrent_request_futures = std::mem::take(&mut request_handles); - max_concurrent_request_futures.push(request_handle); - // We don't care about the errors because we are already handling them inside the task - join_all(max_concurrent_request_futures).await; - let current_parameter_blob_tickets: Vec<(BlobTicket, ModelRequestType)> = { - let mut parameter_blob_tickets_lock = parameter_blob_tickets.lock().unwrap(); - parameter_blob_tickets_lock.drain(..).collect() - }; - tx_params_download.send(current_parameter_blob_tickets)?; - continue; - } - request_handles.push(request_handle); + parameter_downloader_handle.add_parameter(result.0, result.1); } - - // All parameters have been requested, wait all the remaining request futures to complete - // and download the blobs - join_all(request_handles).await; - let parameter_blob_tickets: Vec<(BlobTicket, ModelRequestType)> = { - let mut parameter_blob_tickets_lock = parameter_blob_tickets.lock().unwrap(); - parameter_blob_tickets_lock.drain(..).collect() - }; - tx_params_download.send(parameter_blob_tickets)?; - Ok(()) }); - drop(handle); }, Some(tx_model_config_response) = rx_request_model_config.recv() => { sharable_model.tx_model_config_response = Some(tx_model_config_response); @@ -620,21 +595,21 @@ impl + 'sta let tx_config_download = tx_config_download.clone(); let param_requests_cancel_token = param_requests_cancel_token.clone(); tokio::spawn(async move { - if let Ok(config_blob_ticket) = get_blob_ticket_to_download(router.clone(), ModelRequestType::Config, peer_manager, param_requests_cancel_token).await { + if let Ok((config_blob_ticket, _)) = blob_ticket_param_request_task(ModelRequestType::Config, router.clone(), peer_manager, param_requests_cancel_token).await { tx_config_download.send(config_blob_ticket).expect("Failed to send config blob ticket"); } else { error!("Error getting the config blob ticket, we'll not proceed with the download"); } }); } + // Modify the params download handler: Some(param_blob_tickets) = rx_params_download.recv() => { - for (ticket, request_type) in param_blob_tickets { + let (ticket, request_type) = param_blob_tickets; let kind = DownloadType::ModelSharing(request_type.clone()); metrics.record_download_started(ticket.hash(), kind.kind()); if let ModelRequestType::Parameter(parameter_name) = request_type { p2p.start_download(ticket, Tag::from(format!("model-{}", parameter_name)), kind); } - } } Some(config_blob_ticket) = rx_config_download.recv() => { let kind = DownloadType::ModelSharing(ModelRequestType::Config); @@ -795,33 +770,3 @@ fn all_node_ids_shuffled(state: &Coordinator) -> Vec addrs.shuffle(&mut rand::rng()); addrs } - -async fn get_blob_ticket_to_download( - router: Arc, - request_type: ModelRequestType, - peer_manager: Arc, - cancellation_token: CancellationToken, -) -> Result { - let blob_ticket = Arc::new(std::sync::Mutex::new(Vec::with_capacity(1))); - - blob_ticket_param_request_task( - request_type.clone(), - router, - blob_ticket.clone(), - peer_manager, - cancellation_token.clone(), - ) - .await; - - let ticket_result = { - let blob_ticket_lock = blob_ticket.lock().unwrap(); - blob_ticket_lock - .first() - .map(|a| a.0.clone()) - .ok_or(anyhow::anyhow!( - "No blob ticket found trying to download {request_type:?}" - ))? - }; - - Ok(ticket_result) -} diff --git a/shared/network/src/download_manager.rs b/shared/network/src/download_manager.rs index 0f797bc9d..2461c1709 100644 --- a/shared/network/src/download_manager.rs +++ b/shared/network/src/download_manager.rs @@ -34,8 +34,8 @@ pub struct DownloadRetryInfo { } #[derive(Debug)] -pub enum RetriedDownloadsMessage { - Insert { +pub enum ParameterDownloaderMessage { + InsertRetry { info: DownloadRetryInfo, }, Remove { @@ -53,33 +53,63 @@ pub enum RetriedDownloadsMessage { hash: Hash, response: oneshot::Sender, }, + AddParameter { + blob_ticket: BlobTicket, + request_type: ModelRequestType, + }, + DownloadSucceeded, + WaitForCapacity { + response: oneshot::Sender<()>, + }, } -/// Handler to interact with the retried downloads actor +/// Handler to interact with the parameter downloader actor #[derive(Clone)] -pub struct RetriedDownloadsHandle { - tx: mpsc::UnboundedSender, -} - -impl Default for RetriedDownloadsHandle { - fn default() -> Self { - Self::new() - } +pub struct ParameterDownloaderHandle { + tx: mpsc::UnboundedSender, } -impl RetriedDownloadsHandle { - pub fn new() -> Self { +impl ParameterDownloaderHandle { + pub fn new( + download_tx: mpsc::UnboundedSender<(BlobTicket, ModelRequestType)>, + max_concurrent_parameter_requests: usize, + ) -> Self { let (tx, rx) = mpsc::unbounded_channel(); // Spawn the actor - tokio::spawn(retried_downloads_actor(rx)); + tokio::spawn(parameter_downloader_actor( + rx, + download_tx.clone(), + max_concurrent_parameter_requests, + )); Self { tx } } /// Insert a new download to retry pub fn insert(&self, info: DownloadRetryInfo) { - let _ = self.tx.send(RetriedDownloadsMessage::Insert { info }); + let _ = self + .tx + .send(ParameterDownloaderMessage::InsertRetry { info }); + } + + pub fn add_parameter(&self, blob_ticket: BlobTicket, request_type: ModelRequestType) { + let _ = self.tx.send(ParameterDownloaderMessage::AddParameter { + blob_ticket, + request_type, + }); + } + + pub async fn wait_for_capacity(&self) { + let (tx, rx) = oneshot::channel(); + let _ = self + .tx + .send(ParameterDownloaderMessage::WaitForCapacity { response: tx }); + let _ = rx.await; + } + + pub fn download_succeeded(&self) { + let _ = self.tx.send(ParameterDownloaderMessage::DownloadSucceeded); } /// Remove a download from the retry list @@ -88,7 +118,7 @@ impl RetriedDownloadsHandle { if self .tx - .send(RetriedDownloadsMessage::Remove { + .send(ParameterDownloaderMessage::Remove { hash, response: response_tx, }) @@ -106,7 +136,7 @@ impl RetriedDownloadsHandle { if self .tx - .send(RetriedDownloadsMessage::Get { + .send(ParameterDownloaderMessage::Get { hash, response: response_tx, }) @@ -124,7 +154,7 @@ impl RetriedDownloadsHandle { if self .tx - .send(RetriedDownloadsMessage::PendingRetries { + .send(ParameterDownloaderMessage::PendingRetries { response: response_tx, }) .is_err() @@ -141,7 +171,7 @@ impl RetriedDownloadsHandle { if self .tx - .send(RetriedDownloadsMessage::UpdateTime { + .send(ParameterDownloaderMessage::UpdateTime { hash, response: response_tx, }) @@ -154,38 +184,82 @@ impl RetriedDownloadsHandle { } } -struct RetriedDownloadsActor { - downloads: HashMap, +struct ParameterDownloaderActor { + retry_downloads: HashMap, + tx_start_download: mpsc::UnboundedSender<(BlobTicket, ModelRequestType)>, + current_downloads: usize, + waiting_requesters: Vec>, + max_concurrent_parameter_requests: usize, } -impl RetriedDownloadsActor { - fn new() -> Self { +impl ParameterDownloaderActor { + fn new( + tx_start_download: mpsc::UnboundedSender<(BlobTicket, ModelRequestType)>, + max_concurrent_parameter_requests: usize, + ) -> Self { Self { - downloads: HashMap::new(), + retry_downloads: HashMap::new(), + tx_start_download, + current_downloads: 0, + waiting_requesters: Vec::new(), + max_concurrent_parameter_requests, } } - fn handle_message(&mut self, message: RetriedDownloadsMessage) { + fn handle_message(&mut self, message: ParameterDownloaderMessage) { match message { - RetriedDownloadsMessage::Insert { info } => { + ParameterDownloaderMessage::InsertRetry { info } => { let hash = info.ticket.hash(); - self.downloads.insert(hash, info); + self.retry_downloads.insert(hash, info); + } + + ParameterDownloaderMessage::WaitForCapacity { response } => { + if self.current_downloads < self.max_concurrent_parameter_requests { + // Can proceed immediately + let _ = response.send(()); + } else { + // Queue the process to wait + self.waiting_requesters.push(response); + } + } + + ParameterDownloaderMessage::AddParameter { + blob_ticket, + request_type, + } => { + info!("Starting parameter download for ticket {:?}", blob_ticket); + self.tx_start_download + .send((blob_ticket, request_type)) + .unwrap_or_else(|err| { + error!("Failed to send start download message: {}", err); + }); + self.current_downloads += 1; + } + + ParameterDownloaderMessage::DownloadSucceeded => { + self.current_downloads = self.current_downloads.saturating_sub(1); + if !self.waiting_requesters.is_empty() { + if let Some(waiter) = self.waiting_requesters.pop() { + info!("Notifying waiting requester that capacity is available"); + let _ = waiter.send(()); + } + } } - RetriedDownloadsMessage::Remove { hash, response } => { - let removed = self.downloads.remove(&hash); + ParameterDownloaderMessage::Remove { hash, response } => { + let removed = self.retry_downloads.remove(&hash); let _ = response.send(removed); } - RetriedDownloadsMessage::Get { hash, response } => { - let info = self.downloads.get(&hash).cloned(); + ParameterDownloaderMessage::Get { hash, response } => { + let info = self.retry_downloads.get(&hash).cloned(); let _ = response.send(info); } - RetriedDownloadsMessage::PendingRetries { response } => { + ParameterDownloaderMessage::PendingRetries { response } => { let now = Instant::now(); let pending: Vec<_> = self - .downloads + .retry_downloads .iter() .filter(|(_, info)| { info.retry_time @@ -205,8 +279,8 @@ impl RetriedDownloadsActor { let _ = response.send(pending); } - RetriedDownloadsMessage::UpdateTime { hash, response } => { - let retries = if let Some(info) = self.downloads.get_mut(&hash) { + ParameterDownloaderMessage::UpdateTime { hash, response } => { + let retries = if let Some(info) = self.retry_downloads.get_mut(&hash) { info.retry_time = None; // Mark as being retried now info.retries } else { @@ -219,8 +293,12 @@ impl RetriedDownloadsActor { } } -async fn retried_downloads_actor(mut rx: mpsc::UnboundedReceiver) { - let mut actor = RetriedDownloadsActor::new(); +async fn parameter_downloader_actor( + mut rx: mpsc::UnboundedReceiver, + tx: mpsc::UnboundedSender<(BlobTicket, ModelRequestType)>, + max_concurrent_parameter_requests: usize, +) { + let mut actor = ParameterDownloaderActor::new(tx, max_concurrent_parameter_requests); while let Some(message) = rx.recv().await { actor.handle_message(message); diff --git a/shared/network/src/lib.rs b/shared/network/src/lib.rs index ae8bc7f31..53010353a 100644 --- a/shared/network/src/lib.rs +++ b/shared/network/src/lib.rs @@ -75,7 +75,7 @@ mod test; pub use authenticable_identity::{AuthenticatableIdentity, FromSignedBytesError, raw_p2p_verify}; pub use download_manager::{ DownloadComplete, DownloadFailed, DownloadRetryInfo, DownloadType, MAX_DOWNLOAD_RETRIES, - RetriedDownloadsHandle, TransmittableDownload, + ParameterDownloaderHandle, TransmittableDownload, }; pub use iroh::{Endpoint, PublicKey, SecretKey}; use iroh_relay::{RelayMap, RelayNode, RelayQuicConfig}; @@ -783,10 +783,9 @@ fn hash_bytes(bytes: &Bytes) -> u64 { pub async fn blob_ticket_param_request_task( model_request_type: ModelRequestType, router: Arc, - model_blob_tickets: Arc>>, peer_manager: Arc, cancellation_token: CancellationToken, -) { +) -> Result<(BlobTicket, ModelRequestType)> { let max_attempts = 500u16; let mut attempts = 0u16; @@ -808,13 +807,8 @@ pub async fn blob_ticket_param_request_task( match result { Ok(Ok(blob_ticket)) => { - model_blob_tickets - .lock() - .unwrap() - .push((blob_ticket, model_request_type)); - peer_manager.report_success(peer_id); - return; + return Ok((blob_ticket, model_request_type)); } Ok(Err(e)) | Err(e) => { // Failed - report error and potentially try next peer @@ -831,4 +825,7 @@ pub async fn blob_ticket_param_request_task( error!("No peers available to give us a model parameter after {max_attempts} attempts"); cancellation_token.cancel(); + Err(anyhow!( + "Failed to get model parameter blob ticket after {max_attempts} attempts" + )) }