diff --git a/Cargo.lock b/Cargo.lock index 554aaaad8..940f5fa02 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7015,6 +7015,7 @@ dependencies = [ "anyhow", "async-trait", "bytemuck", + "chrono", "clap", "futures", "google-cloud-storage", diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index 3560e0a7a..8f6a9bcca 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -2,6 +2,8 @@ use anyhow::{Error, Result}; use bytemuck::Zeroable; use hf_hub::Repo; use psyche_centralized_shared::{ClientId, ClientToServerMessage, ServerToClientMessage}; +use psyche_client::HubUploadInfo; +use psyche_client::UploadInfo; use psyche_client::{ Client, ClientTUI, ClientTUIState, NC, RunInitConfig, TrainArgs, read_identity_secret_key, }; @@ -29,7 +31,7 @@ pub type TabsData = ::Data; pub enum ToSend { Witness(Box), HealthCheck(HealthChecks), - Checkpoint(model::HubRepo), + Checkpoint(model::Checkpoint), } struct Backend { @@ -67,7 +69,7 @@ impl WatcherBackend for Backend { Ok(()) } - async fn send_checkpoint(&mut self, checkpoint: model::HubRepo) -> Result<()> { + async fn send_checkpoint(&mut self, checkpoint: model::Checkpoint) -> Result<()> { self.tx.send(ToSend::Checkpoint(checkpoint))?; Ok(()) } @@ -173,18 +175,19 @@ impl App { ) -> Result<()> { // sanity checks if let Some(checkpoint_config) = &state_options.checkpoint_config { - if let Some(hub_upload) = &checkpoint_config.hub_upload { + if let Some(UploadInfo::Hub(HubUploadInfo { + hub_repo, + hub_token, + })) = &checkpoint_config.upload_info + { let api = hf_hub::api::tokio::ApiBuilder::new() - .with_token(Some(hub_upload.hub_token.clone())) + .with_token(Some(hub_token.clone())) .build()?; - let repo_api = api.repo(Repo::new( - hub_upload.hub_repo.clone(), - hf_hub::RepoType::Model, - )); + let repo_api = api.repo(Repo::new(hub_repo.clone(), hf_hub::RepoType::Model)); if !repo_api.is_writable().await { anyhow::bail!( "Checkpoint upload repo {} is not writable with the passed API key.", - hub_upload.hub_repo + hub_repo ) } } diff --git a/architectures/centralized/server/src/app.rs b/architectures/centralized/server/src/app.rs index 6707b49bb..402146817 100644 --- a/architectures/centralized/server/src/app.rs +++ b/architectures/centralized/server/src/app.rs @@ -81,7 +81,7 @@ impl psyche_watcher::Backend for ChannelCoordinatorBackend { bail!("Server does not send health checks"); } - async fn send_checkpoint(&mut self, _checkpoint: model::HubRepo) -> Result<()> { + async fn send_checkpoint(&mut self, _checkpoint: model::Checkpoint) -> Result<()> { bail!("Server does not send checkpoints"); } } diff --git a/architectures/centralized/shared/src/protocol.rs b/architectures/centralized/shared/src/protocol.rs index c71c7524f..a96c64b8f 100644 --- a/architectures/centralized/shared/src/protocol.rs +++ b/architectures/centralized/shared/src/protocol.rs @@ -15,7 +15,7 @@ pub enum ClientToServerMessage { Join { run_id: String }, Witness(Box), HealthCheck(HealthChecks), - Checkpoint(model::HubRepo), + Checkpoint(model::Checkpoint), } #[derive(Serialize, Deserialize, Debug, Clone)] diff --git a/architectures/decentralized/solana-client/src/backend.rs b/architectures/decentralized/solana-client/src/backend.rs index 90f94e528..629ead188 100644 --- a/architectures/decentralized/solana-client/src/backend.rs +++ b/architectures/decentralized/solana-client/src/backend.rs @@ -20,7 +20,8 @@ use anchor_client::{ use anyhow::{Context, Result, anyhow}; use futures_util::StreamExt; use psyche_client::IntegrationTestLogMarker; -use psyche_coordinator::{CommitteeProof, Coordinator, HealthChecks, model::HubRepo}; +use psyche_coordinator::model::{self, Checkpoint}; +use psyche_coordinator::{CommitteeProof, Coordinator, HealthChecks}; use psyche_watcher::{Backend as WatcherBackend, OpportunisticData}; use solana_account_decoder_client_types::{UiAccount, UiAccountEncoding}; use solana_transaction_status_client_types::UiTransactionEncoding; @@ -333,7 +334,7 @@ impl SolanaBackend { &self, coordinator_instance: Pubkey, coordinator_account: Pubkey, - repo: HubRepo, + repo: Checkpoint, ) { let user = self.get_payer(); let instruction = instructions::coordinator_checkpoint( @@ -603,7 +604,7 @@ impl WatcherBackend for SolanaBackendRunner Ok(()) } - async fn send_checkpoint(&mut self, checkpoint: HubRepo) -> Result<()> { + async fn send_checkpoint(&mut self, checkpoint: model::Checkpoint) -> Result<()> { self.backend .send_checkpoint(self.instance, self.account, checkpoint); Ok(()) diff --git a/architectures/decentralized/solana-client/src/command/checkpoint.rs b/architectures/decentralized/solana-client/src/command/checkpoint.rs index 73adc2105..3ee098640 100644 --- a/architectures/decentralized/solana-client/src/command/checkpoint.rs +++ b/architectures/decentralized/solana-client/src/command/checkpoint.rs @@ -1,5 +1,6 @@ use anyhow::Result; use clap::Args; +use psyche_coordinator::model::Checkpoint; use psyche_coordinator::model::HubRepo; use psyche_core::FixedString; @@ -45,7 +46,7 @@ pub async fn command_checkpoint_execute( &coordinator_instance, &coordinator_account, &user, - repo, + Checkpoint::Hub(repo), ); let signature = backend .send_and_retry("Checkpoint", &[instruction], &[]) diff --git a/architectures/decentralized/solana-client/src/instructions.rs b/architectures/decentralized/solana-client/src/instructions.rs index 68e169008..b535d54f4 100644 --- a/architectures/decentralized/solana-client/src/instructions.rs +++ b/architectures/decentralized/solana-client/src/instructions.rs @@ -206,7 +206,7 @@ pub fn coordinator_checkpoint( coordinator_instance: &Pubkey, coordinator_account: &Pubkey, user: &Pubkey, - repo: psyche_coordinator::model::HubRepo, + repo: psyche_coordinator::model::Checkpoint, ) -> Instruction { anchor_instruction( psyche_solana_coordinator::ID, diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs index 029b40fad..8d9ddb977 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs @@ -10,7 +10,7 @@ use psyche_coordinator::RunState; use psyche_coordinator::SOLANA_MAX_STRING_LEN; use psyche_coordinator::TickResult; use psyche_coordinator::Witness; -use psyche_coordinator::model::HubRepo; +use psyche_coordinator::model::Checkpoint; use psyche_coordinator::model::Model; use psyche_core::FixedString; use psyche_core::SmallBoolean; @@ -389,7 +389,11 @@ impl CoordinatorInstanceState { self.tick() } - pub fn checkpoint(&mut self, payer: &Pubkey, repo: HubRepo) -> Result<()> { + pub fn checkpoint( + &mut self, + payer: &Pubkey, + repo: Checkpoint, + ) -> Result<()> { // O(n) on clients, reconsider let id = self.clients_state.find_signer(payer)?; let index = self diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs index 657424195..0a041a6e9 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs @@ -21,7 +21,7 @@ use psyche_coordinator::Witness; use psyche_coordinator::WitnessBloom; use psyche_coordinator::WitnessMetadata; use psyche_coordinator::WitnessProof; -use psyche_coordinator::model::{HubRepo, Model}; +use psyche_coordinator::model::Model; use psyche_core::MerkleRoot; use serde::Deserialize; use serde::Serialize; @@ -313,7 +313,7 @@ pub mod psyche_solana_coordinator { pub fn checkpoint( ctx: Context, - repo: HubRepo, + repo: psyche_coordinator::model::Checkpoint, ) -> Result<()> { let mut account = ctx.accounts.coordinator_account.load_mut()?; account.increment_nonce(); diff --git a/nix/lib.nix b/nix/lib.nix index 882f17040..37ab6d844 100644 --- a/nix/lib.nix +++ b/nix/lib.nix @@ -36,6 +36,7 @@ let python312 pkg-config perl + cargo-nextest ]; buildInputs = diff --git a/shared/client/src/cli.rs b/shared/client/src/cli.rs index 268ea753a..a4ef145f0 100644 --- a/shared/client/src/cli.rs +++ b/shared/client/src/cli.rs @@ -1,7 +1,9 @@ -use crate::{CheckpointConfig, HubUploadInfo, WandBInfo}; +use crate::{CheckpointConfig, WandBInfo}; +use crate::UploadInfo; use anyhow::{Result, anyhow, bail}; use clap::Args; +use psyche_data_provider::{GcsUploadInfo, HubUploadInfo}; use psyche_eval::tasktype_from_name; use psyche_modeling::Devices; use psyche_network::{DiscoveryMode, RelayKind, SecretKey}; @@ -146,6 +148,14 @@ pub struct TrainArgs { #[clap(long, env)] pub hub_repo: Option, + /// Name of the GCS bucket containing model data and configuration. + #[clap(long, env)] + pub gcs_bucket: Option, + + /// Prefix within the GCS bucket for model data and configuration. + #[clap(long, env)] + pub gcs_prefix: Option, + #[clap(long, env, default_value_t = 3)] pub hub_max_concurrent_downloads: usize, @@ -224,43 +234,72 @@ impl TrainArgs { pub fn checkpoint_config(&self) -> Result> { let hub_read_token = std::env::var("HF_TOKEN").ok(); - let checkpoint_upload_info = match ( - &hub_read_token, - self.hub_repo.clone(), - self.checkpoint_dir.clone(), - self.delete_old_steps, - self.keep_steps, - ) { - (Some(token), Some(repo), Some(dir), delete_old_steps, keep_steps) => { - if keep_steps == 0 { - bail!("keep_steps must be >= 1 for hub repository uploads (got {keep_steps})") + + if self.hub_repo.is_some() && self.gcs_bucket.is_some() { + bail!("Use either GCS or HF hub for checkpoint uploads, not both."); + } + + let checkpoint_dir = match &self.checkpoint_dir { + Some(dir) => dir, + None => { + if self.hub_repo.is_some() || self.gcs_bucket.is_some() { + bail!( + "--hub-repo or --gcs-bucket was set, but no --checkpoint-dir was passed!" + ); } - Some(CheckpointConfig { - checkpoint_dir: dir, - hub_upload: Some(HubUploadInfo { - hub_repo: repo, - hub_token: token.to_string(), - }), - delete_old_steps, - keep_steps, - }) - } - (None, Some(_), Some(_), _, _) => { - bail!("hub-repo and checkpoint-dir set, but no HF_TOKEN env variable.") - } - (_, Some(_), None, _, _) => { - bail!("--hub-repo was set, but no --checkpoint-dir was passed!") + return Ok(None); } - (_, None, Some(dir), delete_old_steps, keep_steps) => Some(CheckpointConfig { - checkpoint_dir: dir, - hub_upload: None, - delete_old_steps, - keep_steps, - }), - (_, None, _, _, _) => None, }; - Ok(checkpoint_upload_info) + let upload_info = self.build_upload_info(&hub_read_token)?; + + if upload_info.is_some() && self.keep_steps == 0 { + bail!( + "keep_steps must be >= 1 for checkpoint uploads (got {})", + self.keep_steps + ); + } + + Ok(Some(CheckpointConfig { + checkpoint_dir: checkpoint_dir.clone(), + upload_info, + delete_old_steps: self.delete_old_steps, + keep_steps: self.keep_steps, + })) + } + + fn build_upload_info(&self, hub_token: &Option) -> Result> { + if let Some(repo) = &self.hub_repo { + return self.build_hub_upload_info(repo, hub_token); + } + + if let Some(bucket) = &self.gcs_bucket { + return self.build_gcs_upload_info(bucket); + } + + Ok(None) + } + + fn build_hub_upload_info( + &self, + repo: &str, + token: &Option, + ) -> Result> { + let token = token.as_ref().ok_or_else(|| { + anyhow::anyhow!("hub-repo and checkpoint-dir set, but no HF_TOKEN env variable.") + })?; + + Ok(Some(UploadInfo::Hub(HubUploadInfo { + hub_repo: repo.to_string(), + hub_token: token.to_string(), + }))) + } + + fn build_gcs_upload_info(&self, bucket: &str) -> Result> { + Ok(Some(UploadInfo::Gcs(GcsUploadInfo { + gcs_bucket: bucket.to_string(), + gcs_prefix: self.gcs_prefix.clone(), + }))) } pub fn eval_tasks(&self) -> Result> { diff --git a/shared/client/src/lib.rs b/shared/client/src/lib.rs index f1c545ed7..8a0086932 100644 --- a/shared/client/src/lib.rs +++ b/shared/client/src/lib.rs @@ -10,7 +10,8 @@ pub use cli::{TrainArgs, prepare_environment, print_identity_keys, read_identity pub use client::Client; pub use protocol::{Broadcast, BroadcastType, Finished, NC, TrainingResult}; pub use state::{ - CheckpointConfig, HubUploadInfo, InitRunError, RoundState, RunInitConfig, RunInitConfigAndIO, + CheckpointConfig, GcsUploadInfo, HubUploadInfo, InitRunError, RoundState, RunInitConfig, + RunInitConfigAndIO, UploadInfo, }; pub use testing::IntegrationTestLogMarker; pub use tui::{ClientTUI, ClientTUIState}; diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index 6e2efb6ef..ff322e298 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -1,14 +1,14 @@ -use crate::HubUploadInfo; - +use crate::UploadInfo; use psyche_coordinator::{ Coordinator, - model::{self, HubRepo}, + model::{self}, }; -use psyche_core::{FixedString, NodeIdentity}; -use psyche_data_provider::{UploadModelError, upload_model_repo_async}; +use psyche_core::NodeIdentity; +use psyche_data_provider::{GcsManifestMetadata, UploadError, upload_to_gcs, upload_to_hub}; +#[cfg(feature = "python")] +use psyche_modeling::CausalLM; use psyche_modeling::{ - CausalLM, SaveSafetensorsError, Trainer, TrainerThreadCommunicationError, - save_tensors_into_safetensors, + SaveSafetensorsError, Trainer, TrainerThreadCommunicationError, save_tensors_into_safetensors, }; use std::{ cmp::Reverse, @@ -42,7 +42,7 @@ pub enum CooldownError { } pub struct CooldownStepMetadata { - tx_checkpoint: mpsc::UnboundedSender, + tx_checkpoint: mpsc::UnboundedSender, tx_model: mpsc::UnboundedSender>, checkpoint_info: Option, checkpoint_extra_files: Vec, @@ -59,7 +59,7 @@ pub struct CooldownStepMetadata { impl CooldownStepMetadata { pub fn new( - tx_checkpoint: mpsc::UnboundedSender, + tx_checkpoint: mpsc::UnboundedSender, tx_model: mpsc::UnboundedSender>, checkpoint_info: Option, checkpoint_extra_files: Vec, @@ -93,8 +93,8 @@ pub enum CheckpointError { #[error("Writing extra file to disk failed: {0}")] WriteExtraFile(#[from] tokio::io::Error), - #[error("Couldn't upload model to huggingface: {0}")] - UploadError(#[from] UploadModelError), + #[error("Couldn't upload model to huggingface or GCS: {0}")] + UploadError(#[from] UploadError), #[error("Couldn't send checkpoint - channel closed")] SendCheckpoint, @@ -137,6 +137,7 @@ impl CooldownStepMetadata { let step = state.progress.step - 1; let run_id = String::from(&state.run_id); + let epoch = state.progress.epoch as u32; let checkpoint_extra_files = self.checkpoint_extra_files.clone(); let checkpoint_info = self.checkpoint_info.clone(); let tx_checkpoint = self.tx_checkpoint.clone(); @@ -165,103 +166,51 @@ impl CooldownStepMetadata { .send(variables_clone) .map_err(|_| CheckpointError::SendCheckpoint)?; - let (variables, trainer) = if checkpoint_info.is_some() { - // convert from internal shape to serialized shape (e.g. torchtitan to hf) - tokio::task::spawn_blocking(|| (trainer.convert(Some(variables)), trainer)) - .await - .map_err(|_| CheckpointError::ExtractThreadCrashed)? - } else { - (variables, trainer) + // convert from internal shape to serialized shape (e.g. torchtitan to hf) + let (variables, trainer) = match trainer { + #[cfg(feature = "python")] + Trainer::PythonDistributed(_) => { + info!("Converting distributed trainer variables for checkpointing..."); + tokio::task::spawn_blocking(|| (trainer.convert(Some(variables)), trainer)) + .await + .map_err(|_| CheckpointError::ExtractThreadCrashed)? + } + _ => (variables, trainer), }; trainers.push(trainer); let evals = model_task_runner.start(trainers); let Some(CheckpointConfig { - hub_upload, + upload_info, checkpoint_dir, delete_old_steps, keep_steps, }) = checkpoint_info else { - // If there was no HF checkpointing configuration, return immediately return Ok((evals, None)); }; - // Start the upload process of the updated model parameters in a separate task let upload_handle = tokio::task::spawn(async move { let path = checkpoint_dir.join(format!("{run_id}-step{step}")); - info!("Saving to {}", path.display()); - let mut local = tokio::task::spawn_blocking({ - let path = path.clone(); - move || save_tensors_into_safetensors(variables, path) - }) - .await - .map_err(|_| CheckpointError::WriteThreadCrashed)??; - - for extra in checkpoint_extra_files { - let to = path.join(extra.file_name().unwrap()); - tokio::fs::copy(extra.clone(), to.clone()) - .await - .map_err(CheckpointError::WriteExtraFile)?; - local.push(to); + let local = + save_checkpoint_locally(path, variables, checkpoint_extra_files).await?; + + if let Some(upload_info) = upload_info { + let manifest_metadata = GcsManifestMetadata { + epoch, + run_id: run_id.clone(), + }; + upload_checkpoint( + upload_info, + manifest_metadata, + local.clone(), + step as u64, + tx_checkpoint, + ) + .await?; } - let Some(HubUploadInfo { - hub_repo, - hub_token, - }) = hub_upload - else { - cleanup_dirs( - delete_queue, - keep_steps, - run_id, - delete_old_steps, - step, - checkpoint_dir, - ) - .await; - return Ok::<(), CheckpointError>(()); - }; - - info!(repo = hub_repo, "Uploading checkpoint to HuggingFace"); - let revision = match upload_model_repo_async( - hub_repo.clone(), - local, - hub_token.clone(), - Some(format!("step {step}")), - None, - ) - .await - { - Ok(revision) => { - info!( - repo = hub_repo, - revision = revision, - "Upload to HuggingFace complete" - ); - revision - } - Err(err) => { - error!(repo = hub_repo, "Error uploading to HuggingFace: {err:#}"); - return Err(err.into()); - } - }; - - tx_checkpoint - .send(HubRepo { - repo_id: FixedString::from_str_truncated(&hub_repo), - revision: Some(FixedString::from_str_truncated(&revision)), - }) - .map_err(|_| CheckpointError::SendCheckpoint)?; - - // we put the cleanup step at the end, so that if keep_steps == 0 the logic will still work - // we'll just delete the dir after we've uploaded it - // if we fail in any of the above steps we may wind up not queueing this dir for delete - // but that's probably better than risking having the dir deleted from under us - // for a relatively low priority disk cleanup task - // and this may actually be preferred anyway because if we failed to upload, we may want to keep - // the data around locally on disk cleanup_dirs( delete_queue, keep_steps, @@ -279,12 +228,56 @@ impl CooldownStepMetadata { } .instrument(info_span!("checkpointing")), ); + Ok(CooldownStep { checkpointing_and_evals, }) } } +async fn save_checkpoint_locally( + path: PathBuf, + variables: HashMap, + checkpoint_extra_files: Vec, +) -> Result, CheckpointError> { + info!("Saving to {}", path.display()); + let mut local = tokio::task::spawn_blocking({ + let path = path.clone(); + move || save_tensors_into_safetensors(variables, path) + }) + .await + .map_err(|_| CheckpointError::WriteThreadCrashed)??; + + for extra in checkpoint_extra_files { + let to = path.join(extra.file_name().unwrap()); + tokio::fs::copy(extra.clone(), to.clone()) + .await + .map_err(CheckpointError::WriteExtraFile)?; + local.push(to); + } + + Ok(local) +} + +async fn upload_checkpoint( + upload_info: UploadInfo, + manifest_metadata: GcsManifestMetadata, + local: Vec, + step: u64, + tx_checkpoint: mpsc::UnboundedSender, +) -> Result<(), CheckpointError> { + match upload_info { + UploadInfo::Gcs(gcs_info) => { + upload_to_gcs(gcs_info, manifest_metadata, local, step, tx_checkpoint) + .await + .map_err(CheckpointError::UploadError) + } + UploadInfo::Hub(hub_info) => upload_to_hub(hub_info, local, step, tx_checkpoint) + .await + .map_err(CheckpointError::UploadError), + } +} + type CheckpointAndEvalsHandle = JoinHandle< Result< ( diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index 54073d3b2..36e19db90 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -5,9 +5,9 @@ use psyche_coordinator::{ }; use psyche_core::{Barrier, CancellableBarrier, NodeIdentity, Shuffle, TokenSize}; use psyche_data_provider::{ - DataProvider, DataProviderTcpClient, DummyDataProvider, GcsError, PreprocessedDataProvider, - Split, WeightedDataProvider, download_dataset_repo_async, download_model_from_gcs_async, - download_model_repo_async, + DataProvider, DataProviderTcpClient, DownloadError, DummyDataProvider, + PreprocessedDataProvider, Split, WeightedDataProvider, download_dataset_repo_async, + download_model_from_gcs_async, download_model_repo_async, http::{FileURLs, HttpDataProvider}, }; use psyche_metrics::ClientMetrics; @@ -91,7 +91,7 @@ pub enum InitRunError { HfModelLoad(#[from] hf_hub::api::tokio::ApiError), #[error("failed to download model from GCS: {0}")] - GcsModelLoad(#[from] GcsError), + GcsModelLoad(#[from] DownloadError), #[error("model loading thread crashed")] ModelLoadingThreadCrashed(JoinError), @@ -154,7 +154,7 @@ pub struct RunInitConfigAndIO { pub tx_health_check: UnboundedSender>, pub tx_witness: UnboundedSender, - pub tx_checkpoint: UnboundedSender, + pub tx_checkpoint: UnboundedSender, pub tx_model: UnboundedSender>, pub tx_parameters_req: UnboundedSender<(Vec, OneshotModelParameterSender)>, pub tx_config: UnboundedSender<(String, String)>, diff --git a/shared/client/src/state/mod.rs b/shared/client/src/state/mod.rs index e4d73d5b9..78e6cd1eb 100644 --- a/shared/client/src/state/mod.rs +++ b/shared/client/src/state/mod.rs @@ -14,6 +14,7 @@ mod warmup; mod witness; pub use init::{InitRunError, RunInitConfig, RunInitConfigAndIO}; +pub use psyche_data_provider::{GcsUploadInfo, HubUploadInfo}; pub use round_state::RoundState; pub use steps::{ApplyMessageOutcome, RunManager}; -pub use types::{CheckpointConfig, DistroBroadcastAndPayload, FinishedBroadcast, HubUploadInfo}; +pub use types::{CheckpointConfig, DistroBroadcastAndPayload, FinishedBroadcast, UploadInfo}; diff --git a/shared/client/src/state/types.rs b/shared/client/src/state/types.rs index 2edf22760..29734f1a0 100644 --- a/shared/client/src/state/types.rs +++ b/shared/client/src/state/types.rs @@ -2,6 +2,7 @@ use std::path::PathBuf; use psyche_coordinator::CommitteeProof; use psyche_core::{BatchId, MerkleRoot, NodeIdentity}; +use psyche_data_provider::{GcsUploadInfo, HubUploadInfo}; use psyche_modeling::DistroResult; use psyche_network::{BlobTicket, TransmittableDistroResult}; use tch::TchError; @@ -9,14 +10,14 @@ use thiserror::Error; use tokio::task::JoinHandle; #[derive(Debug, Clone)] -pub struct HubUploadInfo { - pub hub_repo: String, - pub hub_token: String, +pub enum UploadInfo { + Hub(HubUploadInfo), + Gcs(GcsUploadInfo), } #[derive(Debug, Clone)] pub struct CheckpointConfig { - pub hub_upload: Option, + pub upload_info: Option, pub checkpoint_dir: PathBuf, pub delete_old_steps: bool, pub keep_steps: u32, diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index 0a3a9975d..5b10bfcbd 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -1,6 +1,6 @@ use crate::{ Commitment, Committee, CommitteeProof, CommitteeSelection, WitnessProof, - model::{Checkpoint, HubRepo, Model}, + model::{Checkpoint, Model}, }; use anchor_lang::{AnchorDeserialize, AnchorSerialize, InitSpace, prelude::borsh}; @@ -596,22 +596,43 @@ impl Coordinator { &mut self, from: &T, index: u64, - hub_repo: HubRepo, + checkpoint_repo: Checkpoint, ) -> std::result::Result<(), CoordinatorError> { let index = index as usize; if index >= self.epoch_state.clients.len() || self.epoch_state.clients[index].id != *from { return Err(CoordinatorError::InvalidCommitteeProof); } - // TODO: In the case of more than one checkpointer, this will overwrite the hub repo - // with the last checkpointed one. We could instead have a vector of hub repos to have + + // TODO: In the case of more than one checkpointer, this will overwrite the checkpoint + // with the last checkpointed one. We could instead have a vector of checkpoints to have // more download options. - match &mut self.model { - Model::LLM(llm) => match llm.checkpoint { - Checkpoint::P2P(_) => llm.checkpoint = Checkpoint::P2P(hub_repo), - Checkpoint::Hub(_) => llm.checkpoint = Checkpoint::Hub(hub_repo), - _ => {} - }, + let Model::LLM(llm) = &mut self.model; + match (&llm.checkpoint, checkpoint_repo) { + // If current is P2P, wrap the new checkpoint in P2P + (Checkpoint::P2P(_), Checkpoint::Hub(hub_repo)) => { + llm.checkpoint = Checkpoint::P2P(hub_repo); + } + (Checkpoint::P2PGcs(_), Checkpoint::Gcs(gcs_repo)) => { + llm.checkpoint = Checkpoint::P2PGcs(gcs_repo); + } + // If current is Hub, only accept Hub updates + (Checkpoint::Hub(_), Checkpoint::Hub(hub_repo)) => { + llm.checkpoint = Checkpoint::Hub(hub_repo); + } + // If current is Gcs, only accept Gcs updates + (Checkpoint::Gcs(_), Checkpoint::Gcs(gcs_repo)) => { + llm.checkpoint = Checkpoint::Gcs(gcs_repo); + } + (Checkpoint::P2PGcs(_), Checkpoint::Hub(hub_repo)) => { + llm.checkpoint = Checkpoint::P2P(hub_repo); + } + (Checkpoint::P2P(_), Checkpoint::Gcs(gcs_repo)) => { + llm.checkpoint = Checkpoint::P2PGcs(gcs_repo); + } + // Ignore other combinations + _ => {} } + Ok(()) } diff --git a/shared/data-provider/Cargo.toml b/shared/data-provider/Cargo.toml index ea0e59791..7ad69e616 100644 --- a/shared/data-provider/Cargo.toml +++ b/shared/data-provider/Cargo.toml @@ -27,6 +27,8 @@ postcard.workspace = true bytemuck.workspace = true reqwest = "0.12.12" google-cloud-storage = "0.24.0" +chrono = { version = "0.4", features = ["serde"] } +serde_json.workspace = true ts-rs.workspace = true rayon.workspace = true @@ -38,4 +40,3 @@ test-log.workspace = true clap.workspace = true tempfile = "3.15.0" static-web-server = { git = "https://github.com/arilotter/static-web-server", rev = "c91445427b56c5ddff0365d8ec116e3b567377ac" } # forked to add a channel for getting the port -serde_json.workspace = true diff --git a/shared/data-provider/src/errors.rs b/shared/data-provider/src/errors.rs new file mode 100644 index 000000000..b84bc5f9a --- /dev/null +++ b/shared/data-provider/src/errors.rs @@ -0,0 +1,53 @@ +use std::path::PathBuf; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum UploadError { + #[error("path {0} is not a file")] + NotAFile(PathBuf), + + #[error("file {0} doesn't have a valid utf-8 representation")] + InvalidFilename(PathBuf), + + #[error("failed to send checkpoint notification")] + SendCheckpoint, + + // Hub-specific errors + #[error("failed to connect to HF hub: {0}")] + HfHub(#[from] hf_hub::api::tokio::ApiError), + + #[error("failed to commit files: {0}")] + Commit(#[from] hf_hub::api::tokio::CommitError), + + // GCS-specific errors + #[error("GCS authentication failed: {0}")] + GcsAuth(#[from] google_cloud_storage::client::google_cloud_auth::error::Error), + + #[error("GCS operation failed: {0}")] + GcsStorage(#[from] google_cloud_storage::http::Error), + + // Common errors + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + #[error("JSON error: {0}")] + Json(#[from] serde_json::Error), +} + +#[derive(Error, Debug)] +pub enum DownloadError { + #[error("failed to connect to HF hub: {0}")] + HfHub(#[from] hf_hub::api::tokio::ApiError), + + #[error("GCS authentication failed: {0}")] + GcsAuth(#[from] google_cloud_storage::client::google_cloud_auth::error::Error), + + #[error("GCS operation failed: {0}")] + GcsStorage(#[from] google_cloud_storage::http::Error), + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + #[error("JSON error: {0}")] + Json(#[from] serde_json::Error), +} diff --git a/shared/data-provider/src/gcs.rs b/shared/data-provider/src/gcs.rs index cb6848c1d..b79a850b1 100644 --- a/shared/data-provider/src/gcs.rs +++ b/shared/data-provider/src/gcs.rs @@ -1,48 +1,115 @@ +use crate::errors::{DownloadError, UploadError}; +use chrono::{DateTime, Utc}; use google_cloud_storage::client::{Client, ClientConfig}; +use google_cloud_storage::http::objects::upload::Media; +use google_cloud_storage::http::objects::upload::UploadObjectRequest; +use google_cloud_storage::http::objects::upload::UploadType; use google_cloud_storage::http::objects::{ download::Range, get::GetObjectRequest, list::ListObjectsRequest, }; -use std::path::PathBuf; -use thiserror::Error; +use psyche_coordinator::model::{self, GcsRepo}; +use psyche_core::FixedString; +use serde::{Deserialize, Serialize}; +use std::path::{Path, PathBuf}; use tokio::runtime::Runtime; +use tokio::sync::mpsc; use tracing::info; -#[derive(Debug, Error)] -pub enum GcsError { - #[error("GCS authentication failed: {0}")] - Auth(#[from] google_cloud_storage::client::google_cloud_auth::error::Error), +/// Checkpoint manifest.json uploaded to GCS alongside safetensors files. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GcsCheckpointManifest { + pub metadata: ManifestMetadata, + pub files: Vec, +} - #[error("GCS operation failed: {0}")] - Storage(#[from] google_cloud_storage::http::Error), +/// Checkpoint metadata. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ManifestMetadata { + pub timestamp: DateTime, + pub epoch: u32, + pub step: u32, + pub run_id: String, +} - #[error("IO error: {0}")] - Io(#[from] std::io::Error), +/// Single file entry in the manifest. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ManifestFileEntry { + pub filename: String, + pub generation: i64, + pub size_bytes: u64, } -const MODEL_EXTENSIONS: [&str; 3] = [".safetensors", ".json", ".py"]; +#[derive(Debug, Clone)] +pub struct GcsUploadInfo { + pub gcs_bucket: String, + pub gcs_prefix: Option, +} -fn check_model_extension(filename: &str) -> bool { - MODEL_EXTENSIONS.iter().any(|ext| filename.ends_with(ext)) +#[derive(Debug, Clone)] +pub struct GcsManifestMetadata { + pub epoch: u32, + pub run_id: String, } -fn get_cache_dir(bucket: &str, prefix: Option<&str>) -> PathBuf { - let base = std::env::var("HOME") - .map(|h| PathBuf::from(h).join(".cache")) - .unwrap_or_else(|_| PathBuf::from(".cache")) +const MODEL_EXTENSIONS: [&str; 3] = [".safetensors", ".json", ".py"]; + +fn get_cache_base(bucket: &str) -> PathBuf { + // Use HF_HOME if set, otherwise fall back to ~/.cache + std::env::var("HF_HOME") + .map(PathBuf::from) + .unwrap_or_else(|_| { + std::env::var("HOME") + .map(|h| PathBuf::from(h).join(".cache")) + .unwrap_or_else(|_| PathBuf::from(".cache")) + }) .join("psyche") .join("gcs") - .join(bucket); + .join(bucket) +} + +fn get_cache_dir( + bucket: &str, + prefix: Option<&str>, + step: u32, + manifest_generation: i64, +) -> PathBuf { + let base = get_cache_base(bucket); + let versioned_folder = format!("step-{}-{}", step, manifest_generation); match prefix { - Some(p) => base.join(p.trim_end_matches('/')), - None => base, + Some(p) => base.join(p.trim_end_matches('/')).join(versioned_folder), + None => base.join(versioned_folder), } } +fn get_cache_dir_no_manifest(bucket: &str, prefix: Option<&str>) -> PathBuf { + let base = get_cache_base(bucket); + + match prefix { + Some(p) => base.join(p.trim_end_matches('/')).join("no_manifest"), + None => base.join("no_manifest"), + } +} + +fn collect_cached_files( + cache_dir: &Path, + manifest: &GcsCheckpointManifest, +) -> Option> { + let mut files = Vec::new(); + for file_entry in &manifest.files { + let path = cache_dir.join(&file_entry.filename); + if !path.exists() { + return None; + } + files.push(path); + } + Some(files) +} + pub async fn download_model_from_gcs_async( bucket: &str, prefix: Option<&str>, -) -> Result, GcsError> { +) -> Result, DownloadError> { // Use authenticated client if GOOGLE_APPLICATION_CREDENTIALS is set, otherwise anonymous let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { info!("Using authenticated GCS client"); @@ -53,7 +120,131 @@ pub async fn download_model_from_gcs_async( }; let client = Client::new(config); - // List all objects in the bucket with optional prefix + let manifest_object_path = match prefix { + Some(p) => format!("{}/manifest.json", p), + None => "manifest.json".to_string(), + }; + + // Get manifest metadata to obtain generation number + let manifest_metadata = client + .get_object(&GetObjectRequest { + bucket: bucket.to_owned(), + object: manifest_object_path.clone(), + ..Default::default() + }) + .await; + + match manifest_metadata { + Ok(object_meta) => { + let manifest_generation = object_meta.generation; + + // Download manifest content + let manifest_data = client + .download_object( + &GetObjectRequest { + bucket: bucket.to_owned(), + object: manifest_object_path, + ..Default::default() + }, + &Range::default(), + ) + .await?; + + let manifest: GcsCheckpointManifest = serde_json::from_slice(&manifest_data)?; + + info!( + "Found manifest: step {}, epoch {}, generation {}", + manifest.metadata.step, manifest.metadata.epoch, manifest_generation + ); + + // Build versioned cache path + let cache_dir = + get_cache_dir(bucket, prefix, manifest.metadata.step, manifest_generation); + + // Check if all manifest files exist in cache + let mut files = if let Some(cached) = collect_cached_files(&cache_dir, &manifest) { + info!("Using cached checkpoint at {:?}", cache_dir); + cached + } else { + info!( + "Model not found in cache, downloading checkpoint to {:?}", + cache_dir + ); + std::fs::create_dir_all(&cache_dir)?; + download_files_from_manifest(&client, bucket, prefix, &cache_dir, &manifest).await? + }; + // Download config files (json, py) - skips if already cached + let config_files = + download_files_no_manifest(&client, bucket, prefix, &cache_dir, &[".json", ".py"]) + .await?; + files.extend(config_files); + Ok(files) + } + Err(_) => { + // Fallback for old checkpoints without manifest + info!("No manifest found, downloading model without manifest"); + let cache_dir = get_cache_dir_no_manifest(bucket, prefix); + std::fs::create_dir_all(&cache_dir)?; + download_files_no_manifest(&client, bucket, prefix, &cache_dir, &MODEL_EXTENSIONS).await + } + } +} + +async fn download_files_from_manifest( + client: &Client, + bucket: &str, + prefix: Option<&str>, + cache_dir: &Path, + manifest: &GcsCheckpointManifest, +) -> Result, DownloadError> { + let mut downloaded_files = Vec::new(); + + for file_entry in &manifest.files { + let object_name = match prefix { + Some(p) => format!("{}/{}", p, file_entry.filename), + None => file_entry.filename.clone(), + }; + let local_path = cache_dir.join(&file_entry.filename); + + if local_path.exists() { + info!("Using cached: {}", file_entry.filename); + downloaded_files.push(local_path); + continue; + } + + info!( + "Downloading: gs://{}/{} (generation {})", + bucket, object_name, file_entry.generation + ); + + let data = client + .download_object( + &GetObjectRequest { + bucket: bucket.to_owned(), + object: object_name, + ..Default::default() + }, + &Range::default(), + ) + .await?; + + std::fs::write(&local_path, &data)?; + info!("Downloaded: {} ({} bytes)", file_entry.filename, data.len()); + downloaded_files.push(local_path); + } + + Ok(downloaded_files) +} + +/// Download model files by listing the bucket. Skips files that already exist in cache. +/// Used for initial model download (no manifest) and to fetch config files (json, py) after manifest download. +async fn download_files_no_manifest( + client: &Client, + bucket: &str, + prefix: Option<&str>, + cache_dir: &Path, + extensions: &[&str], +) -> Result, DownloadError> { let mut all_objects = vec![]; let mut page_token: Option = None; @@ -68,7 +259,7 @@ pub async fn download_model_from_gcs_async( .await?; for obj in results.items.iter().flatten() { - if check_model_extension(&obj.name) { + if extensions.iter().any(|ext| obj.name.ends_with(ext)) { all_objects.push(obj.name.clone()); } } @@ -80,33 +271,27 @@ pub async fn download_model_from_gcs_async( } info!( - "Found {} model files in gs://{}/{}", + "Found {} files ({}) in gs://{}/{}", all_objects.len(), + extensions.join(", "), bucket, prefix.unwrap_or("") ); - let cache_dir = get_cache_dir(bucket, prefix); - std::fs::create_dir_all(&cache_dir)?; - let mut downloaded_files = Vec::new(); for object_name in all_objects { - // Get just the filename (strip prefix if present) let filename = object_name.rsplit('/').next().unwrap_or(&object_name); - let local_path = cache_dir.join(filename); - // Skip if already cached if local_path.exists() { info!("Using cached: {}", filename); downloaded_files.push(local_path); continue; } - info!("Downloading: {}", object_name); + info!("Downloading: gs://{}/{}", bucket, object_name); - // Download the object let data = client .download_object( &GetObjectRequest { @@ -132,7 +317,130 @@ pub async fn download_model_from_gcs_async( pub fn download_model_from_gcs_sync( bucket: &str, prefix: Option<&str>, -) -> Result, GcsError> { - let rt = Runtime::new().map_err(GcsError::Io)?; +) -> Result, DownloadError> { + let rt = Runtime::new().map_err(DownloadError::Io)?; rt.block_on(download_model_from_gcs_async(bucket, prefix)) } + +pub async fn upload_to_gcs( + gcs_info: GcsUploadInfo, + manifest_metadata: GcsManifestMetadata, + local: Vec, + step: u64, + tx_checkpoint: mpsc::UnboundedSender, +) -> Result<(), UploadError> { + let GcsUploadInfo { + gcs_bucket, + gcs_prefix, + } = gcs_info; + + let GcsManifestMetadata { epoch, run_id } = manifest_metadata; + + info!(bucket = gcs_bucket, "Uploading checkpoint to GCS"); + + let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { + info!("Using authenticated GCS client"); + ClientConfig::default().with_auth().await? + } else { + info!("Using anonymous GCS client"); + ClientConfig::default().anonymous() + }; + let client = Client::new(config); + + let mut manifest = GcsCheckpointManifest { + metadata: ManifestMetadata { + timestamp: Utc::now(), + epoch, + step: step as u32, + run_id, + }, + files: Vec::new(), + }; + + for path in local { + let file_name = path + .file_name() + .ok_or_else(|| UploadError::NotAFile(path.clone()))? + .to_str() + .ok_or_else(|| UploadError::InvalidFilename(path.clone()))?; + + // Only upload safetensors files + if !file_name.ends_with(".safetensors") { + continue; + } + + let object_name = match &gcs_prefix { + Some(p) => format!("{}/{}", p, file_name), + None => file_name.to_string(), + }; + + let size = std::fs::metadata(&path)?.len(); + let data = tokio::fs::read(&path).await?; + + let upload_type = UploadType::Simple(Media::new(object_name.clone())); + let uploaded = client + .upload_object( + &UploadObjectRequest { + bucket: gcs_bucket.clone(), + ..Default::default() + }, + data, + &upload_type, + ) + .await?; + + info!( + bucket = gcs_bucket, + object = object_name, + size = uploaded.size, + generation = uploaded.generation, + "Uploaded file to GCS" + ); + + manifest.files.push(ManifestFileEntry { + filename: file_name.to_string(), + generation: uploaded.generation, + size_bytes: size, + }); + } + + // Upload the manifest file + let manifest_path = match &gcs_prefix { + Some(p) => format!("{}/manifest.json", p), + None => "manifest.json".to_string(), + }; + let manifest_json = serde_json::to_string_pretty(&manifest)?; + + let upload_type = UploadType::Simple(Media::new(manifest_path.clone())); + client + .upload_object( + &UploadObjectRequest { + bucket: gcs_bucket.clone(), + ..Default::default() + }, + manifest_json.into_bytes(), + &upload_type, + ) + .await?; + + info!( + bucket = gcs_bucket, + object = manifest_path, + "Uploaded manifest to GCS" + ); + + info!( + "Upload to GCS complete at gs://{}/{}", + gcs_bucket, + gcs_prefix.as_deref().unwrap_or("") + ); + + tx_checkpoint + .send(model::Checkpoint::Gcs(GcsRepo { + bucket: FixedString::from_str_truncated(&gcs_bucket), + prefix: gcs_prefix.map(|p| FixedString::from_str_truncated(&p)), + })) + .map_err(|_| UploadError::SendCheckpoint)?; + + Ok(()) +} diff --git a/shared/data-provider/src/hub.rs b/shared/data-provider/src/hub.rs index ca090a759..13a575b84 100644 --- a/shared/data-provider/src/hub.rs +++ b/shared/data-provider/src/hub.rs @@ -1,12 +1,16 @@ +use crate::errors::UploadError; +use crate::hub::model::HubRepo; use hf_hub::{ Cache, Repo, RepoType, api::{ Siblings, - tokio::{ApiError, CommitError, UploadSource}, + tokio::{ApiError, UploadSource}, }, }; +use psyche_coordinator::model; +use psyche_core::FixedString; use std::{path::PathBuf, time::Instant}; -use thiserror::Error; +use tokio::sync::mpsc; use tracing::{error, info}; const MODEL_EXTENSIONS: [&str; 3] = [".safetensors", ".json", ".py"]; @@ -189,42 +193,39 @@ pub fn download_dataset_repo_sync( ) } -#[derive(Error, Debug)] -pub enum UploadModelError { - #[error("path {0} is not a file")] - NotAFile(PathBuf), - - #[error("file {0} doesn't have a valid utf-8 representation")] - InvalidFilename(PathBuf), +#[derive(Debug, Clone)] +pub struct HubUploadInfo { + pub hub_repo: String, + pub hub_token: String, +} - #[error("failed to connect to HF hub: {0}")] - HfHub(#[from] ApiError), +pub async fn upload_to_hub( + hub_info: HubUploadInfo, + local: Vec, + step: u64, + tx_checkpoint: mpsc::UnboundedSender, +) -> Result<(), UploadError> { + let HubUploadInfo { + hub_repo, + hub_token, + } = hub_info; - #[error("failed to commit files: {0}")] - Commit(#[from] CommitError), -} + info!(repo = hub_repo, "Uploading checkpoint to HuggingFace"); -pub async fn upload_model_repo_async( - repo_id: String, - files: Vec, - token: String, - commit_message: Option, - commit_description: Option, -) -> Result { let api = hf_hub::api::tokio::ApiBuilder::new() - .with_token(Some(token)) + .with_token(Some(hub_token.clone())) .build()?; - let repo = Repo::model(repo_id.clone()); + let repo = Repo::model(hub_repo.clone()); let api_repo = api.repo(repo); - let files: Result, _> = files + let files: Result, _> = local .into_iter() .map(|path| { path.file_name() - .ok_or(UploadModelError::NotAFile(path.clone())) + .ok_or(UploadError::NotAFile(path.clone())) .and_then(|name| { name.to_str() - .ok_or(UploadModelError::InvalidFilename(path.clone())) + .ok_or(UploadError::InvalidFilename(path.clone())) .map(|s| s.to_string()) }) .map(|name| (path.into(), name)) @@ -233,32 +234,32 @@ pub async fn upload_model_repo_async( let files = files?; - let commit_info = match api_repo - .upload_files( - files, - commit_message.clone(), - commit_description.clone(), - false, - ) + let commit_info = api_repo + .upload_files(files, Some(format!("step {step}")), None, false) .await - { - Ok(info) => { - info!( - repo = repo_id, - oid = info.oid, - "Successfully uploaded files to HuggingFace" - ); - info - } - Err(e) => { + .map_err(|e| { error!( - repo = repo_id, + repo = hub_repo, error = ?e, - "Failed to upload files to HuggingFace. Full error details: {:#?}", - e + "Failed to upload files to HuggingFace" ); - return Err(e.into()); - } - }; - Ok(commit_info.oid) + e + })?; + + let revision = commit_info.oid; + + info!( + repo = hub_repo, + revision = revision, + "Upload to HuggingFace complete" + ); + + tx_checkpoint + .send(model::Checkpoint::Hub(HubRepo { + repo_id: FixedString::from_str_truncated(&hub_repo), + revision: Some(FixedString::from_str_truncated(&revision)), + })) + .map_err(|_| UploadError::SendCheckpoint)?; + + Ok(()) } diff --git a/shared/data-provider/src/lib.rs b/shared/data-provider/src/lib.rs index 7b97fe101..0044d77d2 100644 --- a/shared/data-provider/src/lib.rs +++ b/shared/data-provider/src/lib.rs @@ -1,6 +1,7 @@ mod data_provider; mod dataset; mod dummy; +mod errors; mod file_extensions; mod gcs; pub mod http; @@ -14,11 +15,15 @@ mod weighted; pub use data_provider::DataProvider; pub use dataset::{Dataset, Field, Row, Split}; pub use dummy::DummyDataProvider; +pub use errors::{DownloadError, UploadError}; pub use file_extensions::{DATA_FILE_EXTENSIONS, PARQUET_EXTENSION}; -pub use gcs::{GcsError, download_model_from_gcs_async, download_model_from_gcs_sync}; +pub use gcs::{ + GcsCheckpointManifest, GcsManifestMetadata, GcsUploadInfo, ManifestFileEntry, ManifestMetadata, + download_model_from_gcs_async, download_model_from_gcs_sync, upload_to_gcs, +}; pub use hub::{ - UploadModelError, download_dataset_repo_async, download_dataset_repo_sync, - download_model_repo_async, download_model_repo_sync, upload_model_repo_async, + HubUploadInfo, download_dataset_repo_async, download_dataset_repo_sync, + download_model_repo_async, download_model_repo_sync, upload_to_hub, }; pub use local::LocalDataProvider; pub use parquet::record::{ListAccessor, MapAccessor, RowAccessor}; diff --git a/shared/watcher/src/traits.rs b/shared/watcher/src/traits.rs index 1cd96a7ed..50bf317bb 100644 --- a/shared/watcher/src/traits.rs +++ b/shared/watcher/src/traits.rs @@ -27,5 +27,5 @@ pub trait Backend: Send + Sync { async fn wait_for_new_state(&mut self) -> Result>; async fn send_witness(&mut self, opportunistic_data: OpportunisticData) -> Result<()>; async fn send_health_check(&mut self, health_check: HealthChecks) -> Result<()>; - async fn send_checkpoint(&mut self, checkpoint: model::HubRepo) -> Result<()>; + async fn send_checkpoint(&mut self, checkpoint: model::Checkpoint) -> Result<()>; }