Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 29 additions & 84 deletions shared/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@ use crate::{
};
use anyhow::anyhow;
use anyhow::{Error, Result, bail};
use futures::future::join_all;
use iroh::protocol::Router;
use psyche_coordinator::{Commitment, CommitteeSelection, Coordinator, RunState};
use psyche_core::NodeIdentity;
use psyche_metrics::{ClientMetrics, ClientRoleInRound, PeerConnection};
use psyche_network::{
AuthenticatableIdentity, BlobTicket, DownloadComplete, DownloadRetryInfo, DownloadType,
MAX_DOWNLOAD_RETRIES, ModelRequestType, NetworkEvent, NetworkTUIState, NodeId,
PeerManagerHandle, RetriedDownloadsHandle, SharableModel, TransmittableDownload, allowlist,
ParameterDownloaderHandle, PeerManagerHandle, SharableModel, TransmittableDownload, allowlist,
blob_ticket_param_request_task, raw_p2p_verify,
};
use psyche_watcher::{Backend, BackendWatcher};
Expand Down Expand Up @@ -122,7 +120,10 @@ impl<T: NodeIdentity, A: AuthenticatableIdentity + 'static, B: Backend<T> + 'sta
tx_broadcast_finished,
});

let retried_downloads = RetriedDownloadsHandle::new();
let parameter_downloader_handle = ParameterDownloaderHandle::new(
tx_params_download.clone(),
max_concurrent_parameter_requests,
);
let mut sharable_model = SharableModel::empty();
let peer_manager = Arc::new(PeerManagerHandle::new(
MAX_ERRORS_PER_PEER,
Expand Down Expand Up @@ -272,7 +273,8 @@ impl<T: NodeIdentity, A: AuthenticatableIdentity + 'static, B: Backend<T> + 'sta
}) => {
let _ = trace_span!("NetworkEvent::DownloadComplete", hash = %hash).entered();
metrics.record_download_completed(hash, from);
if retried_downloads.remove(hash).await.is_some() {
parameter_downloader_handle.download_succeeded();
if parameter_downloader_handle.remove(hash).await.is_some() {
info!("Successfully downloaded previously failed blob {}", hex::encode(hash));
}
match download_data {
Expand Down Expand Up @@ -304,7 +306,7 @@ impl<T: NodeIdentity, A: AuthenticatableIdentity + 'static, B: Backend<T> + 'sta
NetworkEvent::DownloadFailed(dl) => {
let _ = trace_span!("NetworkEvent::DownloadFailed", error=%dl.error).entered();
let hash = dl.blob_ticket.hash();
let retries = retried_downloads.get(hash).await.map(|i| i.retries).unwrap_or(0);
let retries = parameter_downloader_handle.get(hash).await.map(|i| i.retries).unwrap_or(0);
let download_type_clone = dl.download_type.clone();

match dl.download_type {
Expand All @@ -324,18 +326,18 @@ impl<T: NodeIdentity, A: AuthenticatableIdentity + 'static, B: Backend<T> + 'sta
);
let router = p2p.router().clone();
let peer_manager = peer_manager.clone();
let retried_downloads = retried_downloads.clone();
let parameter_downloader_handle = parameter_downloader_handle.clone();
let param_requests_cancel_token = param_requests_cancel_token.clone();
tokio::spawn(async move {
let blob_ticket_to_retry = if let Ok(new_blob_ticket) = get_blob_ticket_to_download(router.clone(), request_type, peer_manager.clone(), param_requests_cancel_token).await {
let blob_ticket_to_retry = if let Ok((new_blob_ticket, _)) = blob_ticket_param_request_task(request_type, router.clone(), peer_manager.clone(), param_requests_cancel_token).await {
// We remove the old hash because we're getting the blob from a new peer that has its own version of the model parameter or config blob
retried_downloads.remove(hash).await;
parameter_downloader_handle.remove(hash).await;
new_blob_ticket
} else {
dl.blob_ticket
};

retried_downloads.insert(DownloadRetryInfo {
parameter_downloader_handle.insert(DownloadRetryInfo {
retries: retries + 1,
retry_time,
tag: dl.tag,
Expand All @@ -348,7 +350,7 @@ impl<T: NodeIdentity, A: AuthenticatableIdentity + 'static, B: Backend<T> + 'sta
if retries >= MAX_DOWNLOAD_RETRIES {
metrics.record_download_perma_failed();
warn!("Distro result download failed (not retrying): {}", dl.error);
retried_downloads.remove(hash).await;
parameter_downloader_handle.remove(hash).await;
} else {
metrics.record_download_failed();
let backoff_duration = DOWNLOAD_RETRY_BACKOFF_BASE.mul_f32(2_f32.powi(retries as i32));
Expand All @@ -359,7 +361,7 @@ impl<T: NodeIdentity, A: AuthenticatableIdentity + 'static, B: Backend<T> + 'sta
backoff_duration,
dl.error
);
retried_downloads.insert(DownloadRetryInfo {
parameter_downloader_handle.insert(DownloadRetryInfo {
retries: retries + 1,
retry_time,
tag: dl.tag,
Expand Down Expand Up @@ -481,12 +483,12 @@ impl<T: NodeIdentity, A: AuthenticatableIdentity + 'static, B: Backend<T> + 'sta
let tx_params_download = tx_params_download.clone();
let tx_config_download = tx_config_download.clone();
let metrics = metrics.clone();
let retried_downloads = retried_downloads.clone();
let parameter_downloader_handle = parameter_downloader_handle.clone();
tokio::spawn(async move {
let pending_retries: Vec<(psyche_network::Hash, BlobTicket, Tag, DownloadType)> = retried_downloads.pending_retries().await;
let pending_retries: Vec<(psyche_network::Hash, BlobTicket, Tag, DownloadType)> = parameter_downloader_handle.pending_retries().await;

for (hash, ticket, tag, download_type) in pending_retries {
let retries = retried_downloads.update_time(hash).await;
let retries = parameter_downloader_handle.update_time(hash).await;

metrics.record_download_retry(hash);
// We check the type of the failed download and send it to the appropriate channel to retry it
Expand All @@ -499,7 +501,7 @@ impl<T: NodeIdentity, A: AuthenticatableIdentity + 'static, B: Backend<T> + 'sta
match inner {
ModelRequestType::Parameter(parameter) => {
info!("Retrying download for model parameter: {parameter}, (attempt {})", retries);
let _ = tx_params_download.send(vec![(ticket, ModelRequestType::Parameter(parameter.clone()))]);
let _ = tx_params_download.send((ticket, ModelRequestType::Parameter(parameter.clone())));
},
ModelRequestType::Config => {
info!("Retrying download for model config, (attempt {})", retries);
Expand Down Expand Up @@ -546,60 +548,33 @@ impl<T: NodeIdentity, A: AuthenticatableIdentity + 'static, B: Backend<T> + 'sta
total_parameters = Some(param_names.len());
sharable_model.initialize_parameters(&param_names, tx_params_response);

let tx_params_download = tx_params_download.clone();
let router = p2p.router();

let peer_manager = peer_manager.clone();
let param_requests_cancel_token = param_requests_cancel_token.clone();
let handle: JoinHandle<anyhow::Result<()>> = tokio::spawn(async move {
// We use std mutex implementation here and call `.unwrap()` when acquiring the lock since there
// is no chance of mutex poisoning; locks are acquired only to insert or remove items from them
// and dropped immediately
let parameter_blob_tickets = Arc::new(std::sync::Mutex::new(Vec::new()));
let mut request_handles = Vec::new();
let peer_manager = peer_manager.clone();
let parameter_downloader_handle = parameter_downloader_handle.clone();

tokio::spawn(async move {
for param_name in param_names {
info!("Limit of current parameters downloads reached, waiting for capacity");
parameter_downloader_handle.wait_for_capacity().await;

let router = router.clone();

let request_handle = tokio::spawn(
blob_ticket_param_request_task(
let result = blob_ticket_param_request_task(
ModelRequestType::Parameter(param_name),
router,
parameter_blob_tickets.clone(),
peer_manager.clone(),
param_requests_cancel_token.clone()
)
);
).await.unwrap();

// Check if we reached the max number of concurrent requests, and if that is the case,
// await for all of them to complete and start downloading the blobs
if request_handles.len() == max_concurrent_parameter_requests - 1 {
let mut max_concurrent_request_futures = std::mem::take(&mut request_handles);
max_concurrent_request_futures.push(request_handle);
// We don't care about the errors because we are already handling them inside the task
join_all(max_concurrent_request_futures).await;
let current_parameter_blob_tickets: Vec<(BlobTicket, ModelRequestType)> = {
let mut parameter_blob_tickets_lock = parameter_blob_tickets.lock().unwrap();
parameter_blob_tickets_lock.drain(..).collect()
};
tx_params_download.send(current_parameter_blob_tickets)?;
continue;
}
request_handles.push(request_handle);
parameter_downloader_handle.add_parameter(result.0, result.1);
}

// All parameters have been requested, wait all the remaining request futures to complete
// and download the blobs
join_all(request_handles).await;
let parameter_blob_tickets: Vec<(BlobTicket, ModelRequestType)> = {
let mut parameter_blob_tickets_lock = parameter_blob_tickets.lock().unwrap();
parameter_blob_tickets_lock.drain(..).collect()
};
tx_params_download.send(parameter_blob_tickets)?;
Ok(())
});
drop(handle);
},
Some(tx_model_config_response) = rx_request_model_config.recv() => {
sharable_model.tx_model_config_response = Some(tx_model_config_response);
Expand All @@ -620,21 +595,21 @@ impl<T: NodeIdentity, A: AuthenticatableIdentity + 'static, B: Backend<T> + 'sta
let tx_config_download = tx_config_download.clone();
let param_requests_cancel_token = param_requests_cancel_token.clone();
tokio::spawn(async move {
if let Ok(config_blob_ticket) = get_blob_ticket_to_download(router.clone(), ModelRequestType::Config, peer_manager, param_requests_cancel_token).await {
if let Ok((config_blob_ticket, _)) = blob_ticket_param_request_task(ModelRequestType::Config, router.clone(), peer_manager, param_requests_cancel_token).await {
tx_config_download.send(config_blob_ticket).expect("Failed to send config blob ticket");
} else {
error!("Error getting the config blob ticket, we'll not proceed with the download");
}
});
}
// Modify the params download handler:
Some(param_blob_tickets) = rx_params_download.recv() => {
for (ticket, request_type) in param_blob_tickets {
let (ticket, request_type) = param_blob_tickets;
let kind = DownloadType::ModelSharing(request_type.clone());
metrics.record_download_started(ticket.hash(), kind.kind());
if let ModelRequestType::Parameter(parameter_name) = request_type {
p2p.start_download(ticket, Tag::from(format!("model-{}", parameter_name)), kind);
}
}
}
Some(config_blob_ticket) = rx_config_download.recv() => {
let kind = DownloadType::ModelSharing(ModelRequestType::Config);
Expand Down Expand Up @@ -795,33 +770,3 @@ fn all_node_ids_shuffled<T: NodeIdentity>(state: &Coordinator<T>) -> Vec<NodeId>
addrs.shuffle(&mut rand::rng());
addrs
}

async fn get_blob_ticket_to_download(
router: Arc<Router>,
request_type: ModelRequestType,
peer_manager: Arc<PeerManagerHandle>,
cancellation_token: CancellationToken,
) -> Result<BlobTicket, anyhow::Error> {
let blob_ticket = Arc::new(std::sync::Mutex::new(Vec::with_capacity(1)));

blob_ticket_param_request_task(
request_type.clone(),
router,
blob_ticket.clone(),
peer_manager,
cancellation_token.clone(),
)
.await;

let ticket_result = {
let blob_ticket_lock = blob_ticket.lock().unwrap();
blob_ticket_lock
.first()
.map(|a| a.0.clone())
.ok_or(anyhow::anyhow!(
"No blob ticket found trying to download {request_type:?}"
))?
};

Ok(ticket_result)
}
Loading