From 10e81f3407e6943a83377c397fec414bf6a6ece4 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Fri, 9 Jan 2026 18:45:10 -0300 Subject: [PATCH 1/6] Add GCS checkpoint variant --- architectures/centralized/client/src/app.rs | 23 +-- architectures/centralized/server/src/app.rs | 2 +- .../centralized/shared/src/protocol.rs | 2 +- .../solana-client/src/backend.rs | 5 +- .../solana-client/src/command/checkpoint.rs | 3 +- .../solana-client/src/instructions.rs | 2 +- .../solana-coordinator/src/instance_state.rs | 8 +- .../programs/solana-coordinator/src/lib.rs | 3 +- python/default.nix | 2 +- shared/client/src/cli.rs | 60 +++++++- shared/client/src/lib.rs | 3 +- shared/client/src/state/cooldown.rs | 138 +++++++++++------- shared/client/src/state/init.rs | 2 +- shared/client/src/state/mod.rs | 5 +- shared/client/src/state/types.rs | 14 +- shared/coordinator/src/coordinator.rs | 15 +- shared/data-provider/src/hub.rs | 54 +++++++ shared/data-provider/src/lib.rs | 1 + shared/watcher/src/traits.rs | 2 +- 19 files changed, 258 insertions(+), 86 deletions(-) diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index 3560e0a7a..e3eb19019 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -2,8 +2,10 @@ use anyhow::{Error, Result}; use bytemuck::Zeroable; use hf_hub::Repo; use psyche_centralized_shared::{ClientId, ClientToServerMessage, ServerToClientMessage}; +use psyche_client::UploadInfo; use psyche_client::{ - Client, ClientTUI, ClientTUIState, NC, RunInitConfig, TrainArgs, read_identity_secret_key, + Client, ClientTUI, ClientTUIState, HubUploadInfo, NC, RunInitConfig, TrainArgs, + read_identity_secret_key, }; use psyche_coordinator::{Coordinator, HealthChecks, model}; use psyche_metrics::ClientMetrics; @@ -29,7 +31,7 @@ pub type TabsData = ::Data; pub enum ToSend { Witness(Box), HealthCheck(HealthChecks), - Checkpoint(model::HubRepo), + Checkpoint(model::Checkpoint), } struct Backend { @@ -67,7 +69,7 @@ impl WatcherBackend for Backend { Ok(()) } - async fn send_checkpoint(&mut self, checkpoint: model::HubRepo) -> Result<()> { + async fn send_checkpoint(&mut self, checkpoint: model::Checkpoint) -> Result<()> { self.tx.send(ToSend::Checkpoint(checkpoint))?; Ok(()) } @@ -173,18 +175,19 @@ impl App { ) -> Result<()> { // sanity checks if let Some(checkpoint_config) = &state_options.checkpoint_config { - if let Some(hub_upload) = &checkpoint_config.hub_upload { + if let Some(UploadInfo::Hub(HubUploadInfo { + hub_repo, + hub_token, + })) = &checkpoint_config.hub_upload + { let api = hf_hub::api::tokio::ApiBuilder::new() - .with_token(Some(hub_upload.hub_token.clone())) + .with_token(Some(hub_token.clone())) .build()?; - let repo_api = api.repo(Repo::new( - hub_upload.hub_repo.clone(), - hf_hub::RepoType::Model, - )); + let repo_api = api.repo(Repo::new(hub_repo.clone(), hf_hub::RepoType::Model)); if !repo_api.is_writable().await { anyhow::bail!( "Checkpoint upload repo {} is not writable with the passed API key.", - hub_upload.hub_repo + hub_repo ) } } diff --git a/architectures/centralized/server/src/app.rs b/architectures/centralized/server/src/app.rs index 6707b49bb..402146817 100644 --- a/architectures/centralized/server/src/app.rs +++ b/architectures/centralized/server/src/app.rs @@ -81,7 +81,7 @@ impl psyche_watcher::Backend for ChannelCoordinatorBackend { bail!("Server does not send health checks"); } - async fn send_checkpoint(&mut self, _checkpoint: model::HubRepo) -> Result<()> { + async fn send_checkpoint(&mut self, _checkpoint: model::Checkpoint) -> Result<()> { bail!("Server does not send checkpoints"); } } diff --git a/architectures/centralized/shared/src/protocol.rs b/architectures/centralized/shared/src/protocol.rs index c71c7524f..a96c64b8f 100644 --- a/architectures/centralized/shared/src/protocol.rs +++ b/architectures/centralized/shared/src/protocol.rs @@ -15,7 +15,7 @@ pub enum ClientToServerMessage { Join { run_id: String }, Witness(Box), HealthCheck(HealthChecks), - Checkpoint(model::HubRepo), + Checkpoint(model::Checkpoint), } #[derive(Serialize, Deserialize, Debug, Clone)] diff --git a/architectures/decentralized/solana-client/src/backend.rs b/architectures/decentralized/solana-client/src/backend.rs index 90f94e528..f948cde94 100644 --- a/architectures/decentralized/solana-client/src/backend.rs +++ b/architectures/decentralized/solana-client/src/backend.rs @@ -20,6 +20,7 @@ use anchor_client::{ use anyhow::{Context, Result, anyhow}; use futures_util::StreamExt; use psyche_client::IntegrationTestLogMarker; +use psyche_coordinator::model::{self, Checkpoint}; use psyche_coordinator::{CommitteeProof, Coordinator, HealthChecks, model::HubRepo}; use psyche_watcher::{Backend as WatcherBackend, OpportunisticData}; use solana_account_decoder_client_types::{UiAccount, UiAccountEncoding}; @@ -333,7 +334,7 @@ impl SolanaBackend { &self, coordinator_instance: Pubkey, coordinator_account: Pubkey, - repo: HubRepo, + repo: Checkpoint, ) { let user = self.get_payer(); let instruction = instructions::coordinator_checkpoint( @@ -603,7 +604,7 @@ impl WatcherBackend for SolanaBackendRunner Ok(()) } - async fn send_checkpoint(&mut self, checkpoint: HubRepo) -> Result<()> { + async fn send_checkpoint(&mut self, checkpoint: model::Checkpoint) -> Result<()> { self.backend .send_checkpoint(self.instance, self.account, checkpoint); Ok(()) diff --git a/architectures/decentralized/solana-client/src/command/checkpoint.rs b/architectures/decentralized/solana-client/src/command/checkpoint.rs index 73adc2105..3f8cc9f2b 100644 --- a/architectures/decentralized/solana-client/src/command/checkpoint.rs +++ b/architectures/decentralized/solana-client/src/command/checkpoint.rs @@ -1,5 +1,6 @@ use anyhow::Result; use clap::Args; +use psyche_coordinator::model::Checkpoint; use psyche_coordinator::model::HubRepo; use psyche_core::FixedString; @@ -45,7 +46,7 @@ pub async fn command_checkpoint_execute( &coordinator_instance, &coordinator_account, &user, - repo, + Checkpoint::Hub(repo.clone()), ); let signature = backend .send_and_retry("Checkpoint", &[instruction], &[]) diff --git a/architectures/decentralized/solana-client/src/instructions.rs b/architectures/decentralized/solana-client/src/instructions.rs index 68e169008..b535d54f4 100644 --- a/architectures/decentralized/solana-client/src/instructions.rs +++ b/architectures/decentralized/solana-client/src/instructions.rs @@ -206,7 +206,7 @@ pub fn coordinator_checkpoint( coordinator_instance: &Pubkey, coordinator_account: &Pubkey, user: &Pubkey, - repo: psyche_coordinator::model::HubRepo, + repo: psyche_coordinator::model::Checkpoint, ) -> Instruction { anchor_instruction( psyche_solana_coordinator::ID, diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs index 029b40fad..7ea079bd0 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs @@ -10,8 +10,8 @@ use psyche_coordinator::RunState; use psyche_coordinator::SOLANA_MAX_STRING_LEN; use psyche_coordinator::TickResult; use psyche_coordinator::Witness; -use psyche_coordinator::model::HubRepo; use psyche_coordinator::model::Model; +use psyche_coordinator::model::{Checkpoint, HubRepo}; use psyche_core::FixedString; use psyche_core::SmallBoolean; use psyche_core::sha256v; @@ -389,7 +389,11 @@ impl CoordinatorInstanceState { self.tick() } - pub fn checkpoint(&mut self, payer: &Pubkey, repo: HubRepo) -> Result<()> { + pub fn checkpoint( + &mut self, + payer: &Pubkey, + repo: Checkpoint, + ) -> Result<()> { // O(n) on clients, reconsider let id = self.clients_state.find_signer(payer)?; let index = self diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs index 657424195..f7feb7981 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs @@ -165,6 +165,7 @@ impl CoordinatorInstance { pub mod psyche_solana_coordinator { use super::*; + use psyche_coordinator::model::Checkpoint; use psyche_core::FixedString; pub fn init_coordinator( @@ -313,7 +314,7 @@ pub mod psyche_solana_coordinator { pub fn checkpoint( ctx: Context, - repo: HubRepo, + repo: psyche_coordinator::model::Checkpoint, ) -> Result<()> { let mut account = ctx.accounts.coordinator_account.load_mut()?; account.increment_nonce(); diff --git a/python/default.nix b/python/default.nix index 0ae9670bd..196b07867 100644 --- a/python/default.nix +++ b/python/default.nix @@ -32,7 +32,6 @@ let # packages that we provide to the venv via nix derivations topLevelNixPkgs = [ "torch" - "vllm" ] ++ lib.optionals stdenvNoCC.hostPlatform.isLinux [ "flash-attn" @@ -40,6 +39,7 @@ let # i'm really not a fan of providing torchtitan like this. i'd much rather have it be built as a git dep via uv2nix. # i think there's room to figure out how to provide setuptools for it. "torchtitan" + "vllm" ]; nixProvidedPythonPkgs = getAllTransitiveDeps topLevelNixPkgs; diff --git a/shared/client/src/cli.rs b/shared/client/src/cli.rs index 268ea753a..2f6f2b080 100644 --- a/shared/client/src/cli.rs +++ b/shared/client/src/cli.rs @@ -1,5 +1,7 @@ use crate::{CheckpointConfig, HubUploadInfo, WandBInfo}; +use crate::GcsUploadInfo; +use crate::UploadInfo; use anyhow::{Result, anyhow, bail}; use clap::Args; use psyche_eval::tasktype_from_name; @@ -146,6 +148,14 @@ pub struct TrainArgs { #[clap(long, env)] pub hub_repo: Option, + /// Path to the Hugging Face repository containing model data and configuration. + #[clap(long, env)] + pub gcs_bucket: Option, + + /// Path to the Hugging Face repository containing model data and configuration. + #[clap(long, env)] + pub gcs_prefix: Option, + #[clap(long, env, default_value_t = 3)] pub hub_max_concurrent_downloads: usize, @@ -227,37 +237,73 @@ impl TrainArgs { let checkpoint_upload_info = match ( &hub_read_token, self.hub_repo.clone(), + self.gcs_bucket.clone(), + self.gcs_prefix.clone(), self.checkpoint_dir.clone(), self.delete_old_steps, self.keep_steps, ) { - (Some(token), Some(repo), Some(dir), delete_old_steps, keep_steps) => { + (_, Some(_), Some(_), _, _, _, _) => { + bail!("Use either GCS or HF hub for checkpoint uploads, not both.") + } + (Some(token), Some(repo), None, _, Some(dir), delete_old_steps, keep_steps) => { if keep_steps == 0 { bail!("keep_steps must be >= 1 for hub repository uploads (got {keep_steps})") } Some(CheckpointConfig { checkpoint_dir: dir, - hub_upload: Some(HubUploadInfo { + hub_upload: Some(UploadInfo::Hub(HubUploadInfo { hub_repo: repo, hub_token: token.to_string(), - }), + })), + delete_old_steps, + keep_steps, + }) + } + (_, _, Some(gcp_bucket), Some(gcs_prefix), Some(dir), delete_old_steps, keep_steps) => { + if keep_steps == 0 { + bail!("keep_steps must be >= 1 for GCS uploads (got {keep_steps})") + } + Some(CheckpointConfig { + checkpoint_dir: dir, + hub_upload: Some(UploadInfo::Gcs(GcsUploadInfo { + gcs_bucket: gcp_bucket, + gcs_prefix: Some(gcs_prefix), + })), + delete_old_steps, + keep_steps, + }) + } + (_, _, Some(gcp_bucket), None, Some(dir), delete_old_steps, keep_steps) => { + if keep_steps == 0 { + bail!("keep_steps must be >= 1 for GCS uploads (got {keep_steps})") + } + Some(CheckpointConfig { + checkpoint_dir: dir, + hub_upload: Some(UploadInfo::Gcs(GcsUploadInfo { + gcs_bucket: gcp_bucket, + gcs_prefix: None, + })), delete_old_steps, keep_steps, }) } - (None, Some(_), Some(_), _, _) => { + (None, Some(_), None, _, _, _, _) => { bail!("hub-repo and checkpoint-dir set, but no HF_TOKEN env variable.") } - (_, Some(_), None, _, _) => { + (_, None, Some(_), _, None, _, _) => { + bail!("gcs-bucket and checkpoint-dir set, but no GCS_TOKEN env variable.") + } + (_, Some(_), None, _, None, _, _) => { bail!("--hub-repo was set, but no --checkpoint-dir was passed!") } - (_, None, Some(dir), delete_old_steps, keep_steps) => Some(CheckpointConfig { + (_, None, None, _, Some(dir), delete_old_steps, keep_steps) => Some(CheckpointConfig { checkpoint_dir: dir, hub_upload: None, delete_old_steps, keep_steps, }), - (_, None, _, _, _) => None, + (_, None, None, _, _, _, _) => None, }; Ok(checkpoint_upload_info) diff --git a/shared/client/src/lib.rs b/shared/client/src/lib.rs index f1c545ed7..8a0086932 100644 --- a/shared/client/src/lib.rs +++ b/shared/client/src/lib.rs @@ -10,7 +10,8 @@ pub use cli::{TrainArgs, prepare_environment, print_identity_keys, read_identity pub use client::Client; pub use protocol::{Broadcast, BroadcastType, Finished, NC, TrainingResult}; pub use state::{ - CheckpointConfig, HubUploadInfo, InitRunError, RoundState, RunInitConfig, RunInitConfigAndIO, + CheckpointConfig, GcsUploadInfo, HubUploadInfo, InitRunError, RoundState, RunInitConfig, + RunInitConfigAndIO, UploadInfo, }; pub use testing::IntegrationTestLogMarker; pub use tui::{ClientTUI, ClientTUIState}; diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index 6e2efb6ef..5b0a567e2 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -1,11 +1,12 @@ -use crate::HubUploadInfo; +use crate::{HubUploadInfo, state::types::GcsUploadInfo}; +use crate::UploadInfo; use psyche_coordinator::{ Coordinator, - model::{self, HubRepo}, + model::{self, GcsRepo, HubRepo}, }; use psyche_core::{FixedString, NodeIdentity}; -use psyche_data_provider::{UploadModelError, upload_model_repo_async}; +use psyche_data_provider::{UploadModelError, upload_model_repo_async, upload_to_gcs_async}; use psyche_modeling::{ CausalLM, SaveSafetensorsError, Trainer, TrainerThreadCommunicationError, save_tensors_into_safetensors, @@ -42,7 +43,7 @@ pub enum CooldownError { } pub struct CooldownStepMetadata { - tx_checkpoint: mpsc::UnboundedSender, + tx_checkpoint: mpsc::UnboundedSender, tx_model: mpsc::UnboundedSender>, checkpoint_info: Option, checkpoint_extra_files: Vec, @@ -59,7 +60,7 @@ pub struct CooldownStepMetadata { impl CooldownStepMetadata { pub fn new( - tx_checkpoint: mpsc::UnboundedSender, + tx_checkpoint: mpsc::UnboundedSender, tx_model: mpsc::UnboundedSender>, checkpoint_info: Option, checkpoint_extra_files: Vec, @@ -207,53 +208,90 @@ impl CooldownStepMetadata { local.push(to); } - let Some(HubUploadInfo { - hub_repo, - hub_token, - }) = hub_upload - else { - cleanup_dirs( - delete_queue, - keep_steps, - run_id, - delete_old_steps, - step, - checkpoint_dir, - ) - .await; - return Ok::<(), CheckpointError>(()); - }; - - info!(repo = hub_repo, "Uploading checkpoint to HuggingFace"); - let revision = match upload_model_repo_async( - hub_repo.clone(), - local, - hub_token.clone(), - Some(format!("step {step}")), - None, - ) - .await - { - Ok(revision) => { - info!( - repo = hub_repo, - revision = revision, - "Upload to HuggingFace complete" - ); - revision + match hub_upload { + Some(UploadInfo::Gcs(GcsUploadInfo { + gcs_bucket, + gcs_prefix, + })) => { + info!(bucket = gcs_bucket, "Uploading checkpoint to GCS"); + match upload_to_gcs_async(gcs_bucket.clone(), local, gcs_prefix.clone()) + .await + { + Ok(path) => { + info!( + "Upload to GCS complete at gs://{}/{}", + gcs_bucket, + gcs_prefix.clone().unwrap_or_default() + ); + path + } + Err(err) => { + error!(bucket = gcs_bucket, "Error uploading to GCS: {err:#}"); + return Err(err.into()); + } + }; + tx_checkpoint + .send(model::Checkpoint::Gcs(GcsRepo { + bucket: FixedString::from_str_truncated(&format!( + "gs://{}/{:?}", + gcs_bucket, gcs_prefix + )), + prefix: Some(FixedString::from_str_truncated(&format!( + "step-{step}" + ))), + })) + .map_err(|_| CheckpointError::SendCheckpoint)?; } - Err(err) => { - error!(repo = hub_repo, "Error uploading to HuggingFace: {err:#}"); - return Err(err.into()); + Some(UploadInfo::Hub(HubUploadInfo { + hub_repo, + hub_token, + })) => { + info!(repo = hub_repo, "Uploading checkpoint to HuggingFace"); + let revision = match upload_model_repo_async( + hub_repo.clone(), + local, + hub_token.clone(), + Some(format!("step {step}")), + None, + ) + .await + { + Ok(revision) => { + info!( + repo = hub_repo, + revision = revision, + "Upload to HuggingFace complete" + ); + revision + } + Err(err) => { + error!( + repo = hub_repo, + "Error uploading to HuggingFace: {err:#}" + ); + return Err(err.into()); + } + }; + tx_checkpoint + .send(model::Checkpoint::Hub(HubRepo { + repo_id: FixedString::from_str_truncated(&hub_repo), + revision: Some(FixedString::from_str_truncated(&revision)), + })) + .map_err(|_| CheckpointError::SendCheckpoint)?; } - }; - - tx_checkpoint - .send(HubRepo { - repo_id: FixedString::from_str_truncated(&hub_repo), - revision: Some(FixedString::from_str_truncated(&revision)), - }) - .map_err(|_| CheckpointError::SendCheckpoint)?; + None => { + cleanup_dirs( + delete_queue, + keep_steps, + run_id, + delete_old_steps, + step, + checkpoint_dir, + ) + .await; + return Ok::<(), CheckpointError>(()); + } + } // we put the cleanup step at the end, so that if keep_steps == 0 the logic will still work // we'll just delete the dir after we've uploaded it diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index 54073d3b2..3d396d904 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -154,7 +154,7 @@ pub struct RunInitConfigAndIO { pub tx_health_check: UnboundedSender>, pub tx_witness: UnboundedSender, - pub tx_checkpoint: UnboundedSender, + pub tx_checkpoint: UnboundedSender, pub tx_model: UnboundedSender>, pub tx_parameters_req: UnboundedSender<(Vec, OneshotModelParameterSender)>, pub tx_config: UnboundedSender<(String, String)>, diff --git a/shared/client/src/state/mod.rs b/shared/client/src/state/mod.rs index e4d73d5b9..a148eac59 100644 --- a/shared/client/src/state/mod.rs +++ b/shared/client/src/state/mod.rs @@ -16,4 +16,7 @@ mod witness; pub use init::{InitRunError, RunInitConfig, RunInitConfigAndIO}; pub use round_state::RoundState; pub use steps::{ApplyMessageOutcome, RunManager}; -pub use types::{CheckpointConfig, DistroBroadcastAndPayload, FinishedBroadcast, HubUploadInfo}; +pub use types::{ + CheckpointConfig, DistroBroadcastAndPayload, FinishedBroadcast, GcsUploadInfo, HubUploadInfo, + UploadInfo, +}; diff --git a/shared/client/src/state/types.rs b/shared/client/src/state/types.rs index 2edf22760..68e5e09f8 100644 --- a/shared/client/src/state/types.rs +++ b/shared/client/src/state/types.rs @@ -14,9 +14,21 @@ pub struct HubUploadInfo { pub hub_token: String, } +#[derive(Debug, Clone)] +pub struct GcsUploadInfo { + pub gcs_bucket: String, + pub gcs_prefix: Option, +} + +#[derive(Debug, Clone)] +pub enum UploadInfo { + Hub(HubUploadInfo), + Gcs(GcsUploadInfo), +} + #[derive(Debug, Clone)] pub struct CheckpointConfig { - pub hub_upload: Option, + pub hub_upload: Option, pub checkpoint_dir: PathBuf, pub delete_old_steps: bool, pub keep_steps: u32, diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index 0a3a9975d..8d18f199e 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -1,6 +1,6 @@ use crate::{ Commitment, Committee, CommitteeProof, CommitteeSelection, WitnessProof, - model::{Checkpoint, HubRepo, Model}, + model::{Checkpoint, GcsRepo, HubRepo, Model}, }; use anchor_lang::{AnchorDeserialize, AnchorSerialize, InitSpace, prelude::borsh}; @@ -596,7 +596,7 @@ impl Coordinator { &mut self, from: &T, index: u64, - hub_repo: HubRepo, + hub_repo: Checkpoint, ) -> std::result::Result<(), CoordinatorError> { let index = index as usize; if index >= self.epoch_state.clients.len() || self.epoch_state.clients[index].id != *from { @@ -607,8 +607,15 @@ impl Coordinator { // more download options. match &mut self.model { Model::LLM(llm) => match llm.checkpoint { - Checkpoint::P2P(_) => llm.checkpoint = Checkpoint::P2P(hub_repo), - Checkpoint::Hub(_) => llm.checkpoint = Checkpoint::Hub(hub_repo), + Checkpoint::P2P(HubRepo { repo_id, revision }) => { + llm.checkpoint = Checkpoint::P2P(HubRepo { repo_id, revision }) + } + Checkpoint::Hub(HubRepo { repo_id, revision }) => { + llm.checkpoint = Checkpoint::Hub(HubRepo { repo_id, revision }) + } + Checkpoint::Gcs(GcsRepo { bucket, prefix }) => { + llm.checkpoint = Checkpoint::Gcs(GcsRepo { bucket, prefix }) + } _ => {} }, } diff --git a/shared/data-provider/src/hub.rs b/shared/data-provider/src/hub.rs index ca090a759..df6938f74 100644 --- a/shared/data-provider/src/hub.rs +++ b/shared/data-provider/src/hub.rs @@ -1,3 +1,5 @@ +use google_cloud_storage::client::{Client, ClientConfig}; +use google_cloud_storage::http::objects::upload::{Media, UploadObjectRequest, UploadType}; use hf_hub::{ Cache, Repo, RepoType, api::{ @@ -262,3 +264,55 @@ pub async fn upload_model_repo_async( }; Ok(commit_info.oid) } + +pub async fn upload_to_gcs_async( + bucket: String, + files: Vec, + prefix: Option, +) -> Result<(), UploadModelError> { + let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { + info!("Using authenticated GCS client"); + ClientConfig::default().with_auth().await? + } else { + info!("Using anonymous GCS client"); + ClientConfig::default().anonymous() + }; + let client = Client::new(config); + + for path in files { + let file_name = path + .file_name() + .ok_or_else(|| UploadModelError::NotAFile(path.clone()))? + .to_str() + .ok_or_else(|| UploadModelError::InvalidFilename(path.clone()))?; + + let object_name = match &prefix { + Some(p) => format!("{}/{}", p, file_name), + None => file_name.to_string(), + }; + + let data = tokio::fs::read(&path).await.unwrap(); + + let upload_type = UploadType::Simple(Media::new(object_name.clone())); + let uploaded = client + .upload_object( + &UploadObjectRequest { + bucket: bucket.clone(), + ..Default::default() + }, + data, + &upload_type, + ) + .await + .unwrap(); + + info!( + bucket = bucket, + object = object_name, + size = uploaded.size, + "Successfully uploaded file to GCS" + ); + } + + Ok(()) +} diff --git a/shared/data-provider/src/lib.rs b/shared/data-provider/src/lib.rs index 7b97fe101..0c884bea3 100644 --- a/shared/data-provider/src/lib.rs +++ b/shared/data-provider/src/lib.rs @@ -19,6 +19,7 @@ pub use gcs::{GcsError, download_model_from_gcs_async, download_model_from_gcs_s pub use hub::{ UploadModelError, download_dataset_repo_async, download_dataset_repo_sync, download_model_repo_async, download_model_repo_sync, upload_model_repo_async, + upload_to_gcs_async, }; pub use local::LocalDataProvider; pub use parquet::record::{ListAccessor, MapAccessor, RowAccessor}; diff --git a/shared/watcher/src/traits.rs b/shared/watcher/src/traits.rs index 1cd96a7ed..50bf317bb 100644 --- a/shared/watcher/src/traits.rs +++ b/shared/watcher/src/traits.rs @@ -27,5 +27,5 @@ pub trait Backend: Send + Sync { async fn wait_for_new_state(&mut self) -> Result>; async fn send_witness(&mut self, opportunistic_data: OpportunisticData) -> Result<()>; async fn send_health_check(&mut self, health_check: HealthChecks) -> Result<()>; - async fn send_checkpoint(&mut self, checkpoint: model::HubRepo) -> Result<()>; + async fn send_checkpoint(&mut self, checkpoint: model::Checkpoint) -> Result<()>; } From e893c7c01481cf99cfff0ca74e16986f3ee69117 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Mon, 12 Jan 2026 09:30:53 -0800 Subject: [PATCH 2/6] Fix convert call --- shared/client/src/cli.rs | 4 +--- shared/client/src/state/cooldown.rs | 17 ++++++++++------- shared/data-provider/src/hub.rs | 2 +- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/shared/client/src/cli.rs b/shared/client/src/cli.rs index 2f6f2b080..73e4f8a4c 100644 --- a/shared/client/src/cli.rs +++ b/shared/client/src/cli.rs @@ -291,9 +291,6 @@ impl TrainArgs { (None, Some(_), None, _, _, _, _) => { bail!("hub-repo and checkpoint-dir set, but no HF_TOKEN env variable.") } - (_, None, Some(_), _, None, _, _) => { - bail!("gcs-bucket and checkpoint-dir set, but no GCS_TOKEN env variable.") - } (_, Some(_), None, _, None, _, _) => { bail!("--hub-repo was set, but no --checkpoint-dir was passed!") } @@ -304,6 +301,7 @@ impl TrainArgs { keep_steps, }), (_, None, None, _, _, _, _) => None, + _ => todo!(), }; Ok(checkpoint_upload_info) diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index 5b0a567e2..94c5ac685 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -166,13 +166,16 @@ impl CooldownStepMetadata { .send(variables_clone) .map_err(|_| CheckpointError::SendCheckpoint)?; - let (variables, trainer) = if checkpoint_info.is_some() { - // convert from internal shape to serialized shape (e.g. torchtitan to hf) - tokio::task::spawn_blocking(|| (trainer.convert(Some(variables)), trainer)) - .await - .map_err(|_| CheckpointError::ExtractThreadCrashed)? - } else { - (variables, trainer) + // convert from internal shape to serialized shape (e.g. torchtitan to hf) + let (variables, trainer) = match trainer { + #[cfg(feature = "python")] + Trainer::PythonDistributed(_) => { + info!("Converting distributed trainer variables for checkpointing..."); + tokio::task::spawn_blocking(|| (trainer.convert(Some(variables)), trainer)) + .await + .map_err(|_| CheckpointError::ExtractThreadCrashed)? + } + _ => (variables, trainer), }; trainers.push(trainer); diff --git a/shared/data-provider/src/hub.rs b/shared/data-provider/src/hub.rs index df6938f74..d112e661e 100644 --- a/shared/data-provider/src/hub.rs +++ b/shared/data-provider/src/hub.rs @@ -272,7 +272,7 @@ pub async fn upload_to_gcs_async( ) -> Result<(), UploadModelError> { let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { info!("Using authenticated GCS client"); - ClientConfig::default().with_auth().await? + ClientConfig::default().with_auth().await.unwrap() } else { info!("Using anonymous GCS client"); ClientConfig::default().anonymous() From df99802ac1e7de6602ab3737c182fa0b15967fd4 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Mon, 12 Jan 2026 17:07:55 -0300 Subject: [PATCH 3/6] Refactor on model download and upload --- architectures/centralized/client/src/app.rs | 6 +- .../solana-client/src/backend.rs | 2 +- .../solana-client/src/command/checkpoint.rs | 2 +- .../solana-coordinator/src/instance_state.rs | 2 +- .../programs/solana-coordinator/src/lib.rs | 3 +- shared/client/src/cli.rs | 137 +++++++------- shared/client/src/state/cooldown.rs | 170 ++++++------------ shared/client/src/state/init.rs | 8 +- shared/client/src/state/mod.rs | 6 +- shared/client/src/state/types.rs | 15 +- shared/coordinator/src/coordinator.rs | 48 +++-- shared/data-provider/src/errors.rs | 47 +++++ shared/data-provider/src/gcs.rs | 104 +++++++++-- shared/data-provider/src/hub.rs | 145 +++++---------- shared/data-provider/src/lib.rs | 11 +- 15 files changed, 355 insertions(+), 351 deletions(-) create mode 100644 shared/data-provider/src/errors.rs diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index e3eb19019..8f6a9bcca 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -2,10 +2,10 @@ use anyhow::{Error, Result}; use bytemuck::Zeroable; use hf_hub::Repo; use psyche_centralized_shared::{ClientId, ClientToServerMessage, ServerToClientMessage}; +use psyche_client::HubUploadInfo; use psyche_client::UploadInfo; use psyche_client::{ - Client, ClientTUI, ClientTUIState, HubUploadInfo, NC, RunInitConfig, TrainArgs, - read_identity_secret_key, + Client, ClientTUI, ClientTUIState, NC, RunInitConfig, TrainArgs, read_identity_secret_key, }; use psyche_coordinator::{Coordinator, HealthChecks, model}; use psyche_metrics::ClientMetrics; @@ -178,7 +178,7 @@ impl App { if let Some(UploadInfo::Hub(HubUploadInfo { hub_repo, hub_token, - })) = &checkpoint_config.hub_upload + })) = &checkpoint_config.upload_info { let api = hf_hub::api::tokio::ApiBuilder::new() .with_token(Some(hub_token.clone())) diff --git a/architectures/decentralized/solana-client/src/backend.rs b/architectures/decentralized/solana-client/src/backend.rs index f948cde94..629ead188 100644 --- a/architectures/decentralized/solana-client/src/backend.rs +++ b/architectures/decentralized/solana-client/src/backend.rs @@ -21,7 +21,7 @@ use anyhow::{Context, Result, anyhow}; use futures_util::StreamExt; use psyche_client::IntegrationTestLogMarker; use psyche_coordinator::model::{self, Checkpoint}; -use psyche_coordinator::{CommitteeProof, Coordinator, HealthChecks, model::HubRepo}; +use psyche_coordinator::{CommitteeProof, Coordinator, HealthChecks}; use psyche_watcher::{Backend as WatcherBackend, OpportunisticData}; use solana_account_decoder_client_types::{UiAccount, UiAccountEncoding}; use solana_transaction_status_client_types::UiTransactionEncoding; diff --git a/architectures/decentralized/solana-client/src/command/checkpoint.rs b/architectures/decentralized/solana-client/src/command/checkpoint.rs index 3f8cc9f2b..3ee098640 100644 --- a/architectures/decentralized/solana-client/src/command/checkpoint.rs +++ b/architectures/decentralized/solana-client/src/command/checkpoint.rs @@ -46,7 +46,7 @@ pub async fn command_checkpoint_execute( &coordinator_instance, &coordinator_account, &user, - Checkpoint::Hub(repo.clone()), + Checkpoint::Hub(repo), ); let signature = backend .send_and_retry("Checkpoint", &[instruction], &[]) diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs index 7ea079bd0..8d9ddb977 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs @@ -10,8 +10,8 @@ use psyche_coordinator::RunState; use psyche_coordinator::SOLANA_MAX_STRING_LEN; use psyche_coordinator::TickResult; use psyche_coordinator::Witness; +use psyche_coordinator::model::Checkpoint; use psyche_coordinator::model::Model; -use psyche_coordinator::model::{Checkpoint, HubRepo}; use psyche_core::FixedString; use psyche_core::SmallBoolean; use psyche_core::sha256v; diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs index f7feb7981..0a041a6e9 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs @@ -21,7 +21,7 @@ use psyche_coordinator::Witness; use psyche_coordinator::WitnessBloom; use psyche_coordinator::WitnessMetadata; use psyche_coordinator::WitnessProof; -use psyche_coordinator::model::{HubRepo, Model}; +use psyche_coordinator::model::Model; use psyche_core::MerkleRoot; use serde::Deserialize; use serde::Serialize; @@ -165,7 +165,6 @@ impl CoordinatorInstance { pub mod psyche_solana_coordinator { use super::*; - use psyche_coordinator::model::Checkpoint; use psyche_core::FixedString; pub fn init_coordinator( diff --git a/shared/client/src/cli.rs b/shared/client/src/cli.rs index 73e4f8a4c..a4ef145f0 100644 --- a/shared/client/src/cli.rs +++ b/shared/client/src/cli.rs @@ -1,9 +1,9 @@ -use crate::{CheckpointConfig, HubUploadInfo, WandBInfo}; +use crate::{CheckpointConfig, WandBInfo}; -use crate::GcsUploadInfo; use crate::UploadInfo; use anyhow::{Result, anyhow, bail}; use clap::Args; +use psyche_data_provider::{GcsUploadInfo, HubUploadInfo}; use psyche_eval::tasktype_from_name; use psyche_modeling::Devices; use psyche_network::{DiscoveryMode, RelayKind, SecretKey}; @@ -148,11 +148,11 @@ pub struct TrainArgs { #[clap(long, env)] pub hub_repo: Option, - /// Path to the Hugging Face repository containing model data and configuration. + /// Name of the GCS bucket containing model data and configuration. #[clap(long, env)] pub gcs_bucket: Option, - /// Path to the Hugging Face repository containing model data and configuration. + /// Prefix within the GCS bucket for model data and configuration. #[clap(long, env)] pub gcs_prefix: Option, @@ -234,77 +234,72 @@ impl TrainArgs { pub fn checkpoint_config(&self) -> Result> { let hub_read_token = std::env::var("HF_TOKEN").ok(); - let checkpoint_upload_info = match ( - &hub_read_token, - self.hub_repo.clone(), - self.gcs_bucket.clone(), - self.gcs_prefix.clone(), - self.checkpoint_dir.clone(), - self.delete_old_steps, - self.keep_steps, - ) { - (_, Some(_), Some(_), _, _, _, _) => { - bail!("Use either GCS or HF hub for checkpoint uploads, not both.") - } - (Some(token), Some(repo), None, _, Some(dir), delete_old_steps, keep_steps) => { - if keep_steps == 0 { - bail!("keep_steps must be >= 1 for hub repository uploads (got {keep_steps})") - } - Some(CheckpointConfig { - checkpoint_dir: dir, - hub_upload: Some(UploadInfo::Hub(HubUploadInfo { - hub_repo: repo, - hub_token: token.to_string(), - })), - delete_old_steps, - keep_steps, - }) - } - (_, _, Some(gcp_bucket), Some(gcs_prefix), Some(dir), delete_old_steps, keep_steps) => { - if keep_steps == 0 { - bail!("keep_steps must be >= 1 for GCS uploads (got {keep_steps})") - } - Some(CheckpointConfig { - checkpoint_dir: dir, - hub_upload: Some(UploadInfo::Gcs(GcsUploadInfo { - gcs_bucket: gcp_bucket, - gcs_prefix: Some(gcs_prefix), - })), - delete_old_steps, - keep_steps, - }) - } - (_, _, Some(gcp_bucket), None, Some(dir), delete_old_steps, keep_steps) => { - if keep_steps == 0 { - bail!("keep_steps must be >= 1 for GCS uploads (got {keep_steps})") + + if self.hub_repo.is_some() && self.gcs_bucket.is_some() { + bail!("Use either GCS or HF hub for checkpoint uploads, not both."); + } + + let checkpoint_dir = match &self.checkpoint_dir { + Some(dir) => dir, + None => { + if self.hub_repo.is_some() || self.gcs_bucket.is_some() { + bail!( + "--hub-repo or --gcs-bucket was set, but no --checkpoint-dir was passed!" + ); } - Some(CheckpointConfig { - checkpoint_dir: dir, - hub_upload: Some(UploadInfo::Gcs(GcsUploadInfo { - gcs_bucket: gcp_bucket, - gcs_prefix: None, - })), - delete_old_steps, - keep_steps, - }) + return Ok(None); } - (None, Some(_), None, _, _, _, _) => { - bail!("hub-repo and checkpoint-dir set, but no HF_TOKEN env variable.") - } - (_, Some(_), None, _, None, _, _) => { - bail!("--hub-repo was set, but no --checkpoint-dir was passed!") - } - (_, None, None, _, Some(dir), delete_old_steps, keep_steps) => Some(CheckpointConfig { - checkpoint_dir: dir, - hub_upload: None, - delete_old_steps, - keep_steps, - }), - (_, None, None, _, _, _, _) => None, - _ => todo!(), }; - Ok(checkpoint_upload_info) + let upload_info = self.build_upload_info(&hub_read_token)?; + + if upload_info.is_some() && self.keep_steps == 0 { + bail!( + "keep_steps must be >= 1 for checkpoint uploads (got {})", + self.keep_steps + ); + } + + Ok(Some(CheckpointConfig { + checkpoint_dir: checkpoint_dir.clone(), + upload_info, + delete_old_steps: self.delete_old_steps, + keep_steps: self.keep_steps, + })) + } + + fn build_upload_info(&self, hub_token: &Option) -> Result> { + if let Some(repo) = &self.hub_repo { + return self.build_hub_upload_info(repo, hub_token); + } + + if let Some(bucket) = &self.gcs_bucket { + return self.build_gcs_upload_info(bucket); + } + + Ok(None) + } + + fn build_hub_upload_info( + &self, + repo: &str, + token: &Option, + ) -> Result> { + let token = token.as_ref().ok_or_else(|| { + anyhow::anyhow!("hub-repo and checkpoint-dir set, but no HF_TOKEN env variable.") + })?; + + Ok(Some(UploadInfo::Hub(HubUploadInfo { + hub_repo: repo.to_string(), + hub_token: token.to_string(), + }))) + } + + fn build_gcs_upload_info(&self, bucket: &str) -> Result> { + Ok(Some(UploadInfo::Gcs(GcsUploadInfo { + gcs_bucket: bucket.to_string(), + gcs_prefix: self.gcs_prefix.clone(), + }))) } pub fn eval_tasks(&self) -> Result> { diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index 94c5ac685..488c86d62 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -1,15 +1,12 @@ -use crate::{HubUploadInfo, state::types::GcsUploadInfo}; - use crate::UploadInfo; use psyche_coordinator::{ Coordinator, - model::{self, GcsRepo, HubRepo}, + model::{self}, }; -use psyche_core::{FixedString, NodeIdentity}; -use psyche_data_provider::{UploadModelError, upload_model_repo_async, upload_to_gcs_async}; +use psyche_core::NodeIdentity; +use psyche_data_provider::{UploadError, upload_to_gcs, upload_to_hub}; use psyche_modeling::{ - CausalLM, SaveSafetensorsError, Trainer, TrainerThreadCommunicationError, - save_tensors_into_safetensors, + SaveSafetensorsError, Trainer, TrainerThreadCommunicationError, save_tensors_into_safetensors, }; use std::{ cmp::Reverse, @@ -94,8 +91,8 @@ pub enum CheckpointError { #[error("Writing extra file to disk failed: {0}")] WriteExtraFile(#[from] tokio::io::Error), - #[error("Couldn't upload model to huggingface: {0}")] - UploadError(#[from] UploadModelError), + #[error("Couldn't upload model to huggingface or GCS: {0}")] + UploadError(#[from] UploadError), #[error("Couldn't send checkpoint - channel closed")] SendCheckpoint, @@ -182,127 +179,25 @@ impl CooldownStepMetadata { let evals = model_task_runner.start(trainers); let Some(CheckpointConfig { - hub_upload, + upload_info, checkpoint_dir, delete_old_steps, keep_steps, }) = checkpoint_info else { - // If there was no HF checkpointing configuration, return immediately return Ok((evals, None)); }; - // Start the upload process of the updated model parameters in a separate task let upload_handle = tokio::task::spawn(async move { let path = checkpoint_dir.join(format!("{run_id}-step{step}")); - info!("Saving to {}", path.display()); - let mut local = tokio::task::spawn_blocking({ - let path = path.clone(); - move || save_tensors_into_safetensors(variables, path) - }) - .await - .map_err(|_| CheckpointError::WriteThreadCrashed)??; - - for extra in checkpoint_extra_files { - let to = path.join(extra.file_name().unwrap()); - tokio::fs::copy(extra.clone(), to.clone()) - .await - .map_err(CheckpointError::WriteExtraFile)?; - local.push(to); - } + let local = + save_checkpoint_locally(path, variables, checkpoint_extra_files).await?; - match hub_upload { - Some(UploadInfo::Gcs(GcsUploadInfo { - gcs_bucket, - gcs_prefix, - })) => { - info!(bucket = gcs_bucket, "Uploading checkpoint to GCS"); - match upload_to_gcs_async(gcs_bucket.clone(), local, gcs_prefix.clone()) - .await - { - Ok(path) => { - info!( - "Upload to GCS complete at gs://{}/{}", - gcs_bucket, - gcs_prefix.clone().unwrap_or_default() - ); - path - } - Err(err) => { - error!(bucket = gcs_bucket, "Error uploading to GCS: {err:#}"); - return Err(err.into()); - } - }; - tx_checkpoint - .send(model::Checkpoint::Gcs(GcsRepo { - bucket: FixedString::from_str_truncated(&format!( - "gs://{}/{:?}", - gcs_bucket, gcs_prefix - )), - prefix: Some(FixedString::from_str_truncated(&format!( - "step-{step}" - ))), - })) - .map_err(|_| CheckpointError::SendCheckpoint)?; - } - Some(UploadInfo::Hub(HubUploadInfo { - hub_repo, - hub_token, - })) => { - info!(repo = hub_repo, "Uploading checkpoint to HuggingFace"); - let revision = match upload_model_repo_async( - hub_repo.clone(), - local, - hub_token.clone(), - Some(format!("step {step}")), - None, - ) - .await - { - Ok(revision) => { - info!( - repo = hub_repo, - revision = revision, - "Upload to HuggingFace complete" - ); - revision - } - Err(err) => { - error!( - repo = hub_repo, - "Error uploading to HuggingFace: {err:#}" - ); - return Err(err.into()); - } - }; - tx_checkpoint - .send(model::Checkpoint::Hub(HubRepo { - repo_id: FixedString::from_str_truncated(&hub_repo), - revision: Some(FixedString::from_str_truncated(&revision)), - })) - .map_err(|_| CheckpointError::SendCheckpoint)?; - } - None => { - cleanup_dirs( - delete_queue, - keep_steps, - run_id, - delete_old_steps, - step, - checkpoint_dir, - ) - .await; - return Ok::<(), CheckpointError>(()); - } + if let Some(upload_info) = upload_info { + upload_checkpoint(upload_info, local.clone(), step as u64, tx_checkpoint) + .await?; } - // we put the cleanup step at the end, so that if keep_steps == 0 the logic will still work - // we'll just delete the dir after we've uploaded it - // if we fail in any of the above steps we may wind up not queueing this dir for delete - // but that's probably better than risking having the dir deleted from under us - // for a relatively low priority disk cleanup task - // and this may actually be preferred anyway because if we failed to upload, we may want to keep - // the data around locally on disk cleanup_dirs( delete_queue, keep_steps, @@ -320,12 +215,53 @@ impl CooldownStepMetadata { } .instrument(info_span!("checkpointing")), ); + Ok(CooldownStep { checkpointing_and_evals, }) } } +async fn save_checkpoint_locally( + path: PathBuf, + variables: HashMap, + checkpoint_extra_files: Vec, +) -> Result, CheckpointError> { + info!("Saving to {}", path.display()); + let mut local = tokio::task::spawn_blocking({ + let path = path.clone(); + move || save_tensors_into_safetensors(variables, path) + }) + .await + .map_err(|_| CheckpointError::WriteThreadCrashed)??; + + for extra in checkpoint_extra_files { + let to = path.join(extra.file_name().unwrap()); + tokio::fs::copy(extra.clone(), to.clone()) + .await + .map_err(CheckpointError::WriteExtraFile)?; + local.push(to); + } + + Ok(local) +} + +async fn upload_checkpoint( + upload_info: UploadInfo, + local: Vec, + step: u64, + tx_checkpoint: mpsc::UnboundedSender, +) -> Result<(), CheckpointError> { + match upload_info { + UploadInfo::Gcs(gcs_info) => upload_to_gcs(gcs_info, local, step, tx_checkpoint) + .await + .map_err(CheckpointError::UploadError), + UploadInfo::Hub(hub_info) => upload_to_hub(hub_info, local, step, tx_checkpoint) + .await + .map_err(CheckpointError::UploadError), + } +} + type CheckpointAndEvalsHandle = JoinHandle< Result< ( diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index 3d396d904..36e19db90 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -5,9 +5,9 @@ use psyche_coordinator::{ }; use psyche_core::{Barrier, CancellableBarrier, NodeIdentity, Shuffle, TokenSize}; use psyche_data_provider::{ - DataProvider, DataProviderTcpClient, DummyDataProvider, GcsError, PreprocessedDataProvider, - Split, WeightedDataProvider, download_dataset_repo_async, download_model_from_gcs_async, - download_model_repo_async, + DataProvider, DataProviderTcpClient, DownloadError, DummyDataProvider, + PreprocessedDataProvider, Split, WeightedDataProvider, download_dataset_repo_async, + download_model_from_gcs_async, download_model_repo_async, http::{FileURLs, HttpDataProvider}, }; use psyche_metrics::ClientMetrics; @@ -91,7 +91,7 @@ pub enum InitRunError { HfModelLoad(#[from] hf_hub::api::tokio::ApiError), #[error("failed to download model from GCS: {0}")] - GcsModelLoad(#[from] GcsError), + GcsModelLoad(#[from] DownloadError), #[error("model loading thread crashed")] ModelLoadingThreadCrashed(JoinError), diff --git a/shared/client/src/state/mod.rs b/shared/client/src/state/mod.rs index a148eac59..78e6cd1eb 100644 --- a/shared/client/src/state/mod.rs +++ b/shared/client/src/state/mod.rs @@ -14,9 +14,7 @@ mod warmup; mod witness; pub use init::{InitRunError, RunInitConfig, RunInitConfigAndIO}; +pub use psyche_data_provider::{GcsUploadInfo, HubUploadInfo}; pub use round_state::RoundState; pub use steps::{ApplyMessageOutcome, RunManager}; -pub use types::{ - CheckpointConfig, DistroBroadcastAndPayload, FinishedBroadcast, GcsUploadInfo, HubUploadInfo, - UploadInfo, -}; +pub use types::{CheckpointConfig, DistroBroadcastAndPayload, FinishedBroadcast, UploadInfo}; diff --git a/shared/client/src/state/types.rs b/shared/client/src/state/types.rs index 68e5e09f8..29734f1a0 100644 --- a/shared/client/src/state/types.rs +++ b/shared/client/src/state/types.rs @@ -2,24 +2,13 @@ use std::path::PathBuf; use psyche_coordinator::CommitteeProof; use psyche_core::{BatchId, MerkleRoot, NodeIdentity}; +use psyche_data_provider::{GcsUploadInfo, HubUploadInfo}; use psyche_modeling::DistroResult; use psyche_network::{BlobTicket, TransmittableDistroResult}; use tch::TchError; use thiserror::Error; use tokio::task::JoinHandle; -#[derive(Debug, Clone)] -pub struct HubUploadInfo { - pub hub_repo: String, - pub hub_token: String, -} - -#[derive(Debug, Clone)] -pub struct GcsUploadInfo { - pub gcs_bucket: String, - pub gcs_prefix: Option, -} - #[derive(Debug, Clone)] pub enum UploadInfo { Hub(HubUploadInfo), @@ -28,7 +17,7 @@ pub enum UploadInfo { #[derive(Debug, Clone)] pub struct CheckpointConfig { - pub hub_upload: Option, + pub upload_info: Option, pub checkpoint_dir: PathBuf, pub delete_old_steps: bool, pub keep_steps: u32, diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index 8d18f199e..5b10bfcbd 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -1,6 +1,6 @@ use crate::{ Commitment, Committee, CommitteeProof, CommitteeSelection, WitnessProof, - model::{Checkpoint, GcsRepo, HubRepo, Model}, + model::{Checkpoint, Model}, }; use anchor_lang::{AnchorDeserialize, AnchorSerialize, InitSpace, prelude::borsh}; @@ -596,29 +596,43 @@ impl Coordinator { &mut self, from: &T, index: u64, - hub_repo: Checkpoint, + checkpoint_repo: Checkpoint, ) -> std::result::Result<(), CoordinatorError> { let index = index as usize; if index >= self.epoch_state.clients.len() || self.epoch_state.clients[index].id != *from { return Err(CoordinatorError::InvalidCommitteeProof); } - // TODO: In the case of more than one checkpointer, this will overwrite the hub repo - // with the last checkpointed one. We could instead have a vector of hub repos to have + + // TODO: In the case of more than one checkpointer, this will overwrite the checkpoint + // with the last checkpointed one. We could instead have a vector of checkpoints to have // more download options. - match &mut self.model { - Model::LLM(llm) => match llm.checkpoint { - Checkpoint::P2P(HubRepo { repo_id, revision }) => { - llm.checkpoint = Checkpoint::P2P(HubRepo { repo_id, revision }) - } - Checkpoint::Hub(HubRepo { repo_id, revision }) => { - llm.checkpoint = Checkpoint::Hub(HubRepo { repo_id, revision }) - } - Checkpoint::Gcs(GcsRepo { bucket, prefix }) => { - llm.checkpoint = Checkpoint::Gcs(GcsRepo { bucket, prefix }) - } - _ => {} - }, + let Model::LLM(llm) = &mut self.model; + match (&llm.checkpoint, checkpoint_repo) { + // If current is P2P, wrap the new checkpoint in P2P + (Checkpoint::P2P(_), Checkpoint::Hub(hub_repo)) => { + llm.checkpoint = Checkpoint::P2P(hub_repo); + } + (Checkpoint::P2PGcs(_), Checkpoint::Gcs(gcs_repo)) => { + llm.checkpoint = Checkpoint::P2PGcs(gcs_repo); + } + // If current is Hub, only accept Hub updates + (Checkpoint::Hub(_), Checkpoint::Hub(hub_repo)) => { + llm.checkpoint = Checkpoint::Hub(hub_repo); + } + // If current is Gcs, only accept Gcs updates + (Checkpoint::Gcs(_), Checkpoint::Gcs(gcs_repo)) => { + llm.checkpoint = Checkpoint::Gcs(gcs_repo); + } + (Checkpoint::P2PGcs(_), Checkpoint::Hub(hub_repo)) => { + llm.checkpoint = Checkpoint::P2P(hub_repo); + } + (Checkpoint::P2P(_), Checkpoint::Gcs(gcs_repo)) => { + llm.checkpoint = Checkpoint::P2PGcs(gcs_repo); + } + // Ignore other combinations + _ => {} } + Ok(()) } diff --git a/shared/data-provider/src/errors.rs b/shared/data-provider/src/errors.rs new file mode 100644 index 000000000..20e601b25 --- /dev/null +++ b/shared/data-provider/src/errors.rs @@ -0,0 +1,47 @@ +use std::path::PathBuf; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum UploadError { + #[error("path {0} is not a file")] + NotAFile(PathBuf), + + #[error("file {0} doesn't have a valid utf-8 representation")] + InvalidFilename(PathBuf), + + #[error("failed to send checkpoint notification")] + SendCheckpoint, + + // Hub-specific errors + #[error("failed to connect to HF hub: {0}")] + HfHub(#[from] hf_hub::api::tokio::ApiError), + + #[error("failed to commit files: {0}")] + Commit(#[from] hf_hub::api::tokio::CommitError), + + // GCS-specific errors + #[error("GCS authentication failed: {0}")] + GcsAuth(#[from] google_cloud_storage::client::google_cloud_auth::error::Error), + + #[error("GCS operation failed: {0}")] + GcsStorage(#[from] google_cloud_storage::http::Error), + + // Common errors + #[error("IO error: {0}")] + Io(#[from] std::io::Error), +} + +#[derive(Error, Debug)] +pub enum DownloadError { + #[error("failed to connect to HF hub: {0}")] + HfHub(#[from] hf_hub::api::tokio::ApiError), + + #[error("GCS authentication failed: {0}")] + GcsAuth(#[from] google_cloud_storage::client::google_cloud_auth::error::Error), + + #[error("GCS operation failed: {0}")] + GcsStorage(#[from] google_cloud_storage::http::Error), + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), +} diff --git a/shared/data-provider/src/gcs.rs b/shared/data-provider/src/gcs.rs index cb6848c1d..adf67a5fb 100644 --- a/shared/data-provider/src/gcs.rs +++ b/shared/data-provider/src/gcs.rs @@ -1,22 +1,22 @@ +use crate::errors::{DownloadError, UploadError}; use google_cloud_storage::client::{Client, ClientConfig}; +use google_cloud_storage::http::objects::upload::Media; +use google_cloud_storage::http::objects::upload::UploadObjectRequest; +use google_cloud_storage::http::objects::upload::UploadType; use google_cloud_storage::http::objects::{ download::Range, get::GetObjectRequest, list::ListObjectsRequest, }; +use psyche_coordinator::model::{self, GcsRepo}; +use psyche_core::FixedString; use std::path::PathBuf; -use thiserror::Error; use tokio::runtime::Runtime; +use tokio::sync::mpsc; use tracing::info; -#[derive(Debug, Error)] -pub enum GcsError { - #[error("GCS authentication failed: {0}")] - Auth(#[from] google_cloud_storage::client::google_cloud_auth::error::Error), - - #[error("GCS operation failed: {0}")] - Storage(#[from] google_cloud_storage::http::Error), - - #[error("IO error: {0}")] - Io(#[from] std::io::Error), +#[derive(Debug, Clone)] +pub struct GcsUploadInfo { + pub gcs_bucket: String, + pub gcs_prefix: Option, } const MODEL_EXTENSIONS: [&str; 3] = [".safetensors", ".json", ".py"]; @@ -42,7 +42,7 @@ fn get_cache_dir(bucket: &str, prefix: Option<&str>) -> PathBuf { pub async fn download_model_from_gcs_async( bucket: &str, prefix: Option<&str>, -) -> Result, GcsError> { +) -> Result, DownloadError> { // Use authenticated client if GOOGLE_APPLICATION_CREDENTIALS is set, otherwise anonymous let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { info!("Using authenticated GCS client"); @@ -132,7 +132,83 @@ pub async fn download_model_from_gcs_async( pub fn download_model_from_gcs_sync( bucket: &str, prefix: Option<&str>, -) -> Result, GcsError> { - let rt = Runtime::new().map_err(GcsError::Io)?; +) -> Result, DownloadError> { + let rt = Runtime::new().map_err(DownloadError::Io)?; rt.block_on(download_model_from_gcs_async(bucket, prefix)) } + +pub async fn upload_to_gcs( + gcs_info: GcsUploadInfo, + local: Vec, + step: u64, + tx_checkpoint: mpsc::UnboundedSender, +) -> Result<(), UploadError> { + let GcsUploadInfo { + gcs_bucket, + gcs_prefix, + } = gcs_info; + + info!(bucket = gcs_bucket, "Uploading checkpoint to GCS"); + + let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { + info!("Using authenticated GCS client"); + ClientConfig::default().with_auth().await? + } else { + info!("Using anonymous GCS client"); + ClientConfig::default().anonymous() + }; + let client = Client::new(config); + + for path in local { + let file_name = path + .file_name() + .ok_or_else(|| UploadError::NotAFile(path.clone()))? + .to_str() + .ok_or_else(|| UploadError::InvalidFilename(path.clone()))?; + + let object_name = match &gcs_prefix { + Some(p) => format!("{}/{}", p, file_name), + None => file_name.to_string(), + }; + + let data = tokio::fs::read(&path).await?; + + let upload_type = UploadType::Simple(Media::new(object_name.clone())); + let uploaded = client + .upload_object( + &UploadObjectRequest { + bucket: gcs_bucket.clone(), + ..Default::default() + }, + data, + &upload_type, + ) + .await?; + + info!( + bucket = gcs_bucket, + object = object_name, + size = uploaded.size, + "Successfully uploaded file to GCS" + ); + } + + info!( + "Upload to GCS complete at gs://{}/{}", + gcs_bucket, + gcs_prefix.as_deref().unwrap_or("") + ); + + tx_checkpoint + .send(model::Checkpoint::Gcs(GcsRepo { + bucket: FixedString::from_str_truncated(&format!( + "gs://{}/{}", + gcs_bucket, + gcs_prefix.as_deref().unwrap_or("") + )), + prefix: Some(FixedString::from_str_truncated(&format!("step-{step}"))), + })) + .map_err(|_| UploadError::SendCheckpoint)?; + + Ok(()) +} diff --git a/shared/data-provider/src/hub.rs b/shared/data-provider/src/hub.rs index d112e661e..13a575b84 100644 --- a/shared/data-provider/src/hub.rs +++ b/shared/data-provider/src/hub.rs @@ -1,14 +1,16 @@ -use google_cloud_storage::client::{Client, ClientConfig}; -use google_cloud_storage::http::objects::upload::{Media, UploadObjectRequest, UploadType}; +use crate::errors::UploadError; +use crate::hub::model::HubRepo; use hf_hub::{ Cache, Repo, RepoType, api::{ Siblings, - tokio::{ApiError, CommitError, UploadSource}, + tokio::{ApiError, UploadSource}, }, }; +use psyche_coordinator::model; +use psyche_core::FixedString; use std::{path::PathBuf, time::Instant}; -use thiserror::Error; +use tokio::sync::mpsc; use tracing::{error, info}; const MODEL_EXTENSIONS: [&str; 3] = [".safetensors", ".json", ".py"]; @@ -191,42 +193,39 @@ pub fn download_dataset_repo_sync( ) } -#[derive(Error, Debug)] -pub enum UploadModelError { - #[error("path {0} is not a file")] - NotAFile(PathBuf), - - #[error("file {0} doesn't have a valid utf-8 representation")] - InvalidFilename(PathBuf), +#[derive(Debug, Clone)] +pub struct HubUploadInfo { + pub hub_repo: String, + pub hub_token: String, +} - #[error("failed to connect to HF hub: {0}")] - HfHub(#[from] ApiError), +pub async fn upload_to_hub( + hub_info: HubUploadInfo, + local: Vec, + step: u64, + tx_checkpoint: mpsc::UnboundedSender, +) -> Result<(), UploadError> { + let HubUploadInfo { + hub_repo, + hub_token, + } = hub_info; - #[error("failed to commit files: {0}")] - Commit(#[from] CommitError), -} + info!(repo = hub_repo, "Uploading checkpoint to HuggingFace"); -pub async fn upload_model_repo_async( - repo_id: String, - files: Vec, - token: String, - commit_message: Option, - commit_description: Option, -) -> Result { let api = hf_hub::api::tokio::ApiBuilder::new() - .with_token(Some(token)) + .with_token(Some(hub_token.clone())) .build()?; - let repo = Repo::model(repo_id.clone()); + let repo = Repo::model(hub_repo.clone()); let api_repo = api.repo(repo); - let files: Result, _> = files + let files: Result, _> = local .into_iter() .map(|path| { path.file_name() - .ok_or(UploadModelError::NotAFile(path.clone())) + .ok_or(UploadError::NotAFile(path.clone())) .and_then(|name| { name.to_str() - .ok_or(UploadModelError::InvalidFilename(path.clone())) + .ok_or(UploadError::InvalidFilename(path.clone())) .map(|s| s.to_string()) }) .map(|name| (path.into(), name)) @@ -235,84 +234,32 @@ pub async fn upload_model_repo_async( let files = files?; - let commit_info = match api_repo - .upload_files( - files, - commit_message.clone(), - commit_description.clone(), - false, - ) + let commit_info = api_repo + .upload_files(files, Some(format!("step {step}")), None, false) .await - { - Ok(info) => { - info!( - repo = repo_id, - oid = info.oid, - "Successfully uploaded files to HuggingFace" - ); - info - } - Err(e) => { + .map_err(|e| { error!( - repo = repo_id, + repo = hub_repo, error = ?e, - "Failed to upload files to HuggingFace. Full error details: {:#?}", - e + "Failed to upload files to HuggingFace" ); - return Err(e.into()); - } - }; - Ok(commit_info.oid) -} - -pub async fn upload_to_gcs_async( - bucket: String, - files: Vec, - prefix: Option, -) -> Result<(), UploadModelError> { - let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { - info!("Using authenticated GCS client"); - ClientConfig::default().with_auth().await.unwrap() - } else { - info!("Using anonymous GCS client"); - ClientConfig::default().anonymous() - }; - let client = Client::new(config); + e + })?; - for path in files { - let file_name = path - .file_name() - .ok_or_else(|| UploadModelError::NotAFile(path.clone()))? - .to_str() - .ok_or_else(|| UploadModelError::InvalidFilename(path.clone()))?; + let revision = commit_info.oid; - let object_name = match &prefix { - Some(p) => format!("{}/{}", p, file_name), - None => file_name.to_string(), - }; + info!( + repo = hub_repo, + revision = revision, + "Upload to HuggingFace complete" + ); - let data = tokio::fs::read(&path).await.unwrap(); - - let upload_type = UploadType::Simple(Media::new(object_name.clone())); - let uploaded = client - .upload_object( - &UploadObjectRequest { - bucket: bucket.clone(), - ..Default::default() - }, - data, - &upload_type, - ) - .await - .unwrap(); - - info!( - bucket = bucket, - object = object_name, - size = uploaded.size, - "Successfully uploaded file to GCS" - ); - } + tx_checkpoint + .send(model::Checkpoint::Hub(HubRepo { + repo_id: FixedString::from_str_truncated(&hub_repo), + revision: Some(FixedString::from_str_truncated(&revision)), + })) + .map_err(|_| UploadError::SendCheckpoint)?; Ok(()) } diff --git a/shared/data-provider/src/lib.rs b/shared/data-provider/src/lib.rs index 0c884bea3..a35b73dae 100644 --- a/shared/data-provider/src/lib.rs +++ b/shared/data-provider/src/lib.rs @@ -1,6 +1,7 @@ mod data_provider; mod dataset; mod dummy; +mod errors; mod file_extensions; mod gcs; pub mod http; @@ -14,12 +15,14 @@ mod weighted; pub use data_provider::DataProvider; pub use dataset::{Dataset, Field, Row, Split}; pub use dummy::DummyDataProvider; +pub use errors::{DownloadError, UploadError}; pub use file_extensions::{DATA_FILE_EXTENSIONS, PARQUET_EXTENSION}; -pub use gcs::{GcsError, download_model_from_gcs_async, download_model_from_gcs_sync}; +pub use gcs::{ + GcsUploadInfo, download_model_from_gcs_async, download_model_from_gcs_sync, upload_to_gcs, +}; pub use hub::{ - UploadModelError, download_dataset_repo_async, download_dataset_repo_sync, - download_model_repo_async, download_model_repo_sync, upload_model_repo_async, - upload_to_gcs_async, + HubUploadInfo, download_dataset_repo_async, download_dataset_repo_sync, + download_model_repo_async, download_model_repo_sync, upload_to_hub, }; pub use local::LocalDataProvider; pub use parquet::record::{ListAccessor, MapAccessor, RowAccessor}; From a41c4572194db31f6201775dbfca15fbdad4cc81 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Mon, 12 Jan 2026 12:20:52 -0800 Subject: [PATCH 4/6] Fix import with python feature --- shared/client/src/state/cooldown.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index 488c86d62..28b6294be 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -5,6 +5,8 @@ use psyche_coordinator::{ }; use psyche_core::NodeIdentity; use psyche_data_provider::{UploadError, upload_to_gcs, upload_to_hub}; +#[cfg(feature = "python")] +use psyche_modeling::CausalLM; use psyche_modeling::{ SaveSafetensorsError, Trainer, TrainerThreadCommunicationError, save_tensors_into_safetensors, }; From b19bc90e57b4add11fe352ce1dd583c03c08d1a9 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Mon, 12 Jan 2026 12:24:16 -0800 Subject: [PATCH 5/6] Remove vllm from nix --- nix/lib.nix | 1 + python/default.nix | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/nix/lib.nix b/nix/lib.nix index 882f17040..37ab6d844 100644 --- a/nix/lib.nix +++ b/nix/lib.nix @@ -36,6 +36,7 @@ let python312 pkg-config perl + cargo-nextest ]; buildInputs = diff --git a/python/default.nix b/python/default.nix index 196b07867..ea9faa77b 100644 --- a/python/default.nix +++ b/python/default.nix @@ -39,7 +39,6 @@ let # i'm really not a fan of providing torchtitan like this. i'd much rather have it be built as a git dep via uv2nix. # i think there's room to figure out how to provide setuptools for it. "torchtitan" - "vllm" ]; nixProvidedPythonPkgs = getAllTransitiveDeps topLevelNixPkgs; From 0247ff33ff3a877ecf6671e78e4dba0db8b1ad29 Mon Sep 17 00:00:00 2001 From: Pedro Fontana Date: Wed, 14 Jan 2026 15:18:20 -0300 Subject: [PATCH 6/6] GCS Model: manifest.json (#485) --- Cargo.lock | 1 + shared/client/src/state/cooldown.rs | 26 ++- shared/data-provider/Cargo.toml | 3 +- shared/data-provider/src/errors.rs | 6 + shared/data-provider/src/gcs.rs | 292 +++++++++++++++++++++++++--- shared/data-provider/src/lib.rs | 3 +- 6 files changed, 293 insertions(+), 38 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 554aaaad8..940f5fa02 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7015,6 +7015,7 @@ dependencies = [ "anyhow", "async-trait", "bytemuck", + "chrono", "clap", "futures", "google-cloud-storage", diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index 28b6294be..ff322e298 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -4,7 +4,7 @@ use psyche_coordinator::{ model::{self}, }; use psyche_core::NodeIdentity; -use psyche_data_provider::{UploadError, upload_to_gcs, upload_to_hub}; +use psyche_data_provider::{GcsManifestMetadata, UploadError, upload_to_gcs, upload_to_hub}; #[cfg(feature = "python")] use psyche_modeling::CausalLM; use psyche_modeling::{ @@ -137,6 +137,7 @@ impl CooldownStepMetadata { let step = state.progress.step - 1; let run_id = String::from(&state.run_id); + let epoch = state.progress.epoch as u32; let checkpoint_extra_files = self.checkpoint_extra_files.clone(); let checkpoint_info = self.checkpoint_info.clone(); let tx_checkpoint = self.tx_checkpoint.clone(); @@ -196,8 +197,18 @@ impl CooldownStepMetadata { save_checkpoint_locally(path, variables, checkpoint_extra_files).await?; if let Some(upload_info) = upload_info { - upload_checkpoint(upload_info, local.clone(), step as u64, tx_checkpoint) - .await?; + let manifest_metadata = GcsManifestMetadata { + epoch, + run_id: run_id.clone(), + }; + upload_checkpoint( + upload_info, + manifest_metadata, + local.clone(), + step as u64, + tx_checkpoint, + ) + .await?; } cleanup_dirs( @@ -250,14 +261,17 @@ async fn save_checkpoint_locally( async fn upload_checkpoint( upload_info: UploadInfo, + manifest_metadata: GcsManifestMetadata, local: Vec, step: u64, tx_checkpoint: mpsc::UnboundedSender, ) -> Result<(), CheckpointError> { match upload_info { - UploadInfo::Gcs(gcs_info) => upload_to_gcs(gcs_info, local, step, tx_checkpoint) - .await - .map_err(CheckpointError::UploadError), + UploadInfo::Gcs(gcs_info) => { + upload_to_gcs(gcs_info, manifest_metadata, local, step, tx_checkpoint) + .await + .map_err(CheckpointError::UploadError) + } UploadInfo::Hub(hub_info) => upload_to_hub(hub_info, local, step, tx_checkpoint) .await .map_err(CheckpointError::UploadError), diff --git a/shared/data-provider/Cargo.toml b/shared/data-provider/Cargo.toml index ea0e59791..7ad69e616 100644 --- a/shared/data-provider/Cargo.toml +++ b/shared/data-provider/Cargo.toml @@ -27,6 +27,8 @@ postcard.workspace = true bytemuck.workspace = true reqwest = "0.12.12" google-cloud-storage = "0.24.0" +chrono = { version = "0.4", features = ["serde"] } +serde_json.workspace = true ts-rs.workspace = true rayon.workspace = true @@ -38,4 +40,3 @@ test-log.workspace = true clap.workspace = true tempfile = "3.15.0" static-web-server = { git = "https://github.com/arilotter/static-web-server", rev = "c91445427b56c5ddff0365d8ec116e3b567377ac" } # forked to add a channel for getting the port -serde_json.workspace = true diff --git a/shared/data-provider/src/errors.rs b/shared/data-provider/src/errors.rs index 20e601b25..b84bc5f9a 100644 --- a/shared/data-provider/src/errors.rs +++ b/shared/data-provider/src/errors.rs @@ -29,6 +29,9 @@ pub enum UploadError { // Common errors #[error("IO error: {0}")] Io(#[from] std::io::Error), + + #[error("JSON error: {0}")] + Json(#[from] serde_json::Error), } #[derive(Error, Debug)] @@ -44,4 +47,7 @@ pub enum DownloadError { #[error("IO error: {0}")] Io(#[from] std::io::Error), + + #[error("JSON error: {0}")] + Json(#[from] serde_json::Error), } diff --git a/shared/data-provider/src/gcs.rs b/shared/data-provider/src/gcs.rs index adf67a5fb..b79a850b1 100644 --- a/shared/data-provider/src/gcs.rs +++ b/shared/data-provider/src/gcs.rs @@ -1,4 +1,5 @@ use crate::errors::{DownloadError, UploadError}; +use chrono::{DateTime, Utc}; use google_cloud_storage::client::{Client, ClientConfig}; use google_cloud_storage::http::objects::upload::Media; use google_cloud_storage::http::objects::upload::UploadObjectRequest; @@ -8,37 +9,103 @@ use google_cloud_storage::http::objects::{ }; use psyche_coordinator::model::{self, GcsRepo}; use psyche_core::FixedString; -use std::path::PathBuf; +use serde::{Deserialize, Serialize}; +use std::path::{Path, PathBuf}; use tokio::runtime::Runtime; use tokio::sync::mpsc; use tracing::info; +/// Checkpoint manifest.json uploaded to GCS alongside safetensors files. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GcsCheckpointManifest { + pub metadata: ManifestMetadata, + pub files: Vec, +} + +/// Checkpoint metadata. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ManifestMetadata { + pub timestamp: DateTime, + pub epoch: u32, + pub step: u32, + pub run_id: String, +} + +/// Single file entry in the manifest. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ManifestFileEntry { + pub filename: String, + pub generation: i64, + pub size_bytes: u64, +} + #[derive(Debug, Clone)] pub struct GcsUploadInfo { pub gcs_bucket: String, pub gcs_prefix: Option, } -const MODEL_EXTENSIONS: [&str; 3] = [".safetensors", ".json", ".py"]; - -fn check_model_extension(filename: &str) -> bool { - MODEL_EXTENSIONS.iter().any(|ext| filename.ends_with(ext)) +#[derive(Debug, Clone)] +pub struct GcsManifestMetadata { + pub epoch: u32, + pub run_id: String, } -fn get_cache_dir(bucket: &str, prefix: Option<&str>) -> PathBuf { - let base = std::env::var("HOME") - .map(|h| PathBuf::from(h).join(".cache")) - .unwrap_or_else(|_| PathBuf::from(".cache")) +const MODEL_EXTENSIONS: [&str; 3] = [".safetensors", ".json", ".py"]; + +fn get_cache_base(bucket: &str) -> PathBuf { + // Use HF_HOME if set, otherwise fall back to ~/.cache + std::env::var("HF_HOME") + .map(PathBuf::from) + .unwrap_or_else(|_| { + std::env::var("HOME") + .map(|h| PathBuf::from(h).join(".cache")) + .unwrap_or_else(|_| PathBuf::from(".cache")) + }) .join("psyche") .join("gcs") - .join(bucket); + .join(bucket) +} + +fn get_cache_dir( + bucket: &str, + prefix: Option<&str>, + step: u32, + manifest_generation: i64, +) -> PathBuf { + let base = get_cache_base(bucket); + let versioned_folder = format!("step-{}-{}", step, manifest_generation); + + match prefix { + Some(p) => base.join(p.trim_end_matches('/')).join(versioned_folder), + None => base.join(versioned_folder), + } +} + +fn get_cache_dir_no_manifest(bucket: &str, prefix: Option<&str>) -> PathBuf { + let base = get_cache_base(bucket); match prefix { - Some(p) => base.join(p.trim_end_matches('/')), - None => base, + Some(p) => base.join(p.trim_end_matches('/')).join("no_manifest"), + None => base.join("no_manifest"), } } +fn collect_cached_files( + cache_dir: &Path, + manifest: &GcsCheckpointManifest, +) -> Option> { + let mut files = Vec::new(); + for file_entry in &manifest.files { + let path = cache_dir.join(&file_entry.filename); + if !path.exists() { + return None; + } + files.push(path); + } + Some(files) +} + pub async fn download_model_from_gcs_async( bucket: &str, prefix: Option<&str>, @@ -53,7 +120,131 @@ pub async fn download_model_from_gcs_async( }; let client = Client::new(config); - // List all objects in the bucket with optional prefix + let manifest_object_path = match prefix { + Some(p) => format!("{}/manifest.json", p), + None => "manifest.json".to_string(), + }; + + // Get manifest metadata to obtain generation number + let manifest_metadata = client + .get_object(&GetObjectRequest { + bucket: bucket.to_owned(), + object: manifest_object_path.clone(), + ..Default::default() + }) + .await; + + match manifest_metadata { + Ok(object_meta) => { + let manifest_generation = object_meta.generation; + + // Download manifest content + let manifest_data = client + .download_object( + &GetObjectRequest { + bucket: bucket.to_owned(), + object: manifest_object_path, + ..Default::default() + }, + &Range::default(), + ) + .await?; + + let manifest: GcsCheckpointManifest = serde_json::from_slice(&manifest_data)?; + + info!( + "Found manifest: step {}, epoch {}, generation {}", + manifest.metadata.step, manifest.metadata.epoch, manifest_generation + ); + + // Build versioned cache path + let cache_dir = + get_cache_dir(bucket, prefix, manifest.metadata.step, manifest_generation); + + // Check if all manifest files exist in cache + let mut files = if let Some(cached) = collect_cached_files(&cache_dir, &manifest) { + info!("Using cached checkpoint at {:?}", cache_dir); + cached + } else { + info!( + "Model not found in cache, downloading checkpoint to {:?}", + cache_dir + ); + std::fs::create_dir_all(&cache_dir)?; + download_files_from_manifest(&client, bucket, prefix, &cache_dir, &manifest).await? + }; + // Download config files (json, py) - skips if already cached + let config_files = + download_files_no_manifest(&client, bucket, prefix, &cache_dir, &[".json", ".py"]) + .await?; + files.extend(config_files); + Ok(files) + } + Err(_) => { + // Fallback for old checkpoints without manifest + info!("No manifest found, downloading model without manifest"); + let cache_dir = get_cache_dir_no_manifest(bucket, prefix); + std::fs::create_dir_all(&cache_dir)?; + download_files_no_manifest(&client, bucket, prefix, &cache_dir, &MODEL_EXTENSIONS).await + } + } +} + +async fn download_files_from_manifest( + client: &Client, + bucket: &str, + prefix: Option<&str>, + cache_dir: &Path, + manifest: &GcsCheckpointManifest, +) -> Result, DownloadError> { + let mut downloaded_files = Vec::new(); + + for file_entry in &manifest.files { + let object_name = match prefix { + Some(p) => format!("{}/{}", p, file_entry.filename), + None => file_entry.filename.clone(), + }; + let local_path = cache_dir.join(&file_entry.filename); + + if local_path.exists() { + info!("Using cached: {}", file_entry.filename); + downloaded_files.push(local_path); + continue; + } + + info!( + "Downloading: gs://{}/{} (generation {})", + bucket, object_name, file_entry.generation + ); + + let data = client + .download_object( + &GetObjectRequest { + bucket: bucket.to_owned(), + object: object_name, + ..Default::default() + }, + &Range::default(), + ) + .await?; + + std::fs::write(&local_path, &data)?; + info!("Downloaded: {} ({} bytes)", file_entry.filename, data.len()); + downloaded_files.push(local_path); + } + + Ok(downloaded_files) +} + +/// Download model files by listing the bucket. Skips files that already exist in cache. +/// Used for initial model download (no manifest) and to fetch config files (json, py) after manifest download. +async fn download_files_no_manifest( + client: &Client, + bucket: &str, + prefix: Option<&str>, + cache_dir: &Path, + extensions: &[&str], +) -> Result, DownloadError> { let mut all_objects = vec![]; let mut page_token: Option = None; @@ -68,7 +259,7 @@ pub async fn download_model_from_gcs_async( .await?; for obj in results.items.iter().flatten() { - if check_model_extension(&obj.name) { + if extensions.iter().any(|ext| obj.name.ends_with(ext)) { all_objects.push(obj.name.clone()); } } @@ -80,33 +271,27 @@ pub async fn download_model_from_gcs_async( } info!( - "Found {} model files in gs://{}/{}", + "Found {} files ({}) in gs://{}/{}", all_objects.len(), + extensions.join(", "), bucket, prefix.unwrap_or("") ); - let cache_dir = get_cache_dir(bucket, prefix); - std::fs::create_dir_all(&cache_dir)?; - let mut downloaded_files = Vec::new(); for object_name in all_objects { - // Get just the filename (strip prefix if present) let filename = object_name.rsplit('/').next().unwrap_or(&object_name); - let local_path = cache_dir.join(filename); - // Skip if already cached if local_path.exists() { info!("Using cached: {}", filename); downloaded_files.push(local_path); continue; } - info!("Downloading: {}", object_name); + info!("Downloading: gs://{}/{}", bucket, object_name); - // Download the object let data = client .download_object( &GetObjectRequest { @@ -139,6 +324,7 @@ pub fn download_model_from_gcs_sync( pub async fn upload_to_gcs( gcs_info: GcsUploadInfo, + manifest_metadata: GcsManifestMetadata, local: Vec, step: u64, tx_checkpoint: mpsc::UnboundedSender, @@ -148,6 +334,8 @@ pub async fn upload_to_gcs( gcs_prefix, } = gcs_info; + let GcsManifestMetadata { epoch, run_id } = manifest_metadata; + info!(bucket = gcs_bucket, "Uploading checkpoint to GCS"); let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { @@ -159,6 +347,16 @@ pub async fn upload_to_gcs( }; let client = Client::new(config); + let mut manifest = GcsCheckpointManifest { + metadata: ManifestMetadata { + timestamp: Utc::now(), + epoch, + step: step as u32, + run_id, + }, + files: Vec::new(), + }; + for path in local { let file_name = path .file_name() @@ -166,11 +364,17 @@ pub async fn upload_to_gcs( .to_str() .ok_or_else(|| UploadError::InvalidFilename(path.clone()))?; + // Only upload safetensors files + if !file_name.ends_with(".safetensors") { + continue; + } + let object_name = match &gcs_prefix { Some(p) => format!("{}/{}", p, file_name), None => file_name.to_string(), }; + let size = std::fs::metadata(&path)?.len(); let data = tokio::fs::read(&path).await?; let upload_type = UploadType::Simple(Media::new(object_name.clone())); @@ -189,10 +393,42 @@ pub async fn upload_to_gcs( bucket = gcs_bucket, object = object_name, size = uploaded.size, - "Successfully uploaded file to GCS" + generation = uploaded.generation, + "Uploaded file to GCS" ); + + manifest.files.push(ManifestFileEntry { + filename: file_name.to_string(), + generation: uploaded.generation, + size_bytes: size, + }); } + // Upload the manifest file + let manifest_path = match &gcs_prefix { + Some(p) => format!("{}/manifest.json", p), + None => "manifest.json".to_string(), + }; + let manifest_json = serde_json::to_string_pretty(&manifest)?; + + let upload_type = UploadType::Simple(Media::new(manifest_path.clone())); + client + .upload_object( + &UploadObjectRequest { + bucket: gcs_bucket.clone(), + ..Default::default() + }, + manifest_json.into_bytes(), + &upload_type, + ) + .await?; + + info!( + bucket = gcs_bucket, + object = manifest_path, + "Uploaded manifest to GCS" + ); + info!( "Upload to GCS complete at gs://{}/{}", gcs_bucket, @@ -201,12 +437,8 @@ pub async fn upload_to_gcs( tx_checkpoint .send(model::Checkpoint::Gcs(GcsRepo { - bucket: FixedString::from_str_truncated(&format!( - "gs://{}/{}", - gcs_bucket, - gcs_prefix.as_deref().unwrap_or("") - )), - prefix: Some(FixedString::from_str_truncated(&format!("step-{step}"))), + bucket: FixedString::from_str_truncated(&gcs_bucket), + prefix: gcs_prefix.map(|p| FixedString::from_str_truncated(&p)), })) .map_err(|_| UploadError::SendCheckpoint)?; diff --git a/shared/data-provider/src/lib.rs b/shared/data-provider/src/lib.rs index a35b73dae..0044d77d2 100644 --- a/shared/data-provider/src/lib.rs +++ b/shared/data-provider/src/lib.rs @@ -18,7 +18,8 @@ pub use dummy::DummyDataProvider; pub use errors::{DownloadError, UploadError}; pub use file_extensions::{DATA_FILE_EXTENSIONS, PARQUET_EXTENSION}; pub use gcs::{ - GcsUploadInfo, download_model_from_gcs_async, download_model_from_gcs_sync, upload_to_gcs, + GcsCheckpointManifest, GcsManifestMetadata, GcsUploadInfo, ManifestFileEntry, ManifestMetadata, + download_model_from_gcs_async, download_model_from_gcs_sync, upload_to_gcs, }; pub use hub::{ HubUploadInfo, download_dataset_repo_async, download_dataset_repo_sync,