From d6f3e42229e6c4eae68f57b2c2afd28552620e7f Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Mon, 10 Nov 2025 10:24:07 -0800 Subject: [PATCH 1/5] Limit the amount of concurrent downloads for the parameters --- shared/client/src/client.rs | 103 ++++++++++++------------ shared/network/src/download_manager.rs | 106 ++++++++++++++++++++----- shared/network/src/lib.rs | 13 ++- 3 files changed, 139 insertions(+), 83 deletions(-) diff --git a/shared/client/src/client.rs b/shared/client/src/client.rs index d44c06f70..851eb2aa0 100644 --- a/shared/client/src/client.rs +++ b/shared/client/src/client.rs @@ -103,6 +103,7 @@ impl + 'sta let max_concurrent_parameter_requests = init_config.max_concurrent_parameter_requests; + let mut concurrent_downloads = 0_usize; let mut current_downloaded_parameters = 0_u64; let mut total_parameters = None; @@ -122,7 +123,7 @@ impl + 'sta tx_broadcast_finished, }); - let retried_downloads = RetriedDownloadsHandle::new(); + let retried_downloads = RetriedDownloadsHandle::new(tx_params_download.clone()); let mut sharable_model = SharableModel::empty(); let peer_manager = Arc::new(PeerManagerHandle::new( MAX_ERRORS_PER_PEER, @@ -272,6 +273,7 @@ impl + 'sta }) => { let _ = trace_span!("NetworkEvent::DownloadComplete", hash = %hash).entered(); metrics.record_download_completed(hash, from); + retried_downloads.download_succeeded(hash); if retried_downloads.remove(hash).await.is_some() { info!("Successfully downloaded previously failed blob {}", hex::encode(hash)); } @@ -499,7 +501,7 @@ impl + 'sta match inner { ModelRequestType::Parameter(parameter) => { info!("Retrying download for model parameter: {parameter}, (attempt {})", retries); - let _ = tx_params_download.send(vec![(ticket, ModelRequestType::Parameter(parameter.clone()))]); + let _ = tx_params_download.send((ticket, ModelRequestType::Parameter(parameter.clone()))); }, ModelRequestType::Config => { info!("Retrying download for model config, (attempt {})", retries); @@ -551,55 +553,59 @@ impl + 'sta let peer_manager = peer_manager.clone(); let param_requests_cancel_token = param_requests_cancel_token.clone(); - let handle: JoinHandle> = tokio::spawn(async move { // We use std mutex implementation here and call `.unwrap()` when acquiring the lock since there // is no chance of mutex poisoning; locks are acquired only to insert or remove items from them // and dropped immediately - let parameter_blob_tickets = Arc::new(std::sync::Mutex::new(Vec::new())); - let mut request_handles = Vec::new(); let peer_manager = peer_manager.clone(); + let mut max_concurrent_parameter_requests = 0; + let retried_downloads = retried_downloads.clone(); + tokio::spawn(async move { for param_name in param_names { let router = router.clone(); - let request_handle = tokio::spawn( - blob_ticket_param_request_task( + let result = blob_ticket_param_request_task( ModelRequestType::Parameter(param_name), router, - parameter_blob_tickets.clone(), peer_manager.clone(), param_requests_cancel_token.clone() - ) - ); - - // Check if we reached the max number of concurrent requests, and if that is the case, - // await for all of them to complete and start downloading the blobs - if request_handles.len() == max_concurrent_parameter_requests - 1 { - let mut max_concurrent_request_futures = std::mem::take(&mut request_handles); - max_concurrent_request_futures.push(request_handle); - // We don't care about the errors because we are already handling them inside the task - join_all(max_concurrent_request_futures).await; - let current_parameter_blob_tickets: Vec<(BlobTicket, ModelRequestType)> = { - let mut parameter_blob_tickets_lock = parameter_blob_tickets.lock().unwrap(); - parameter_blob_tickets_lock.drain(..).collect() - }; - tx_params_download.send(current_parameter_blob_tickets)?; - continue; - } - request_handles.push(request_handle); + ).await.unwrap(); + + // let send_result = tx_params_download.send(result.unwrap()); + retried_downloads.add_parameter(result.0, result.1); + // max_concurrent_parameter_requests += 1; + // if max_concurrent_parameter_requests >= 5 { + // println!("Reached max concurrent parameter requests, waiting for one to complete"); + // let a = rx_parameter_download_confirm.recv().await; + // println!("Download completed for one parameter, continuing with requests"); + // max_concurrent_parameter_requests -= 1; + // } + // // Check if we reached the max number of concurrent requests, and if that is the case, + // // await for all of them to complete and start downloading the blobs + // if request_handles.len() == max_concurrent_parameter_requests - 1 { + // let mut max_concurrent_request_futures = std::mem::take(&mut request_handles); + // max_concurrent_request_futures.push(request_handle); + // // We don't care about the errors because we are already handling them inside the task + // join_all(max_concurrent_request_futures).await; + // let current_parameter_blob_tickets: Vec<(BlobTicket, ModelRequestType)> = { + // let mut parameter_blob_tickets_lock = parameter_blob_tickets.lock().unwrap(); + // parameter_blob_tickets_lock.drain(..).collect() + // }; + // tx_params_download.send(current_parameter_blob_tickets)?; + // continue; + // } + // request_handles.push(request_handle); } - - // All parameters have been requested, wait all the remaining request futures to complete - // and download the blobs - join_all(request_handles).await; - let parameter_blob_tickets: Vec<(BlobTicket, ModelRequestType)> = { - let mut parameter_blob_tickets_lock = parameter_blob_tickets.lock().unwrap(); - parameter_blob_tickets_lock.drain(..).collect() - }; - tx_params_download.send(parameter_blob_tickets)?; - Ok(()) }); - drop(handle); + + // // All parameters have been requested, wait all the remaining request futures to complete + // // and download the blobs + // join_all(request_handles).await; + // let parameter_blob_tickets: Vec<(BlobTicket, ModelRequestType)> = { + // let mut parameter_blob_tickets_lock = parameter_blob_tickets.lock().unwrap(); + // parameter_blob_tickets_lock.drain(..).collect() + // }; + // tx_params_download.send(parameter_blob_tickets)?; }, Some(tx_model_config_response) = rx_request_model_config.recv() => { sharable_model.tx_model_config_response = Some(tx_model_config_response); @@ -627,14 +633,14 @@ impl + 'sta } }); } + // Modify the params download handler: Some(param_blob_tickets) = rx_params_download.recv() => { - for (ticket, request_type) in param_blob_tickets { + let (ticket, request_type) = param_blob_tickets; let kind = DownloadType::ModelSharing(request_type.clone()); metrics.record_download_started(ticket.hash(), kind.kind()); if let ModelRequestType::Parameter(parameter_name) = request_type { p2p.start_download(ticket, Tag::from(format!("model-{}", parameter_name)), kind); } - } } Some(config_blob_ticket) = rx_config_download.recv() => { let kind = DownloadType::ModelSharing(ModelRequestType::Config); @@ -802,26 +808,15 @@ async fn get_blob_ticket_to_download( peer_manager: Arc, cancellation_token: CancellationToken, ) -> Result { - let blob_ticket = Arc::new(std::sync::Mutex::new(Vec::with_capacity(1))); - - blob_ticket_param_request_task( + let result = blob_ticket_param_request_task( request_type.clone(), router, - blob_ticket.clone(), peer_manager, cancellation_token.clone(), ) .await; - let ticket_result = { - let blob_ticket_lock = blob_ticket.lock().unwrap(); - blob_ticket_lock - .first() - .map(|a| a.0.clone()) - .ok_or(anyhow::anyhow!( - "No blob ticket found trying to download {request_type:?}" - ))? - }; - - Ok(ticket_result) + let (blob_ticket_lock, model_type) = result.unwrap(); + + Ok(blob_ticket_lock) } diff --git a/shared/network/src/download_manager.rs b/shared/network/src/download_manager.rs index 0f797bc9d..8e56189e3 100644 --- a/shared/network/src/download_manager.rs +++ b/shared/network/src/download_manager.rs @@ -23,6 +23,7 @@ use tokio::{ use tracing::{error, info, trace, warn}; pub const MAX_DOWNLOAD_RETRIES: usize = 3; +pub const MAX_CONCURRENT_PARAMETER_REQUESTS: usize = 5; #[derive(Debug, Clone)] pub struct DownloadRetryInfo { @@ -35,7 +36,7 @@ pub struct DownloadRetryInfo { #[derive(Debug)] pub enum RetriedDownloadsMessage { - Insert { + InsertRetry { info: DownloadRetryInfo, }, Remove { @@ -53,6 +54,13 @@ pub enum RetriedDownloadsMessage { hash: Hash, response: oneshot::Sender, }, + AddParameter { + blob_ticket: BlobTicket, + request_type: ModelRequestType, + }, + DownloadSucceeded { + hash: Hash, + }, } /// Handler to interact with the retried downloads actor @@ -61,25 +69,32 @@ pub struct RetriedDownloadsHandle { tx: mpsc::UnboundedSender, } -impl Default for RetriedDownloadsHandle { - fn default() -> Self { - Self::new() - } -} - impl RetriedDownloadsHandle { - pub fn new() -> Self { + pub fn new(download_tx: mpsc::UnboundedSender<(BlobTicket, ModelRequestType)>) -> Self { let (tx, rx) = mpsc::unbounded_channel(); // Spawn the actor - tokio::spawn(retried_downloads_actor(rx)); + tokio::spawn(retried_downloads_actor(rx, download_tx.clone())); Self { tx } } /// Insert a new download to retry pub fn insert(&self, info: DownloadRetryInfo) { - let _ = self.tx.send(RetriedDownloadsMessage::Insert { info }); + let _ = self.tx.send(RetriedDownloadsMessage::InsertRetry { info }); + } + + pub fn add_parameter(&self, blob_ticket: BlobTicket, request_type: ModelRequestType) { + let _ = self.tx.send(RetriedDownloadsMessage::AddParameter { + blob_ticket, + request_type, + }); + } + + pub fn download_succeeded(&self, hash: Hash) { + let _ = self + .tx + .send(RetriedDownloadsMessage::DownloadSucceeded { hash }); } /// Remove a download from the retry list @@ -155,37 +170,83 @@ impl RetriedDownloadsHandle { } struct RetriedDownloadsActor { - downloads: HashMap, + retry_downloads: HashMap, + current_parameter_tickets: Vec<(BlobTicket, ModelRequestType)>, + tx_start_download: mpsc::UnboundedSender<(BlobTicket, ModelRequestType)>, + current_downloads: usize, } impl RetriedDownloadsActor { - fn new() -> Self { + fn new(tx_start_download: mpsc::UnboundedSender<(BlobTicket, ModelRequestType)>) -> Self { Self { - downloads: HashMap::new(), + retry_downloads: HashMap::new(), + current_parameter_tickets: Vec::new(), + tx_start_download, + current_downloads: 0, } } fn handle_message(&mut self, message: RetriedDownloadsMessage) { match message { - RetriedDownloadsMessage::Insert { info } => { + RetriedDownloadsMessage::InsertRetry { info } => { let hash = info.ticket.hash(); - self.downloads.insert(hash, info); + self.retry_downloads.insert(hash, info); + } + + RetriedDownloadsMessage::AddParameter { + blob_ticket, + request_type, + } => { + info!( + "Adding parameter download ticket {:?} for request type {:?}", + blob_ticket, request_type + ); + if self.current_downloads >= MAX_CONCURRENT_PARAMETER_REQUESTS { + self.current_parameter_tickets + .push((blob_ticket, request_type)); + info!("Max concurrent parameter downloads reached, queuing ticket"); + return; + } + info!("Starting parameter download for ticket {:?}", blob_ticket); + self.tx_start_download + .send((blob_ticket, request_type)) + .unwrap_or_else(|err| { + error!("Failed to send start download message: {}", err); + }); + self.current_downloads += 1; + } + + RetriedDownloadsMessage::DownloadSucceeded { hash } => { + self.current_downloads = self.current_downloads.saturating_sub(1); + if let Some((next_ticket, request_type)) = self.current_parameter_tickets.pop() { + info!( + "Starting next queued parameter download for ticket {:?} since a download completed", + next_ticket + ); + // Start the next download + self.tx_start_download + .send((next_ticket.clone(), request_type)) + .unwrap_or_else(|err| { + error!("Failed to send start download message: {}", err); + }); + self.current_downloads += 1; + } } RetriedDownloadsMessage::Remove { hash, response } => { - let removed = self.downloads.remove(&hash); + let removed = self.retry_downloads.remove(&hash); let _ = response.send(removed); } RetriedDownloadsMessage::Get { hash, response } => { - let info = self.downloads.get(&hash).cloned(); + let info = self.retry_downloads.get(&hash).cloned(); let _ = response.send(info); } RetriedDownloadsMessage::PendingRetries { response } => { let now = Instant::now(); let pending: Vec<_> = self - .downloads + .retry_downloads .iter() .filter(|(_, info)| { info.retry_time @@ -206,7 +267,7 @@ impl RetriedDownloadsActor { } RetriedDownloadsMessage::UpdateTime { hash, response } => { - let retries = if let Some(info) = self.downloads.get_mut(&hash) { + let retries = if let Some(info) = self.retry_downloads.get_mut(&hash) { info.retry_time = None; // Mark as being retried now info.retries } else { @@ -219,8 +280,11 @@ impl RetriedDownloadsActor { } } -async fn retried_downloads_actor(mut rx: mpsc::UnboundedReceiver) { - let mut actor = RetriedDownloadsActor::new(); +async fn retried_downloads_actor( + mut rx: mpsc::UnboundedReceiver, + tx: mpsc::UnboundedSender<(BlobTicket, ModelRequestType)>, +) { + let mut actor = RetriedDownloadsActor::new(tx); while let Some(message) = rx.recv().await { actor.handle_message(message); diff --git a/shared/network/src/lib.rs b/shared/network/src/lib.rs index c8d043f4e..40cad1268 100644 --- a/shared/network/src/lib.rs +++ b/shared/network/src/lib.rs @@ -771,10 +771,9 @@ fn hash_bytes(bytes: &Bytes) -> u64 { pub async fn blob_ticket_param_request_task( model_request_type: ModelRequestType, router: Arc, - model_blob_tickets: Arc>>, peer_manager: Arc, cancellation_token: CancellationToken, -) { +) -> Result<(BlobTicket, ModelRequestType)> { let max_attempts = 500u16; let mut attempts = 0u16; @@ -796,13 +795,8 @@ pub async fn blob_ticket_param_request_task( match result { Ok(Ok(blob_ticket)) => { - model_blob_tickets - .lock() - .unwrap() - .push((blob_ticket, model_request_type)); - peer_manager.report_success(peer_id); - return; + return Ok((blob_ticket, model_request_type)); } Ok(Err(e)) | Err(e) => { // Failed - report error and potentially try next peer @@ -819,4 +813,7 @@ pub async fn blob_ticket_param_request_task( error!("No peers available to give us a model parameter after {max_attempts} attempts"); cancellation_token.cancel(); + Err(anyhow!( + "Failed to get model parameter blob ticket after {max_attempts} attempts" + )) } From e593f9baebe1a8a4e76697b5735ed15722f23613 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Tue, 11 Nov 2025 15:46:31 -0300 Subject: [PATCH 2/5] Limit the blob ticket requests --- shared/client/src/client.rs | 2 ++ shared/network/src/download_manager.rs | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/shared/client/src/client.rs b/shared/client/src/client.rs index 851eb2aa0..b40741cc7 100644 --- a/shared/client/src/client.rs +++ b/shared/client/src/client.rs @@ -562,6 +562,8 @@ impl + 'sta tokio::spawn(async move { for param_name in param_names { + retried_downloads.wait_for_capacity().await; + let router = router.clone(); let result = blob_ticket_param_request_task( diff --git a/shared/network/src/download_manager.rs b/shared/network/src/download_manager.rs index 8e56189e3..d8efbfb17 100644 --- a/shared/network/src/download_manager.rs +++ b/shared/network/src/download_manager.rs @@ -61,6 +61,9 @@ pub enum RetriedDownloadsMessage { DownloadSucceeded { hash: Hash, }, + WaitForCapacity { + response: oneshot::Sender<()>, + }, } /// Handler to interact with the retried downloads actor @@ -91,6 +94,14 @@ impl RetriedDownloadsHandle { }); } + pub async fn wait_for_capacity(&self) { + let (tx, rx) = oneshot::channel(); + self.tx + .send(RetriedDownloadsMessage::WaitForCapacity { response: tx }) + .unwrap(); + let _ = rx.await; + } + pub fn download_succeeded(&self, hash: Hash) { let _ = self .tx @@ -174,6 +185,7 @@ struct RetriedDownloadsActor { current_parameter_tickets: Vec<(BlobTicket, ModelRequestType)>, tx_start_download: mpsc::UnboundedSender<(BlobTicket, ModelRequestType)>, current_downloads: usize, + waiting_requesters: Vec>, } impl RetriedDownloadsActor { @@ -183,6 +195,7 @@ impl RetriedDownloadsActor { current_parameter_tickets: Vec::new(), tx_start_download, current_downloads: 0, + waiting_requesters: Vec::new(), } } @@ -193,6 +206,16 @@ impl RetriedDownloadsActor { self.retry_downloads.insert(hash, info); } + RetriedDownloadsMessage::WaitForCapacity { response } => { + if self.current_downloads < MAX_CONCURRENT_PARAMETER_REQUESTS { + // Can proceed immediately + let _ = response.send(()); + } else { + // Queue the waiter + self.waiting_requesters.push(response); + } + } + RetriedDownloadsMessage::AddParameter { blob_ticket, request_type, From 1b16b3ac94b22ad95e91b96483a7eb7d8112f1aa Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Tue, 11 Nov 2025 12:40:03 -0800 Subject: [PATCH 3/5] Fix limit blob ticket requests --- shared/client/src/client.rs | 33 +------------------------- shared/network/src/download_manager.rs | 32 ++++++------------------- 2 files changed, 8 insertions(+), 57 deletions(-) diff --git a/shared/client/src/client.rs b/shared/client/src/client.rs index b40741cc7..8a8433ca7 100644 --- a/shared/client/src/client.rs +++ b/shared/client/src/client.rs @@ -562,6 +562,7 @@ impl + 'sta tokio::spawn(async move { for param_name in param_names { + info!("Waiting for capacity"); retried_downloads.wait_for_capacity().await; let router = router.clone(); @@ -573,41 +574,9 @@ impl + 'sta param_requests_cancel_token.clone() ).await.unwrap(); - // let send_result = tx_params_download.send(result.unwrap()); retried_downloads.add_parameter(result.0, result.1); - // max_concurrent_parameter_requests += 1; - // if max_concurrent_parameter_requests >= 5 { - // println!("Reached max concurrent parameter requests, waiting for one to complete"); - // let a = rx_parameter_download_confirm.recv().await; - // println!("Download completed for one parameter, continuing with requests"); - // max_concurrent_parameter_requests -= 1; - // } - // // Check if we reached the max number of concurrent requests, and if that is the case, - // // await for all of them to complete and start downloading the blobs - // if request_handles.len() == max_concurrent_parameter_requests - 1 { - // let mut max_concurrent_request_futures = std::mem::take(&mut request_handles); - // max_concurrent_request_futures.push(request_handle); - // // We don't care about the errors because we are already handling them inside the task - // join_all(max_concurrent_request_futures).await; - // let current_parameter_blob_tickets: Vec<(BlobTicket, ModelRequestType)> = { - // let mut parameter_blob_tickets_lock = parameter_blob_tickets.lock().unwrap(); - // parameter_blob_tickets_lock.drain(..).collect() - // }; - // tx_params_download.send(current_parameter_blob_tickets)?; - // continue; - // } - // request_handles.push(request_handle); } }); - - // // All parameters have been requested, wait all the remaining request futures to complete - // // and download the blobs - // join_all(request_handles).await; - // let parameter_blob_tickets: Vec<(BlobTicket, ModelRequestType)> = { - // let mut parameter_blob_tickets_lock = parameter_blob_tickets.lock().unwrap(); - // parameter_blob_tickets_lock.drain(..).collect() - // }; - // tx_params_download.send(parameter_blob_tickets)?; }, Some(tx_model_config_response) = rx_request_model_config.recv() => { sharable_model.tx_model_config_response = Some(tx_model_config_response); diff --git a/shared/network/src/download_manager.rs b/shared/network/src/download_manager.rs index d8efbfb17..bcc3fce12 100644 --- a/shared/network/src/download_manager.rs +++ b/shared/network/src/download_manager.rs @@ -182,7 +182,6 @@ impl RetriedDownloadsHandle { struct RetriedDownloadsActor { retry_downloads: HashMap, - current_parameter_tickets: Vec<(BlobTicket, ModelRequestType)>, tx_start_download: mpsc::UnboundedSender<(BlobTicket, ModelRequestType)>, current_downloads: usize, waiting_requesters: Vec>, @@ -192,7 +191,6 @@ impl RetriedDownloadsActor { fn new(tx_start_download: mpsc::UnboundedSender<(BlobTicket, ModelRequestType)>) -> Self { Self { retry_downloads: HashMap::new(), - current_parameter_tickets: Vec::new(), tx_start_download, current_downloads: 0, waiting_requesters: Vec::new(), @@ -211,7 +209,7 @@ impl RetriedDownloadsActor { // Can proceed immediately let _ = response.send(()); } else { - // Queue the waiter + // Queue the process to wait self.waiting_requesters.push(response); } } @@ -220,16 +218,6 @@ impl RetriedDownloadsActor { blob_ticket, request_type, } => { - info!( - "Adding parameter download ticket {:?} for request type {:?}", - blob_ticket, request_type - ); - if self.current_downloads >= MAX_CONCURRENT_PARAMETER_REQUESTS { - self.current_parameter_tickets - .push((blob_ticket, request_type)); - info!("Max concurrent parameter downloads reached, queuing ticket"); - return; - } info!("Starting parameter download for ticket {:?}", blob_ticket); self.tx_start_download .send((blob_ticket, request_type)) @@ -237,22 +225,16 @@ impl RetriedDownloadsActor { error!("Failed to send start download message: {}", err); }); self.current_downloads += 1; + info!("CURRENT PARAMETER DOWNLOADS: {}", self.current_downloads); } RetriedDownloadsMessage::DownloadSucceeded { hash } => { self.current_downloads = self.current_downloads.saturating_sub(1); - if let Some((next_ticket, request_type)) = self.current_parameter_tickets.pop() { - info!( - "Starting next queued parameter download for ticket {:?} since a download completed", - next_ticket - ); - // Start the next download - self.tx_start_download - .send((next_ticket.clone(), request_type)) - .unwrap_or_else(|err| { - error!("Failed to send start download message: {}", err); - }); - self.current_downloads += 1; + if !self.waiting_requesters.is_empty() { + if let Some(waiter) = self.waiting_requesters.pop() { + info!("Notifying waiting requester that capacity is available"); + let _ = waiter.send(()); + } } } From 130d9322c2615f46115c0d150a637a8cca05469e Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Tue, 18 Nov 2025 12:09:09 -0300 Subject: [PATCH 4/5] Rename actor for downloading parameters --- shared/client/src/client.rs | 19 +++--- shared/network/src/download_manager.rs | 87 ++++++++++++++------------ shared/network/src/lib.rs | 2 +- 3 files changed, 57 insertions(+), 51 deletions(-) diff --git a/shared/client/src/client.rs b/shared/client/src/client.rs index 8a8433ca7..f54c96be1 100644 --- a/shared/client/src/client.rs +++ b/shared/client/src/client.rs @@ -5,7 +5,6 @@ use crate::{ }; use anyhow::anyhow; use anyhow::{Error, Result, bail}; -use futures::future::join_all; use iroh::protocol::Router; use psyche_coordinator::{Commitment, CommitteeSelection, Coordinator, RunState}; use psyche_core::NodeIdentity; @@ -13,7 +12,7 @@ use psyche_metrics::{ClientMetrics, ClientRoleInRound, PeerConnection}; use psyche_network::{ AuthenticatableIdentity, BlobTicket, DownloadComplete, DownloadRetryInfo, DownloadType, MAX_DOWNLOAD_RETRIES, ModelRequestType, NetworkEvent, NetworkTUIState, NodeId, - PeerManagerHandle, RetriedDownloadsHandle, SharableModel, TransmittableDownload, allowlist, + ParameterDownloaderHandle, PeerManagerHandle, SharableModel, TransmittableDownload, allowlist, blob_ticket_param_request_task, raw_p2p_verify, }; use psyche_watcher::{Backend, BackendWatcher}; @@ -103,7 +102,6 @@ impl + 'sta let max_concurrent_parameter_requests = init_config.max_concurrent_parameter_requests; - let mut concurrent_downloads = 0_usize; let mut current_downloaded_parameters = 0_u64; let mut total_parameters = None; @@ -123,7 +121,10 @@ impl + 'sta tx_broadcast_finished, }); - let retried_downloads = RetriedDownloadsHandle::new(tx_params_download.clone()); + let retried_downloads = ParameterDownloaderHandle::new( + tx_params_download.clone(), + max_concurrent_parameter_requests, + ); let mut sharable_model = SharableModel::empty(); let peer_manager = Arc::new(PeerManagerHandle::new( MAX_ERRORS_PER_PEER, @@ -273,7 +274,7 @@ impl + 'sta }) => { let _ = trace_span!("NetworkEvent::DownloadComplete", hash = %hash).entered(); metrics.record_download_completed(hash, from); - retried_downloads.download_succeeded(hash); + retried_downloads.download_succeeded(); if retried_downloads.remove(hash).await.is_some() { info!("Successfully downloaded previously failed blob {}", hex::encode(hash)); } @@ -548,7 +549,6 @@ impl + 'sta total_parameters = Some(param_names.len()); sharable_model.initialize_parameters(¶m_names, tx_params_response); - let tx_params_download = tx_params_download.clone(); let router = p2p.router(); let peer_manager = peer_manager.clone(); @@ -557,7 +557,6 @@ impl + 'sta // is no chance of mutex poisoning; locks are acquired only to insert or remove items from them // and dropped immediately let peer_manager = peer_manager.clone(); - let mut max_concurrent_parameter_requests = 0; let retried_downloads = retried_downloads.clone(); tokio::spawn(async move { @@ -785,9 +784,7 @@ async fn get_blob_ticket_to_download( peer_manager, cancellation_token.clone(), ) - .await; - - let (blob_ticket_lock, model_type) = result.unwrap(); + .await?; - Ok(blob_ticket_lock) + Ok(result.0) } diff --git a/shared/network/src/download_manager.rs b/shared/network/src/download_manager.rs index bcc3fce12..2fe7fb54d 100644 --- a/shared/network/src/download_manager.rs +++ b/shared/network/src/download_manager.rs @@ -23,7 +23,6 @@ use tokio::{ use tracing::{error, info, trace, warn}; pub const MAX_DOWNLOAD_RETRIES: usize = 3; -pub const MAX_CONCURRENT_PARAMETER_REQUESTS: usize = 5; #[derive(Debug, Clone)] pub struct DownloadRetryInfo { @@ -35,7 +34,7 @@ pub struct DownloadRetryInfo { } #[derive(Debug)] -pub enum RetriedDownloadsMessage { +pub enum ParameterDownloaderMessage { InsertRetry { info: DownloadRetryInfo, }, @@ -58,37 +57,44 @@ pub enum RetriedDownloadsMessage { blob_ticket: BlobTicket, request_type: ModelRequestType, }, - DownloadSucceeded { - hash: Hash, - }, + DownloadSucceeded, WaitForCapacity { response: oneshot::Sender<()>, }, } -/// Handler to interact with the retried downloads actor +/// Handler to interact with the parameter downloader actor #[derive(Clone)] -pub struct RetriedDownloadsHandle { - tx: mpsc::UnboundedSender, +pub struct ParameterDownloaderHandle { + tx: mpsc::UnboundedSender, } -impl RetriedDownloadsHandle { - pub fn new(download_tx: mpsc::UnboundedSender<(BlobTicket, ModelRequestType)>) -> Self { +impl ParameterDownloaderHandle { + pub fn new( + download_tx: mpsc::UnboundedSender<(BlobTicket, ModelRequestType)>, + max_concurrent_parameter_requests: usize, + ) -> Self { let (tx, rx) = mpsc::unbounded_channel(); // Spawn the actor - tokio::spawn(retried_downloads_actor(rx, download_tx.clone())); + tokio::spawn(parameter_downloader_actor( + rx, + download_tx.clone(), + max_concurrent_parameter_requests, + )); Self { tx } } /// Insert a new download to retry pub fn insert(&self, info: DownloadRetryInfo) { - let _ = self.tx.send(RetriedDownloadsMessage::InsertRetry { info }); + let _ = self + .tx + .send(ParameterDownloaderMessage::InsertRetry { info }); } pub fn add_parameter(&self, blob_ticket: BlobTicket, request_type: ModelRequestType) { - let _ = self.tx.send(RetriedDownloadsMessage::AddParameter { + let _ = self.tx.send(ParameterDownloaderMessage::AddParameter { blob_ticket, request_type, }); @@ -97,15 +103,13 @@ impl RetriedDownloadsHandle { pub async fn wait_for_capacity(&self) { let (tx, rx) = oneshot::channel(); self.tx - .send(RetriedDownloadsMessage::WaitForCapacity { response: tx }) + .send(ParameterDownloaderMessage::WaitForCapacity { response: tx }) .unwrap(); let _ = rx.await; } - pub fn download_succeeded(&self, hash: Hash) { - let _ = self - .tx - .send(RetriedDownloadsMessage::DownloadSucceeded { hash }); + pub fn download_succeeded(&self) { + let _ = self.tx.send(ParameterDownloaderMessage::DownloadSucceeded); } /// Remove a download from the retry list @@ -114,7 +118,7 @@ impl RetriedDownloadsHandle { if self .tx - .send(RetriedDownloadsMessage::Remove { + .send(ParameterDownloaderMessage::Remove { hash, response: response_tx, }) @@ -132,7 +136,7 @@ impl RetriedDownloadsHandle { if self .tx - .send(RetriedDownloadsMessage::Get { + .send(ParameterDownloaderMessage::Get { hash, response: response_tx, }) @@ -150,7 +154,7 @@ impl RetriedDownloadsHandle { if self .tx - .send(RetriedDownloadsMessage::PendingRetries { + .send(ParameterDownloaderMessage::PendingRetries { response: response_tx, }) .is_err() @@ -167,7 +171,7 @@ impl RetriedDownloadsHandle { if self .tx - .send(RetriedDownloadsMessage::UpdateTime { + .send(ParameterDownloaderMessage::UpdateTime { hash, response: response_tx, }) @@ -180,32 +184,37 @@ impl RetriedDownloadsHandle { } } -struct RetriedDownloadsActor { +struct ParameterDownloaderActor { retry_downloads: HashMap, tx_start_download: mpsc::UnboundedSender<(BlobTicket, ModelRequestType)>, current_downloads: usize, waiting_requesters: Vec>, + max_concurrent_parameter_requests: usize, } -impl RetriedDownloadsActor { - fn new(tx_start_download: mpsc::UnboundedSender<(BlobTicket, ModelRequestType)>) -> Self { +impl ParameterDownloaderActor { + fn new( + tx_start_download: mpsc::UnboundedSender<(BlobTicket, ModelRequestType)>, + max_concurrent_parameter_requests: usize, + ) -> Self { Self { retry_downloads: HashMap::new(), tx_start_download, current_downloads: 0, waiting_requesters: Vec::new(), + max_concurrent_parameter_requests, } } - fn handle_message(&mut self, message: RetriedDownloadsMessage) { + fn handle_message(&mut self, message: ParameterDownloaderMessage) { match message { - RetriedDownloadsMessage::InsertRetry { info } => { + ParameterDownloaderMessage::InsertRetry { info } => { let hash = info.ticket.hash(); self.retry_downloads.insert(hash, info); } - RetriedDownloadsMessage::WaitForCapacity { response } => { - if self.current_downloads < MAX_CONCURRENT_PARAMETER_REQUESTS { + ParameterDownloaderMessage::WaitForCapacity { response } => { + if self.current_downloads < self.max_concurrent_parameter_requests { // Can proceed immediately let _ = response.send(()); } else { @@ -214,7 +223,7 @@ impl RetriedDownloadsActor { } } - RetriedDownloadsMessage::AddParameter { + ParameterDownloaderMessage::AddParameter { blob_ticket, request_type, } => { @@ -225,10 +234,9 @@ impl RetriedDownloadsActor { error!("Failed to send start download message: {}", err); }); self.current_downloads += 1; - info!("CURRENT PARAMETER DOWNLOADS: {}", self.current_downloads); } - RetriedDownloadsMessage::DownloadSucceeded { hash } => { + ParameterDownloaderMessage::DownloadSucceeded => { self.current_downloads = self.current_downloads.saturating_sub(1); if !self.waiting_requesters.is_empty() { if let Some(waiter) = self.waiting_requesters.pop() { @@ -238,17 +246,17 @@ impl RetriedDownloadsActor { } } - RetriedDownloadsMessage::Remove { hash, response } => { + ParameterDownloaderMessage::Remove { hash, response } => { let removed = self.retry_downloads.remove(&hash); let _ = response.send(removed); } - RetriedDownloadsMessage::Get { hash, response } => { + ParameterDownloaderMessage::Get { hash, response } => { let info = self.retry_downloads.get(&hash).cloned(); let _ = response.send(info); } - RetriedDownloadsMessage::PendingRetries { response } => { + ParameterDownloaderMessage::PendingRetries { response } => { let now = Instant::now(); let pending: Vec<_> = self .retry_downloads @@ -271,7 +279,7 @@ impl RetriedDownloadsActor { let _ = response.send(pending); } - RetriedDownloadsMessage::UpdateTime { hash, response } => { + ParameterDownloaderMessage::UpdateTime { hash, response } => { let retries = if let Some(info) = self.retry_downloads.get_mut(&hash) { info.retry_time = None; // Mark as being retried now info.retries @@ -285,11 +293,12 @@ impl RetriedDownloadsActor { } } -async fn retried_downloads_actor( - mut rx: mpsc::UnboundedReceiver, +async fn parameter_downloader_actor( + mut rx: mpsc::UnboundedReceiver, tx: mpsc::UnboundedSender<(BlobTicket, ModelRequestType)>, + max_concurrent_parameter_requests: usize, ) { - let mut actor = RetriedDownloadsActor::new(tx); + let mut actor = ParameterDownloaderActor::new(tx, max_concurrent_parameter_requests); while let Some(message) = rx.recv().await { actor.handle_message(message); diff --git a/shared/network/src/lib.rs b/shared/network/src/lib.rs index 40cad1268..b3b6f295e 100644 --- a/shared/network/src/lib.rs +++ b/shared/network/src/lib.rs @@ -75,7 +75,7 @@ mod test; pub use authenticable_identity::{AuthenticatableIdentity, FromSignedBytesError, raw_p2p_verify}; pub use download_manager::{ DownloadComplete, DownloadFailed, DownloadRetryInfo, DownloadType, MAX_DOWNLOAD_RETRIES, - RetriedDownloadsHandle, TransmittableDownload, + ParameterDownloaderHandle, TransmittableDownload, }; pub use iroh::{Endpoint, PublicKey, SecretKey}; use iroh_relay::{RelayMap, RelayNode, RelayQuicConfig}; From 72e36fa481e6b6f63e571a3d8a4e07dbe55bdc8e Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Tue, 18 Nov 2025 14:05:38 -0300 Subject: [PATCH 5/5] Remove unused function and update names --- shared/client/src/client.rs | 54 +++++++++----------------- shared/network/src/download_manager.rs | 6 +-- 2 files changed, 21 insertions(+), 39 deletions(-) diff --git a/shared/client/src/client.rs b/shared/client/src/client.rs index f54c96be1..38fff1fde 100644 --- a/shared/client/src/client.rs +++ b/shared/client/src/client.rs @@ -5,7 +5,6 @@ use crate::{ }; use anyhow::anyhow; use anyhow::{Error, Result, bail}; -use iroh::protocol::Router; use psyche_coordinator::{Commitment, CommitteeSelection, Coordinator, RunState}; use psyche_core::NodeIdentity; use psyche_metrics::{ClientMetrics, ClientRoleInRound, PeerConnection}; @@ -121,7 +120,7 @@ impl + 'sta tx_broadcast_finished, }); - let retried_downloads = ParameterDownloaderHandle::new( + let parameter_downloader_handle = ParameterDownloaderHandle::new( tx_params_download.clone(), max_concurrent_parameter_requests, ); @@ -274,8 +273,8 @@ impl + 'sta }) => { let _ = trace_span!("NetworkEvent::DownloadComplete", hash = %hash).entered(); metrics.record_download_completed(hash, from); - retried_downloads.download_succeeded(); - if retried_downloads.remove(hash).await.is_some() { + parameter_downloader_handle.download_succeeded(); + if parameter_downloader_handle.remove(hash).await.is_some() { info!("Successfully downloaded previously failed blob {}", hex::encode(hash)); } match download_data { @@ -307,7 +306,7 @@ impl + 'sta NetworkEvent::DownloadFailed(dl) => { let _ = trace_span!("NetworkEvent::DownloadFailed", error=%dl.error).entered(); let hash = dl.blob_ticket.hash(); - let retries = retried_downloads.get(hash).await.map(|i| i.retries).unwrap_or(0); + let retries = parameter_downloader_handle.get(hash).await.map(|i| i.retries).unwrap_or(0); let download_type_clone = dl.download_type.clone(); match dl.download_type { @@ -327,18 +326,18 @@ impl + 'sta ); let router = p2p.router().clone(); let peer_manager = peer_manager.clone(); - let retried_downloads = retried_downloads.clone(); + let parameter_downloader_handle = parameter_downloader_handle.clone(); let param_requests_cancel_token = param_requests_cancel_token.clone(); tokio::spawn(async move { - let blob_ticket_to_retry = if let Ok(new_blob_ticket) = get_blob_ticket_to_download(router.clone(), request_type, peer_manager.clone(), param_requests_cancel_token).await { + let blob_ticket_to_retry = if let Ok((new_blob_ticket, _)) = blob_ticket_param_request_task(request_type, router.clone(), peer_manager.clone(), param_requests_cancel_token).await { // We remove the old hash because we're getting the blob from a new peer that has its own version of the model parameter or config blob - retried_downloads.remove(hash).await; + parameter_downloader_handle.remove(hash).await; new_blob_ticket } else { dl.blob_ticket }; - retried_downloads.insert(DownloadRetryInfo { + parameter_downloader_handle.insert(DownloadRetryInfo { retries: retries + 1, retry_time, tag: dl.tag, @@ -351,7 +350,7 @@ impl + 'sta if retries >= MAX_DOWNLOAD_RETRIES { metrics.record_download_perma_failed(); warn!("Distro result download failed (not retrying): {}", dl.error); - retried_downloads.remove(hash).await; + parameter_downloader_handle.remove(hash).await; } else { metrics.record_download_failed(); let backoff_duration = DOWNLOAD_RETRY_BACKOFF_BASE.mul_f32(2_f32.powi(retries as i32)); @@ -362,7 +361,7 @@ impl + 'sta backoff_duration, dl.error ); - retried_downloads.insert(DownloadRetryInfo { + parameter_downloader_handle.insert(DownloadRetryInfo { retries: retries + 1, retry_time, tag: dl.tag, @@ -484,12 +483,12 @@ impl + 'sta let tx_params_download = tx_params_download.clone(); let tx_config_download = tx_config_download.clone(); let metrics = metrics.clone(); - let retried_downloads = retried_downloads.clone(); + let parameter_downloader_handle = parameter_downloader_handle.clone(); tokio::spawn(async move { - let pending_retries: Vec<(psyche_network::Hash, BlobTicket, Tag, DownloadType)> = retried_downloads.pending_retries().await; + let pending_retries: Vec<(psyche_network::Hash, BlobTicket, Tag, DownloadType)> = parameter_downloader_handle.pending_retries().await; for (hash, ticket, tag, download_type) in pending_retries { - let retries = retried_downloads.update_time(hash).await; + let retries = parameter_downloader_handle.update_time(hash).await; metrics.record_download_retry(hash); // We check the type of the failed download and send it to the appropriate channel to retry it @@ -557,12 +556,12 @@ impl + 'sta // is no chance of mutex poisoning; locks are acquired only to insert or remove items from them // and dropped immediately let peer_manager = peer_manager.clone(); - let retried_downloads = retried_downloads.clone(); + let parameter_downloader_handle = parameter_downloader_handle.clone(); tokio::spawn(async move { for param_name in param_names { - info!("Waiting for capacity"); - retried_downloads.wait_for_capacity().await; + info!("Limit of current parameters downloads reached, waiting for capacity"); + parameter_downloader_handle.wait_for_capacity().await; let router = router.clone(); @@ -573,7 +572,7 @@ impl + 'sta param_requests_cancel_token.clone() ).await.unwrap(); - retried_downloads.add_parameter(result.0, result.1); + parameter_downloader_handle.add_parameter(result.0, result.1); } }); }, @@ -596,7 +595,7 @@ impl + 'sta let tx_config_download = tx_config_download.clone(); let param_requests_cancel_token = param_requests_cancel_token.clone(); tokio::spawn(async move { - if let Ok(config_blob_ticket) = get_blob_ticket_to_download(router.clone(), ModelRequestType::Config, peer_manager, param_requests_cancel_token).await { + if let Ok((config_blob_ticket, _)) = blob_ticket_param_request_task(ModelRequestType::Config, router.clone(), peer_manager, param_requests_cancel_token).await { tx_config_download.send(config_blob_ticket).expect("Failed to send config blob ticket"); } else { error!("Error getting the config blob ticket, we'll not proceed with the download"); @@ -771,20 +770,3 @@ fn all_node_ids_shuffled(state: &Coordinator) -> Vec addrs.shuffle(&mut rand::rng()); addrs } - -async fn get_blob_ticket_to_download( - router: Arc, - request_type: ModelRequestType, - peer_manager: Arc, - cancellation_token: CancellationToken, -) -> Result { - let result = blob_ticket_param_request_task( - request_type.clone(), - router, - peer_manager, - cancellation_token.clone(), - ) - .await?; - - Ok(result.0) -} diff --git a/shared/network/src/download_manager.rs b/shared/network/src/download_manager.rs index 2fe7fb54d..2461c1709 100644 --- a/shared/network/src/download_manager.rs +++ b/shared/network/src/download_manager.rs @@ -102,9 +102,9 @@ impl ParameterDownloaderHandle { pub async fn wait_for_capacity(&self) { let (tx, rx) = oneshot::channel(); - self.tx - .send(ParameterDownloaderMessage::WaitForCapacity { response: tx }) - .unwrap(); + let _ = self + .tx + .send(ParameterDownloaderMessage::WaitForCapacity { response: tx }); let _ = rx.await; }