From 0836df877b4b781111fabf1ca4517f4f0491f6fc Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Tue, 23 Dec 2025 11:39:15 -0300 Subject: [PATCH 01/72] Make cooldown opportunistic --- .../solana-client/src/backend.rs | 6 ++++ .../solana-coordinator/src/instance_state.rs | 15 +++++++++ shared/client/src/state/cooldown.rs | 13 ++------ shared/client/src/state/steps.rs | 33 +++++++++++++++++++ shared/coordinator/src/coordinator.rs | 18 ++++++++++ shared/core/src/fixed_vec.rs | 12 +++++++ shared/watcher/src/traits.rs | 2 ++ 7 files changed, 89 insertions(+), 10 deletions(-) diff --git a/architectures/decentralized/solana-client/src/backend.rs b/architectures/decentralized/solana-client/src/backend.rs index 90f94e528..55d6d7ada 100644 --- a/architectures/decentralized/solana-client/src/backend.rs +++ b/architectures/decentralized/solana-client/src/backend.rs @@ -307,6 +307,12 @@ impl SolanaBackend { &user, witness, ), + OpportunisticData::CooldownStep(witness) => instructions::coordinator_cooldown_witness( + &coordinator_instance, + &coordinator_account, + &user, + witness, + ), }; self.spawn_scheduled_send("Witness", &[instruction], &[]); } diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs index 029b40fad..831586c30 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs @@ -233,6 +233,21 @@ impl CoordinatorInstanceState { self.tick() } + pub fn cooldown_witness( + &mut self, + payer: &Pubkey, + witness: Witness, + ) -> Result<()> { + let id = self.clients_state.find_signer(payer)?; + + let clock: Clock = Clock::get()?; + self.coordinator + .cooldown_witness(id, witness, clock.unix_timestamp as u64) + .map_err(|err| anchor_lang::error!(ProgramError::from(err)))?; + + self.tick() + } + pub fn warmup_witness( &mut self, payer: &Pubkey, diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index e85120fe9..ced30f006 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -167,17 +167,10 @@ impl CooldownStepMetadata { .send(variables_clone) .map_err(|_| CheckpointError::SendCheckpoint)?; - let Some(CheckpointConfig { - hub_upload, - checkpoint_dir, - delete_old_steps, - keep_steps, - }) = checkpoint_info - else { - // If there was no HF checkpointing configuration, return immediately + if state.epoch_state.checkpointer != T { + info!("Skipping checkpoint upload as this node is not the checkpointer for this epoch"); return Ok((evals, None)); - }; - + } // Start the upload process of the updated model parameters in a separate task let upload_handle = tokio::task::spawn(async move { let path = checkpoint_dir.join(format!("{run_id}-step{step}")); diff --git a/shared/client/src/state/steps.rs b/shared/client/src/state/steps.rs index a43fedcb6..85381b921 100644 --- a/shared/client/src/state/steps.rs +++ b/shared/client/src/state/steps.rs @@ -338,6 +338,39 @@ impl StepStateMachine { /// `get_historical_clients` is what you actually want. pub clients: FixedVec, { SOLANA_MAX_NUM_CLIENTS }>, pub exited_clients: FixedVec, { SOLANA_MAX_NUM_CLIENTS }>, + pub checkpointer: T, pub rounds_head: u32, pub start_step: u32, pub last_step: u32, @@ -411,6 +412,7 @@ impl Default for CoordinatorEpochState { first_round: true.into(), clients: Default::default(), exited_clients: Default::default(), + checkpointer: T::default(), cold_start_epoch: false.into(), start_step: Default::default(), last_step: Default::default(), @@ -612,6 +614,21 @@ impl Coordinator { _ => {} }, } + + if self.halted() { + return Err(CoordinatorError::Halted); + } + + if !matches!(self.run_state, RunState::Cooldown) { + return Err(CoordinatorError::InvalidRunState); + } + + if self.epoch_state.checkpointer != *from { + return Err(CoordinatorError::InvalidWitness); + } else { + self.start_waiting_for_members(unix_timestamp); + } + Ok(()) } @@ -933,6 +950,7 @@ impl Coordinator { ) .unwrap(); + self.epoch_state.checkpointer = self.epoch_state.clients.random().unwrap().id.clone(); self.start_warmup(unix_timestamp); } diff --git a/shared/core/src/fixed_vec.rs b/shared/core/src/fixed_vec.rs index 4058b3e2b..955b06bff 100644 --- a/shared/core/src/fixed_vec.rs +++ b/shared/core/src/fixed_vec.rs @@ -1,6 +1,8 @@ use crate as psyche_core; use anchor_lang::{AnchorDeserialize, AnchorSerialize, prelude::borsh}; use bytemuck::Zeroable; +#[cfg(feature = "rand")] +use rand::Rng; use serde::{Deserialize, Serialize}; use std::ops::{Deref, DerefMut, Range, RangeFrom, RangeFull, RangeTo}; use ts_rs::TS; @@ -144,6 +146,16 @@ impl FixedVec { Ok(()) } + #[cfg(feature = "rand")] + pub fn random(&self) -> Option<&T> { + if self.len == 0 { + return None; + } + let mut rng = rand::rng(); + let random_index = rng.random_range(0..self.len) as u16; + Some(&self.data[random_index as usize]) + } + pub fn retain(&mut self, mut f: F) where F: FnMut(&T) -> bool, diff --git a/shared/watcher/src/traits.rs b/shared/watcher/src/traits.rs index 1cd96a7ed..3d4de39d0 100644 --- a/shared/watcher/src/traits.rs +++ b/shared/watcher/src/traits.rs @@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize}; pub enum OpportunisticData { WitnessStep(Witness, WitnessMetadata), WarmupStep(Witness), + CooldownStep(Witness), } impl OpportunisticData { @@ -15,6 +16,7 @@ impl OpportunisticData { match self { OpportunisticData::WitnessStep(..) => "witness", OpportunisticData::WarmupStep(..) => "warmup", + OpportunisticData::CooldownStep(..) => "cooldown", } } } From 80550e6b54ecc1bbd434117a1289d53829fb01c8 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Tue, 23 Dec 2025 07:19:06 -0800 Subject: [PATCH 02/72] Fix some things --- .../solana-coordinator/src/instance_state.rs | 27 ++++++++++--------- shared/client/src/state/cooldown.rs | 2 +- shared/coordinator/src/coordinator.rs | 5 ++-- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs index 831586c30..05eb20d2c 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs @@ -233,20 +233,20 @@ impl CoordinatorInstanceState { self.tick() } - pub fn cooldown_witness( - &mut self, - payer: &Pubkey, - witness: Witness, - ) -> Result<()> { - let id = self.clients_state.find_signer(payer)?; + // pub fn cooldown_witness( + // &mut self, + // payer: &Pubkey, + // witness: Witness, + // ) -> Result<()> { + // let id = self.clients_state.find_signer(payer)?; - let clock: Clock = Clock::get()?; - self.coordinator - .cooldown_witness(id, witness, clock.unix_timestamp as u64) - .map_err(|err| anchor_lang::error!(ProgramError::from(err)))?; + // let clock: Clock = Clock::get()?; + // self.coordinator + // .cooldown_witness(id, witness, clock.unix_timestamp as u64) + // .map_err(|err| anchor_lang::error!(ProgramError::from(err)))?; - self.tick() - } + // self.tick() + // } pub fn warmup_witness( &mut self, @@ -414,9 +414,10 @@ impl CoordinatorInstanceState { .iter() .position(|x| x.id == *id) .ok_or(ProgramError::SignerNotAClient)?; + let clock = Clock::get()?; self.coordinator - .checkpoint(id, index as u64, repo) + .checkpoint(id, index as u64, repo, clock.unix_timestamp as u64) .map_err(|err| anchor_lang::error!(ProgramError::from(err)))?; // Only tick if not halted (Paused/Uninitialized/Finished) diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index ced30f006..4cdb55189 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -167,7 +167,7 @@ impl CooldownStepMetadata { .send(variables_clone) .map_err(|_| CheckpointError::SendCheckpoint)?; - if state.epoch_state.checkpointer != T { + if state.epoch_state.checkpointer != trainer { info!("Skipping checkpoint upload as this node is not the checkpointer for this epoch"); return Ok((evals, None)); } diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index 5b16b437b..8ba070e5d 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -274,7 +274,7 @@ pub struct CoordinatorEpochState { /// `get_historical_clients` is what you actually want. pub clients: FixedVec, { SOLANA_MAX_NUM_CLIENTS }>, pub exited_clients: FixedVec, { SOLANA_MAX_NUM_CLIENTS }>, - pub checkpointer: T, + pub checkpointer: Client, pub rounds_head: u32, pub start_step: u32, pub last_step: u32, @@ -599,6 +599,7 @@ impl Coordinator { from: &T, index: u64, hub_repo: HubRepo, + unix_timestamp: u64, ) -> std::result::Result<(), CoordinatorError> { let index = index as usize; if index >= self.epoch_state.clients.len() || self.epoch_state.clients[index].id != *from { @@ -950,7 +951,7 @@ impl Coordinator { ) .unwrap(); - self.epoch_state.checkpointer = self.epoch_state.clients.random().unwrap().id.clone(); + self.epoch_state.checkpointer = self.epoch_state.clients[0].id.clone(); self.start_warmup(unix_timestamp); } From be24bdaa251de4030e4fa392eeab593ea86f8e92 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Tue, 23 Dec 2025 12:59:08 -0300 Subject: [PATCH 03/72] Fix compilation --- architectures/centralized/server/src/app.rs | 9 ++- .../solana-client/src/backend.rs | 12 ++-- shared/client/src/state/cooldown.rs | 14 +++- shared/client/src/state/steps.rs | 68 +++++++++---------- shared/coordinator/src/coordinator.rs | 2 +- shared/watcher/src/traits.rs | 4 +- 6 files changed, 62 insertions(+), 47 deletions(-) diff --git a/architectures/centralized/server/src/app.rs b/architectures/centralized/server/src/app.rs index bc034d1db..99e607945 100644 --- a/architectures/centralized/server/src/app.rs +++ b/architectures/centralized/server/src/app.rs @@ -422,9 +422,12 @@ impl App { .position(|x| x.id == from); match position { Some(index) => { - if let Err(error) = - self.coordinator.checkpoint(&from, index as u64, checkpoint) - { + if let Err(error) = self.coordinator.checkpoint( + &from, + index as u64, + checkpoint, + Self::get_timestamp(), + ) { warn!("Error when processing checkpoint: {error}"); } } diff --git a/architectures/decentralized/solana-client/src/backend.rs b/architectures/decentralized/solana-client/src/backend.rs index 55d6d7ada..f6ba621ba 100644 --- a/architectures/decentralized/solana-client/src/backend.rs +++ b/architectures/decentralized/solana-client/src/backend.rs @@ -307,12 +307,12 @@ impl SolanaBackend { &user, witness, ), - OpportunisticData::CooldownStep(witness) => instructions::coordinator_cooldown_witness( - &coordinator_instance, - &coordinator_account, - &user, - witness, - ), + // OpportunisticData::CooldownStep(witness) => instructions::coordinator_cooldown_witness( + // &coordinator_instance, + // &coordinator_account, + // &user, + // witness, + // ), }; self.spawn_scheduled_send("Witness", &[instruction], &[]); } diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index 4cdb55189..dae6eaaf3 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -129,6 +129,7 @@ impl CooldownStepMetadata { &self, mut trainers: Vec, state: &Coordinator, + from: T, ) -> Result { let Some(mut trainer) = trainers.pop() else { return Err(CooldownError::NoTrainers); @@ -142,6 +143,7 @@ impl CooldownStepMetadata { let tx_model = self.tx_model.clone(); let model_task_runner = self.model_task_runner.clone(); let delete_queue = self.delete_queue.clone(); + let checkpointer: T = state.epoch_state.checkpointer.clone(); let checkpointing_and_evals: CheckpointAndEvalsHandle = tokio::task::spawn( async move { @@ -167,10 +169,20 @@ impl CooldownStepMetadata { .send(variables_clone) .map_err(|_| CheckpointError::SendCheckpoint)?; - if state.epoch_state.checkpointer != trainer { + if checkpointer != from { info!("Skipping checkpoint upload as this node is not the checkpointer for this epoch"); return Ok((evals, None)); } + let Some(CheckpointConfig { + hub_upload, + checkpoint_dir, + delete_old_steps, + keep_steps, + }) = checkpoint_info + else { + // If there was no HF checkpointing configuration, return immediately + return Ok((evals, None)); + }; // Start the upload process of the updated model parameters in a separate task let upload_handle = tokio::task::spawn(async move { let path = checkpoint_dir.join(format!("{run_id}-step{step}")); diff --git a/shared/client/src/state/steps.rs b/shared/client/src/state/steps.rs index 85381b921..70ddb1004 100644 --- a/shared/client/src/state/steps.rs +++ b/shared/client/src/state/steps.rs @@ -338,39 +338,39 @@ impl StepStateMachine StepStateMachine { /// `get_historical_clients` is what you actually want. pub clients: FixedVec, { SOLANA_MAX_NUM_CLIENTS }>, pub exited_clients: FixedVec, { SOLANA_MAX_NUM_CLIENTS }>, - pub checkpointer: Client, + pub checkpointer: T, pub rounds_head: u32, pub start_step: u32, pub last_step: u32, diff --git a/shared/watcher/src/traits.rs b/shared/watcher/src/traits.rs index 3d4de39d0..7755d8331 100644 --- a/shared/watcher/src/traits.rs +++ b/shared/watcher/src/traits.rs @@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize}; pub enum OpportunisticData { WitnessStep(Witness, WitnessMetadata), WarmupStep(Witness), - CooldownStep(Witness), + // CooldownStep(Witness), } impl OpportunisticData { @@ -16,7 +16,7 @@ impl OpportunisticData { match self { OpportunisticData::WitnessStep(..) => "witness", OpportunisticData::WarmupStep(..) => "warmup", - OpportunisticData::CooldownStep(..) => "cooldown", + // OpportunisticData::CooldownStep(..) => "cooldown", } } } From 20ac8a2f6cc7a6c71306b5aa2d96b99e4dd0d49f Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Tue, 23 Dec 2025 14:12:09 -0300 Subject: [PATCH 04/72] End epoch after checkpointer --- shared/coordinator/src/coordinator.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index 742b58119..9ddc7191d 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -281,6 +281,7 @@ pub struct CoordinatorEpochState { pub start_timestamp: u64, pub first_round: SmallBoolean, pub cold_start_epoch: SmallBoolean, + pub checkpointed: bool, } #[derive( @@ -417,6 +418,7 @@ impl Default for CoordinatorEpochState { start_step: Default::default(), last_step: Default::default(), start_timestamp: Default::default(), + checkpointed: false, } } } @@ -627,7 +629,7 @@ impl Coordinator { if self.epoch_state.checkpointer != *from { return Err(CoordinatorError::InvalidWitness); } else { - self.start_waiting_for_members(unix_timestamp); + self.epoch_state.checkpointed = true; } Ok(()) @@ -1048,7 +1050,9 @@ impl Coordinator { &mut self, unix_timestamp: u64, ) -> std::result::Result { - if self.check_timeout(unix_timestamp, self.config.cooldown_time) { + if self.check_timeout(unix_timestamp, self.config.cooldown_time) + || self.epoch_state.checkpointed + { let last_round_batch_size = self.get_target_global_batch_size(self.current_round()); self.progress.epoch_start_data_index = self.current_round_unchecked().data_index + last_round_batch_size as u64; From 514c9936e435a9c65abceb249896ef489b65f113 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Tue, 23 Dec 2025 16:24:55 -0300 Subject: [PATCH 05/72] Make hub_repo checkpoint mandatory --- architectures/centralized/client/src/app.rs | 30 +++++---- architectures/centralized/server/src/app.rs | 9 +-- .../solana-coordinator/src/instance_state.rs | 3 +- shared/client/src/cli.rs | 40 +++++------- shared/client/src/state/cooldown.rs | 63 +++++++------------ shared/client/src/state/init.rs | 2 +- shared/client/src/state/types.rs | 2 +- shared/coordinator/src/coordinator.rs | 3 +- 8 files changed, 59 insertions(+), 93 deletions(-) diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index 3560e0a7a..9e841b7f4 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -172,22 +172,20 @@ impl App { state_options: RunInitConfig, ) -> Result<()> { // sanity checks - if let Some(checkpoint_config) = &state_options.checkpoint_config { - if let Some(hub_upload) = &checkpoint_config.hub_upload { - let api = hf_hub::api::tokio::ApiBuilder::new() - .with_token(Some(hub_upload.hub_token.clone())) - .build()?; - let repo_api = api.repo(Repo::new( - hub_upload.hub_repo.clone(), - hf_hub::RepoType::Model, - )); - if !repo_api.is_writable().await { - anyhow::bail!( - "Checkpoint upload repo {} is not writable with the passed API key.", - hub_upload.hub_repo - ) - } - } + let api = hf_hub::api::tokio::ApiBuilder::new() + .with_token(Some( + state_options.checkpoint_config.hub_upload.hub_token.clone(), + )) + .build()?; + let repo_api = api.repo(Repo::new( + state_options.checkpoint_config.hub_upload.hub_repo.clone(), + hf_hub::RepoType::Model, + )); + if !repo_api.is_writable().await { + anyhow::bail!( + "Checkpoint upload repo {} is not writable with the passed API key.", + state_options.checkpoint_config.hub_upload.hub_repo + ) } self.server_conn diff --git a/architectures/centralized/server/src/app.rs b/architectures/centralized/server/src/app.rs index 99e607945..bc034d1db 100644 --- a/architectures/centralized/server/src/app.rs +++ b/architectures/centralized/server/src/app.rs @@ -422,12 +422,9 @@ impl App { .position(|x| x.id == from); match position { Some(index) => { - if let Err(error) = self.coordinator.checkpoint( - &from, - index as u64, - checkpoint, - Self::get_timestamp(), - ) { + if let Err(error) = + self.coordinator.checkpoint(&from, index as u64, checkpoint) + { warn!("Error when processing checkpoint: {error}"); } } diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs index 05eb20d2c..d875b4992 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs @@ -414,10 +414,9 @@ impl CoordinatorInstanceState { .iter() .position(|x| x.id == *id) .ok_or(ProgramError::SignerNotAClient)?; - let clock = Clock::get()?; self.coordinator - .checkpoint(id, index as u64, repo, clock.unix_timestamp as u64) + .checkpoint(id, index as u64, repo) .map_err(|err| anchor_lang::error!(ProgramError::from(err)))?; // Only tick if not halted (Paused/Uninitialized/Finished) diff --git a/shared/client/src/cli.rs b/shared/client/src/cli.rs index 268ea753a..60e33203e 100644 --- a/shared/client/src/cli.rs +++ b/shared/client/src/cli.rs @@ -139,12 +139,12 @@ pub struct TrainArgs { pub prompt_task: bool, /// If provided, every model parameters update will be save in this directory after each epoch. - #[clap(long, env)] - pub checkpoint_dir: Option, + #[clap(long, env, default_value = "~/.cache/psyche/checkpoints")] + pub checkpoint_dir: PathBuf, /// Path to the Hugging Face repository containing model data and configuration. #[clap(long, env)] - pub hub_repo: Option, + pub hub_repo: String, #[clap(long, env, default_value_t = 3)] pub hub_max_concurrent_downloads: usize, @@ -222,7 +222,7 @@ impl TrainArgs { Ok(wandb_info) } - pub fn checkpoint_config(&self) -> Result> { + pub fn checkpoint_config(&self) -> Result { let hub_read_token = std::env::var("HF_TOKEN").ok(); let checkpoint_upload_info = match ( &hub_read_token, @@ -231,33 +231,21 @@ impl TrainArgs { self.delete_old_steps, self.keep_steps, ) { - (Some(token), Some(repo), Some(dir), delete_old_steps, keep_steps) => { - if keep_steps == 0 { - bail!("keep_steps must be >= 1 for hub repository uploads (got {keep_steps})") - } - Some(CheckpointConfig { - checkpoint_dir: dir, - hub_upload: Some(HubUploadInfo { - hub_repo: repo, - hub_token: token.to_string(), - }), - delete_old_steps, - keep_steps, - }) + (_, _, _, _, keep_steps) if keep_steps == 0 => { + bail!("keep_steps must be > 0 for hub repository uploads (got {keep_steps})") } - (None, Some(_), Some(_), _, _) => { - bail!("hub-repo and checkpoint-dir set, but no HF_TOKEN env variable.") + (None, _, _, _, _) => { + bail!("No HF_TOKEN env variable.") } - (_, Some(_), None, _, _) => { - bail!("--hub-repo was set, but no --checkpoint-dir was passed!") - } - (_, None, Some(dir), delete_old_steps, keep_steps) => Some(CheckpointConfig { + (Some(token), repo, dir, delete_old_steps, keep_steps) => CheckpointConfig { checkpoint_dir: dir, - hub_upload: None, + hub_upload: HubUploadInfo { + hub_repo: repo, + hub_token: token.to_string(), + }, delete_old_steps, keep_steps, - }), - (_, None, _, _, _) => None, + }, }; Ok(checkpoint_upload_info) diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index dae6eaaf3..f330e3da9 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -43,7 +43,7 @@ pub enum CooldownError { pub struct CooldownStepMetadata { tx_checkpoint: mpsc::UnboundedSender, tx_model: mpsc::UnboundedSender>, - checkpoint_info: Option, + checkpoint_info: CheckpointConfig, checkpoint_extra_files: Vec, model_task_runner: ModelTaskRunner, @@ -60,7 +60,7 @@ impl CooldownStepMetadata { pub fn new( tx_checkpoint: mpsc::UnboundedSender, tx_model: mpsc::UnboundedSender>, - checkpoint_info: Option, + checkpoint_info: CheckpointConfig, checkpoint_extra_files: Vec, model_task_runner: ModelTaskRunner, ) -> Self { @@ -143,7 +143,7 @@ impl CooldownStepMetadata { let tx_model = self.tx_model.clone(); let model_task_runner = self.model_task_runner.clone(); let delete_queue = self.delete_queue.clone(); - let checkpointer: T = state.epoch_state.checkpointer.clone(); + let checkpointer: T = state.epoch_state.checkpointer; let checkpointing_and_evals: CheckpointAndEvalsHandle = tokio::task::spawn( async move { @@ -173,51 +173,36 @@ impl CooldownStepMetadata { info!("Skipping checkpoint upload as this node is not the checkpointer for this epoch"); return Ok((evals, None)); } - let Some(CheckpointConfig { + let CheckpointConfig { hub_upload, checkpoint_dir, delete_old_steps, keep_steps, - }) = checkpoint_info - else { - // If there was no HF checkpointing configuration, return immediately - return Ok((evals, None)); - }; + } = checkpoint_info; + // Start the upload process of the updated model parameters in a separate task let upload_handle = tokio::task::spawn(async move { - let path = checkpoint_dir.join(format!("{run_id}-step{step}")); - info!("Saving to {}", path.display()); - let mut local = tokio::task::spawn_blocking({ - let path = path.clone(); - move || save_tensors_into_safetensors(variables, path) - }) - .await - .map_err(|_| CheckpointError::WriteThreadCrashed)??; - - for extra in checkpoint_extra_files { - let to = path.join(extra.file_name().unwrap()); - tokio::fs::copy(extra.clone(), to.clone()) - .await - .map_err(CheckpointError::WriteExtraFile)?; - local.push(to); - } + let path = checkpoint_dir.join(format!("{run_id}-step{step}")); + info!("Saving to {}", path.display()); + let mut local = tokio::task::spawn_blocking({ + let path = path.clone(); + move || save_tensors_into_safetensors(variables, path) + }) + .await + .map_err(|_| CheckpointError::WriteThreadCrashed)??; + + for extra in checkpoint_extra_files { + let to = path.join(extra.file_name().unwrap()); + tokio::fs::copy(extra.clone(), to.clone()) + .await + .map_err(CheckpointError::WriteExtraFile)?; + local.push(to); + } - let Some(HubUploadInfo { + let HubUploadInfo { hub_repo, hub_token, - }) = hub_upload - else { - cleanup_dirs( - delete_queue, - keep_steps, - run_id, - delete_old_steps, - step, - checkpoint_dir, - ) - .await; - return Ok::<(), CheckpointError>(()); - }; + } = hub_upload; info!(repo = hub_repo, "Uploading checkpoint to HuggingFace"); let revision = match upload_model_repo_async( diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index 2b065e6ce..e7063018f 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -67,7 +67,7 @@ pub struct RunInitConfig { pub write_gradients_dir: Option, // checkpointing - pub checkpoint_config: Option, + pub checkpoint_config: CheckpointConfig, // configurable dummy training time (in seconds) for this client - relevant just for testing pub dummy_training_delay_secs: Option, diff --git a/shared/client/src/state/types.rs b/shared/client/src/state/types.rs index 2edf22760..5d35e0e48 100644 --- a/shared/client/src/state/types.rs +++ b/shared/client/src/state/types.rs @@ -16,7 +16,7 @@ pub struct HubUploadInfo { #[derive(Debug, Clone)] pub struct CheckpointConfig { - pub hub_upload: Option, + pub hub_upload: HubUploadInfo, pub checkpoint_dir: PathBuf, pub delete_old_steps: bool, pub keep_steps: u32, diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index 9ddc7191d..e5dabede6 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -601,7 +601,6 @@ impl Coordinator { from: &T, index: u64, hub_repo: HubRepo, - unix_timestamp: u64, ) -> std::result::Result<(), CoordinatorError> { let index = index as usize; if index >= self.epoch_state.clients.len() || self.epoch_state.clients[index].id != *from { @@ -953,7 +952,7 @@ impl Coordinator { ) .unwrap(); - self.epoch_state.checkpointer = self.epoch_state.clients[0].id.clone(); + self.epoch_state.checkpointer = self.epoch_state.clients[0].id; self.start_warmup(unix_timestamp); } From 3553f7d9ae025b4b6f4c26ab4bacb44ee9073b59 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Mon, 5 Jan 2026 10:18:04 -0800 Subject: [PATCH 06/72] Use client as checkpointer and pick random --- architectures/centralized/client/src/app.rs | 1 - shared/client/src/state/cooldown.rs | 5 +++-- shared/coordinator/src/coordinator.rs | 9 ++++----- shared/core/src/fixed_vec.rs | 5 ++--- 4 files changed, 9 insertions(+), 11 deletions(-) diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index 9e841b7f4..65d2c7bc7 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -171,7 +171,6 @@ impl App { p2p: NC, state_options: RunInitConfig, ) -> Result<()> { - // sanity checks let api = hf_hub::api::tokio::ApiBuilder::new() .with_token(Some( state_options.checkpoint_config.hub_upload.hub_token.clone(), diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index f330e3da9..b736d46a5 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -1,5 +1,6 @@ use crate::HubUploadInfo; +use psyche_coordinator::Client; use psyche_coordinator::{ Coordinator, model::{self, HubRepo}, @@ -143,7 +144,7 @@ impl CooldownStepMetadata { let tx_model = self.tx_model.clone(); let model_task_runner = self.model_task_runner.clone(); let delete_queue = self.delete_queue.clone(); - let checkpointer: T = state.epoch_state.checkpointer; + let checkpointer: Client = state.epoch_state.checkpointer; let checkpointing_and_evals: CheckpointAndEvalsHandle = tokio::task::spawn( async move { @@ -169,7 +170,7 @@ impl CooldownStepMetadata { .send(variables_clone) .map_err(|_| CheckpointError::SendCheckpoint)?; - if checkpointer != from { + if checkpointer.id != from { info!("Skipping checkpoint upload as this node is not the checkpointer for this epoch"); return Ok((evals, None)); } diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index e5dabede6..240e9c226 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -274,7 +274,7 @@ pub struct CoordinatorEpochState { /// `get_historical_clients` is what you actually want. pub clients: FixedVec, { SOLANA_MAX_NUM_CLIENTS }>, pub exited_clients: FixedVec, { SOLANA_MAX_NUM_CLIENTS }>, - pub checkpointer: T, + pub checkpointer: Client, pub rounds_head: u32, pub start_step: u32, pub last_step: u32, @@ -413,7 +413,7 @@ impl Default for CoordinatorEpochState { first_round: true.into(), clients: Default::default(), exited_clients: Default::default(), - checkpointer: T::default(), + checkpointer: Default::default(), cold_start_epoch: false.into(), start_step: Default::default(), last_step: Default::default(), @@ -625,7 +625,7 @@ impl Coordinator { return Err(CoordinatorError::InvalidRunState); } - if self.epoch_state.checkpointer != *from { + if self.epoch_state.checkpointer.id != *from { return Err(CoordinatorError::InvalidWitness); } else { self.epoch_state.checkpointed = true; @@ -951,8 +951,7 @@ impl Coordinator { .map(|x| Client::new(*x)), ) .unwrap(); - - self.epoch_state.checkpointer = self.epoch_state.clients[0].id; + self.epoch_state.checkpointer = self.epoch_state.clients.random().unwrap(); self.start_warmup(unix_timestamp); } diff --git a/shared/core/src/fixed_vec.rs b/shared/core/src/fixed_vec.rs index 955b06bff..43b4fbef5 100644 --- a/shared/core/src/fixed_vec.rs +++ b/shared/core/src/fixed_vec.rs @@ -146,14 +146,13 @@ impl FixedVec { Ok(()) } - #[cfg(feature = "rand")] - pub fn random(&self) -> Option<&T> { + pub fn random(&self) -> Option { if self.len == 0 { return None; } let mut rng = rand::rng(); let random_index = rng.random_range(0..self.len) as u16; - Some(&self.data[random_index as usize]) + Some(self.data[random_index as usize]) } pub fn retain(&mut self, mut f: F) From 00b97f1ed75c9d0d6b36de1dc5fd5952cf465036 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Mon, 5 Jan 2026 18:14:01 -0300 Subject: [PATCH 07/72] Reuse opportunistic logic --- Cargo.lock | 1 + architectures/centralized/client/src/app.rs | 8 +- architectures/centralized/server/src/app.rs | 9 ++- .../solana-client/src/backend.rs | 55 +++++++------- .../solana-client/src/instructions.rs | 24 ++++++ .../solana-coordinator/src/instance_state.rs | 34 +++++---- .../programs/solana-coordinator/src/lib.rs | 23 ++++++ shared/client/src/client.rs | 52 ++++++------- shared/client/src/state/cooldown.rs | 10 +++ shared/client/src/state/steps.rs | 74 ++++++++++--------- shared/coordinator/Cargo.toml | 1 + shared/coordinator/src/coordinator.rs | 43 ++++++++++- shared/coordinator/src/model.rs | 13 ++++ shared/core/src/fixed_vec.rs | 16 ++-- shared/watcher/src/traits.rs | 11 ++- 15 files changed, 252 insertions(+), 122 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index dad6cfcdc..c1f585610 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6977,6 +6977,7 @@ name = "psyche-coordinator" version = "0.1.0" dependencies = [ "anchor-lang", + "anyhow", "async-trait", "bytemuck", "cfg_eval", diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index 65d2c7bc7..2d9793132 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -67,10 +67,10 @@ impl WatcherBackend for Backend { Ok(()) } - async fn send_checkpoint(&mut self, checkpoint: model::HubRepo) -> Result<()> { - self.tx.send(ToSend::Checkpoint(checkpoint))?; - Ok(()) - } + // async fn send_checkpoint(&mut self, checkpoint: model::HubRepo) -> Result<()> { + // self.tx.send(ToSend::Checkpoint(checkpoint))?; + // Ok(()) + // } } pub struct App { diff --git a/architectures/centralized/server/src/app.rs b/architectures/centralized/server/src/app.rs index bc034d1db..3d930b767 100644 --- a/architectures/centralized/server/src/app.rs +++ b/architectures/centralized/server/src/app.rs @@ -80,9 +80,9 @@ impl psyche_watcher::Backend for ChannelCoordinatorBackend { bail!("Server does not send health checks"); } - async fn send_checkpoint(&mut self, _checkpoint: model::HubRepo) -> Result<()> { - bail!("Server does not send checkpoints"); - } + // async fn send_checkpoint(&mut self, _checkpoint: model::HubRepo) -> Result<()> { + // bail!("Server does not send checkpoints"); + // } } type DataServer = @@ -395,6 +395,9 @@ impl App { Self::get_timestamp(), rand::rng().next_u64(), ), + OpportunisticData::CooldownStep(witness, hub_repo) => self + .coordinator + .cooldown_witness(&from, witness, Self::get_timestamp(), hub_repo), } { warn!("Error when processing witness: {error}"); }; diff --git a/architectures/decentralized/solana-client/src/backend.rs b/architectures/decentralized/solana-client/src/backend.rs index f6ba621ba..a9be96173 100644 --- a/architectures/decentralized/solana-client/src/backend.rs +++ b/architectures/decentralized/solana-client/src/backend.rs @@ -307,12 +307,15 @@ impl SolanaBackend { &user, witness, ), - // OpportunisticData::CooldownStep(witness) => instructions::coordinator_cooldown_witness( - // &coordinator_instance, - // &coordinator_account, - // &user, - // witness, - // ), + OpportunisticData::CooldownStep(witness, hub_repo) => { + instructions::coordinator_cooldown_witness( + &coordinator_instance, + &coordinator_account, + &user, + witness, + hub_repo, + ) + } }; self.spawn_scheduled_send("Witness", &[instruction], &[]); } @@ -335,21 +338,21 @@ impl SolanaBackend { self.spawn_scheduled_send("Health check", &[instruction], &[]); } - pub fn send_checkpoint( - &self, - coordinator_instance: Pubkey, - coordinator_account: Pubkey, - repo: HubRepo, - ) { - let user = self.get_payer(); - let instruction = instructions::coordinator_checkpoint( - &coordinator_instance, - &coordinator_account, - &user, - repo, - ); - self.spawn_scheduled_send("Checkpoint", &[instruction], &[]); - } + // pub fn send_checkpoint( + // &self, + // coordinator_instance: Pubkey, + // coordinator_account: Pubkey, + // repo: HubRepo, + // ) { + // let user = self.get_payer(); + // let instruction = instructions::coordinator_checkpoint( + // &coordinator_instance, + // &coordinator_account, + // &user, + // repo, + // ); + // self.spawn_scheduled_send("Checkpoint", &[instruction], &[]); + // } pub fn find_join_authorization(join_authority: &Pubkey, authorizer: Option) -> Pubkey { psyche_solana_authorizer::find_authorization( @@ -609,11 +612,11 @@ impl WatcherBackend for SolanaBackendRunner Ok(()) } - async fn send_checkpoint(&mut self, checkpoint: HubRepo) -> Result<()> { - self.backend - .send_checkpoint(self.instance, self.account, checkpoint); - Ok(()) - } + // async fn send_checkpoint(&mut self, checkpoint: HubRepo) -> Result<()> { + // self.backend + // .send_checkpoint(self.instance, self.account, checkpoint); + // Ok(()) + // } } impl SolanaBackendRunner { diff --git a/architectures/decentralized/solana-client/src/instructions.rs b/architectures/decentralized/solana-client/src/instructions.rs index a8bec54a0..f23fcfb36 100644 --- a/architectures/decentralized/solana-client/src/instructions.rs +++ b/architectures/decentralized/solana-client/src/instructions.rs @@ -179,6 +179,30 @@ pub fn coordinator_warmup_witness( ) } +pub fn coordinator_cooldown_witness( + coordinator_instance: &Pubkey, + coordinator_account: &Pubkey, + user: &Pubkey, + witness: psyche_coordinator::Witness, + hub_repo: psyche_coordinator::model::HubRepo, +) -> Instruction { + anchor_instruction( + psyche_solana_coordinator::ID, + psyche_solana_coordinator::accounts::PermissionlessCoordinatorAccounts { + user: *user, + coordinator_instance: *coordinator_instance, + coordinator_account: *coordinator_account, + }, + psyche_solana_coordinator::instruction::CooldownWitness { + proof: witness.proof, + participant_bloom: witness.participant_bloom, + broadcast_bloom: witness.broadcast_bloom, + broadcast_merkle: witness.broadcast_merkle, + hub_repo, + }, + ) +} + pub fn coordinator_health_check( coordinator_instance: &Pubkey, coordinator_account: &Pubkey, diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs index d875b4992..864d8d0ea 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs @@ -233,20 +233,26 @@ impl CoordinatorInstanceState { self.tick() } - // pub fn cooldown_witness( - // &mut self, - // payer: &Pubkey, - // witness: Witness, - // ) -> Result<()> { - // let id = self.clients_state.find_signer(payer)?; - - // let clock: Clock = Clock::get()?; - // self.coordinator - // .cooldown_witness(id, witness, clock.unix_timestamp as u64) - // .map_err(|err| anchor_lang::error!(ProgramError::from(err)))?; - - // self.tick() - // } + pub fn cooldown_witness( + &mut self, + payer: &Pubkey, + witness: Witness, + hub_repo: HubRepo, + ) -> Result<()> { + let id = self.clients_state.find_signer(payer)?; + + let clock: Clock = Clock::get()?; + self.coordinator + .cooldown_witness( + id, + witness, + clock.unix_timestamp as u64, + hub_repo, + ) + .map_err(|err| anchor_lang::error!(ProgramError::from(err)))?; + + self.tick() + } pub fn warmup_witness( &mut self, diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs index 657424195..4dccbeec6 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs @@ -289,6 +289,29 @@ pub mod psyche_solana_coordinator { ) } + #[allow(unused_variables)] // for the metadata field. adding a _ prefix results in anchor's IDL not matching the actual types. lol. + pub fn cooldown_witness( + ctx: Context, + proof: WitnessProof, + participant_bloom: WitnessBloom, + broadcast_bloom: WitnessBloom, + broadcast_merkle: MerkleRoot, + hub_repo: HubRepo, + ) -> Result<()> { + let mut account = ctx.accounts.coordinator_account.load_mut()?; + account.increment_nonce(); + account.state.cooldown_witness( + ctx.accounts.user.key, + Witness { + proof, + participant_bloom, + broadcast_bloom, + broadcast_merkle, + }, + hub_repo, + ) + } + pub fn health_check( ctx: Context, id: ClientId, diff --git a/shared/client/src/client.rs b/shared/client/src/client.rs index ffbc6abc9..ec72a554d 100644 --- a/shared/client/src/client.rs +++ b/shared/client/src/client.rs @@ -531,9 +531,9 @@ impl + 'sta Some(health_check) = rx_health_check.recv() => { watcher.backend_mut().send_health_check(health_check).await?; } - Some(checkpoint) = rx_checkpoint.recv() => { - watcher.backend_mut().send_checkpoint(checkpoint).await?; - } + // Some(checkpoint) = rx_checkpoint.recv() => { + // watcher.backend_mut().send_checkpoint(checkpoint).await?; + // } Some(model) = rx_model.recv() => { sharable_model.update_parameters(model)?; }, @@ -678,29 +678,29 @@ impl + 'sta let p2p_shutdown = p2p.shutdown(); - if wait_for_checkpoint { - info!("Waiting for all pending checkpoints to finish"); - - // Keep waiting for checkpoints while there are uploads pending - let mut checkpoint_check_interval = interval(Duration::from_secs(10)); - while run.doing_checkpoint() { - tokio::select! { - checkpoint = rx_checkpoint.recv() => { - if let Some(checkpoint) = checkpoint { - info!("Checkpoint upload completed, sending to Solana"); - watcher.backend_mut().send_checkpoint(checkpoint).await?; - } else { - // Channel closed, no more checkpoints coming - break; - } - } - _ = checkpoint_check_interval.tick() => { - } - } - } - - info!("All checkpoints finished, exiting main client loop"); - } + // if wait_for_checkpoint { + // info!("Waiting for all pending checkpoints to finish"); + + // // Keep waiting for checkpoints while there are uploads pending + // let mut checkpoint_check_interval = interval(Duration::from_secs(10)); + // while run.doing_checkpoint() { + // tokio::select! { + // checkpoint = rx_checkpoint.recv() => { + // if let Some(checkpoint) = checkpoint { + // info!("Checkpoint upload completed, sending to Solana"); + // watcher.backend_mut().send_checkpoint(checkpoint).await?; + // } else { + // // Channel closed, no more checkpoints coming + // break; + // } + // } + // _ = checkpoint_check_interval.tick() => { + // } + // } + // } + + // info!("All checkpoints finished, exiting main client loop"); + // } p2p_shutdown .await diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index b736d46a5..ac6f05d55 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -10,6 +10,7 @@ use psyche_data_provider::{UploadModelError, upload_model_repo_async}; use psyche_modeling::{ SaveSafetensorsError, Trainer, TrainerThreadCommunicationError, save_tensors_into_safetensors, }; +use std::str::FromStr; use std::{ cmp::Reverse, collections::{BinaryHeap, HashMap}, @@ -264,6 +265,11 @@ impl CooldownStepMetadata { checkpointing_and_evals, }) } + + pub fn get_repo(&self) -> Option { + let repo = HubRepo::from_str(&self.checkpoint_info.hub_upload.hub_repo).ok()?; + Some(repo) + } } type CheckpointAndEvalsHandle = JoinHandle< @@ -298,4 +304,8 @@ impl CooldownStep { Ok((running_evals, upload_handle)) } + + pub fn is_finished(&self) -> bool { + self.checkpointing_and_evals.is_finished() + } } diff --git a/shared/client/src/state/steps.rs b/shared/client/src/state/steps.rs index 70ddb1004..c4b05e3d8 100644 --- a/shared/client/src/state/steps.rs +++ b/shared/client/src/state/steps.rs @@ -338,39 +338,47 @@ impl StepStateMachine Coordinator { Ok(()) } + pub fn cooldown_witness( + &mut self, + from: &T, + witness: Witness, + unix_timestamp: u64, + hub_repo: HubRepo, + ) -> std::result::Result<(), CoordinatorError> { + match &mut self.model { + Model::LLM(llm) => match llm.checkpoint { + Checkpoint::P2P(_) => llm.checkpoint = Checkpoint::P2P(hub_repo), + Checkpoint::Hub(_) => llm.checkpoint = Checkpoint::Hub(hub_repo), + _ => {} + }, + } + + if self.halted() { + return Err(CoordinatorError::Halted); + } + + if !matches!(self.run_state, RunState::Cooldown) { + return Err(CoordinatorError::InvalidRunState); + } + + if self.epoch_state.checkpointer.id != *from { + return Err(CoordinatorError::InvalidWitness); + } else { + self.start_waiting_for_members(unix_timestamp); + } + + Ok(()) + } + pub fn witness( &mut self, from: &T, @@ -606,9 +638,6 @@ impl Coordinator { if index >= self.epoch_state.clients.len() || self.epoch_state.clients[index].id != *from { return Err(CoordinatorError::InvalidCommitteeProof); } - // TODO: In the case of more than one checkpointer, this will overwrite the hub repo - // with the last checkpointed one. We could instead have a vector of hub repos to have - // more download options. match &mut self.model { Model::LLM(llm) => match llm.checkpoint { Checkpoint::P2P(_) => llm.checkpoint = Checkpoint::P2P(hub_repo), @@ -951,7 +980,13 @@ impl Coordinator { .map(|x| Client::new(*x)), ) .unwrap(); - self.epoch_state.checkpointer = self.epoch_state.clients.random().unwrap(); + // self.epoch_state.checkpointer = self.epoch_state.clients.random().unwrap(); + self.epoch_state.checkpointer = self + .epoch_state + .clients + .get(0) + .cloned() + .expect("at least one client"); self.start_warmup(unix_timestamp); } diff --git a/shared/coordinator/src/model.rs b/shared/coordinator/src/model.rs index 82efdd8cc..8b8d82eef 100644 --- a/shared/coordinator/src/model.rs +++ b/shared/coordinator/src/model.rs @@ -1,3 +1,5 @@ +use std::str::FromStr; + use crate::{SOLANA_MAX_STRING_LEN, coordinator::SOLANA_MAX_URL_STRING_LEN}; use anchor_lang::{ @@ -226,6 +228,17 @@ impl HubRepo { } } +impl FromStr for HubRepo { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + Ok(HubRepo { + repo_id: FixedString::from_str_truncated(s), + revision: None, + }) + } +} + #[derive( AnchorSerialize, AnchorDeserialize, diff --git a/shared/core/src/fixed_vec.rs b/shared/core/src/fixed_vec.rs index 43b4fbef5..18b52dbc6 100644 --- a/shared/core/src/fixed_vec.rs +++ b/shared/core/src/fixed_vec.rs @@ -146,14 +146,14 @@ impl FixedVec { Ok(()) } - pub fn random(&self) -> Option { - if self.len == 0 { - return None; - } - let mut rng = rand::rng(); - let random_index = rng.random_range(0..self.len) as u16; - Some(self.data[random_index as usize]) - } + // pub fn random(&self) -> Option { + // if self.len == 0 { + // return None; + // } + // let mut rng = rand::rng(); + // let random_index = rng.random_range(0..self.len) as u16; + // Some(self.data[random_index as usize]) + // } pub fn retain(&mut self, mut f: F) where diff --git a/shared/watcher/src/traits.rs b/shared/watcher/src/traits.rs index 7755d8331..d993bfc38 100644 --- a/shared/watcher/src/traits.rs +++ b/shared/watcher/src/traits.rs @@ -1,5 +1,8 @@ use anyhow::Result; -use psyche_coordinator::{Coordinator, HealthChecks, Witness, WitnessMetadata, model}; +use psyche_coordinator::{ + Coordinator, HealthChecks, Witness, WitnessMetadata, + model::{self, HubRepo}, +}; use psyche_core::NodeIdentity; use serde::{Deserialize, Serialize}; @@ -8,7 +11,7 @@ use serde::{Deserialize, Serialize}; pub enum OpportunisticData { WitnessStep(Witness, WitnessMetadata), WarmupStep(Witness), - // CooldownStep(Witness), + CooldownStep(Witness, HubRepo), } impl OpportunisticData { @@ -16,7 +19,7 @@ impl OpportunisticData { match self { OpportunisticData::WitnessStep(..) => "witness", OpportunisticData::WarmupStep(..) => "warmup", - // OpportunisticData::CooldownStep(..) => "cooldown", + OpportunisticData::CooldownStep(..) => "cooldown", } } } @@ -29,5 +32,5 @@ pub trait Backend: Send + Sync { async fn wait_for_new_state(&mut self) -> Result>; async fn send_witness(&mut self, opportunistic_data: OpportunisticData) -> Result<()>; async fn send_health_check(&mut self, health_check: HealthChecks) -> Result<()>; - async fn send_checkpoint(&mut self, checkpoint: model::HubRepo) -> Result<()>; + // async fn send_checkpoint(&mut self, checkpoint: model::HubRepo) -> Result<()>; } From 0566a4aaf3e6482b6009cac212859c9b407770a1 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Mon, 5 Jan 2026 13:32:48 -0800 Subject: [PATCH 08/72] Add anyhow to cargo lock --- architectures/decentralized/solana-coordinator/Cargo.lock | 1 + architectures/decentralized/solana-treasurer/Cargo.lock | 1 + 2 files changed, 2 insertions(+) diff --git a/architectures/decentralized/solana-coordinator/Cargo.lock b/architectures/decentralized/solana-coordinator/Cargo.lock index 22d64fadc..001c29943 100644 --- a/architectures/decentralized/solana-coordinator/Cargo.lock +++ b/architectures/decentralized/solana-coordinator/Cargo.lock @@ -1605,6 +1605,7 @@ name = "psyche-coordinator" version = "0.1.0" dependencies = [ "anchor-lang", + "anyhow", "async-trait", "bytemuck", "cfg_eval", diff --git a/architectures/decentralized/solana-treasurer/Cargo.lock b/architectures/decentralized/solana-treasurer/Cargo.lock index 5d56eb74d..06fc99a96 100644 --- a/architectures/decentralized/solana-treasurer/Cargo.lock +++ b/architectures/decentralized/solana-treasurer/Cargo.lock @@ -1605,6 +1605,7 @@ name = "psyche-coordinator" version = "0.1.0" dependencies = [ "anchor-lang", + "anyhow", "async-trait", "bytemuck", "cfg_eval", From 203d0c47c614a6bef6844a9824676ed6f397694c Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Tue, 6 Jan 2026 12:33:27 -0800 Subject: [PATCH 09/72] Fix opportunistic cooldown message --- shared/client/src/state/cooldown.rs | 12 ++++++------ shared/client/src/state/steps.rs | 9 ++++++++- shared/core/src/fixed_vec.rs | 11 ----------- 3 files changed, 14 insertions(+), 18 deletions(-) diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index ac6f05d55..70f7b212d 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -230,12 +230,12 @@ impl CooldownStepMetadata { } }; - tx_checkpoint - .send(HubRepo { - repo_id: FixedString::from_str_truncated(&hub_repo), - revision: Some(FixedString::from_str_truncated(&revision)), - }) - .map_err(|_| CheckpointError::SendCheckpoint)?; + // tx_checkpoint + // .send(HubRepo { + // repo_id: FixedString::from_str_truncated(&hub_repo), + // revision: Some(FixedString::from_str_truncated(&revision)), + // }) + // .map_err(|_| CheckpointError::SendCheckpoint)?; // we put the cleanup step at the end, so that if keep_steps == 0 the logic will still work // we'll just delete the dir after we've uploaded it diff --git a/shared/client/src/state/steps.rs b/shared/client/src/state/steps.rs index c4b05e3d8..83126e4b6 100644 --- a/shared/client/src/state/steps.rs +++ b/shared/client/src/state/steps.rs @@ -171,8 +171,15 @@ impl StepStateMachine Result<(), OpportunisticWitnessError> { - if let Some(committee_info) = &self.current_round.committee_info { + println!("CURRENT STATE = {:?}", self.coordinator_state.run_state); + if self.current_round.committee_info.is_some() + && !matches!( + self.coordinator_state.run_state, + RunState::Warmup | RunState::Cooldown + ) + { // trace!("Checking for opprotunistic witness with committee info"); + let committee_info = self.current_round.committee_info.as_ref().unwrap(); if let ActiveStep::Training(step) = &self.active_step { let all_prev_round_batches_are_trained = self .previous_round diff --git a/shared/core/src/fixed_vec.rs b/shared/core/src/fixed_vec.rs index 18b52dbc6..4058b3e2b 100644 --- a/shared/core/src/fixed_vec.rs +++ b/shared/core/src/fixed_vec.rs @@ -1,8 +1,6 @@ use crate as psyche_core; use anchor_lang::{AnchorDeserialize, AnchorSerialize, prelude::borsh}; use bytemuck::Zeroable; -#[cfg(feature = "rand")] -use rand::Rng; use serde::{Deserialize, Serialize}; use std::ops::{Deref, DerefMut, Range, RangeFrom, RangeFull, RangeTo}; use ts_rs::TS; @@ -146,15 +144,6 @@ impl FixedVec { Ok(()) } - // pub fn random(&self) -> Option { - // if self.len == 0 { - // return None; - // } - // let mut rng = rand::rng(); - // let random_index = rng.random_range(0..self.len) as u16; - // Some(self.data[random_index as usize]) - // } - pub fn retain(&mut self, mut f: F) where F: FnMut(&T) -> bool, From e8b1dd84d6c0db294d24082ea2b7151011c1fa66 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Tue, 6 Jan 2026 18:08:30 -0300 Subject: [PATCH 10/72] Add random index for client --- Cargo.lock | 1 + shared/coordinator/Cargo.toml | 1 + shared/coordinator/src/coordinator.rs | 11 ++++++++--- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c1f585610..47b62c565 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6982,6 +6982,7 @@ dependencies = [ "bytemuck", "cfg_eval", "psyche-core", + "rand 0.9.2", "serde", "serde_with", "ts-rs", diff --git a/shared/coordinator/Cargo.toml b/shared/coordinator/Cargo.toml index 024696555..c980571f3 100644 --- a/shared/coordinator/Cargo.toml +++ b/shared/coordinator/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" psyche-core.workspace = true async-trait.workspace = true anchor-lang.workspace = true +rand.workspace = true bytemuck.workspace = true serde_with.workspace = true anyhow.workspace = true diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index 2b525f8a8..96784fe3c 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -5,7 +5,11 @@ use crate::{ use anchor_lang::{AnchorDeserialize, AnchorSerialize, InitSpace, prelude::borsh}; use bytemuck::{Pod, Zeroable}; -use psyche_core::{Bloom, FixedString, FixedVec, MerkleRoot, NodeIdentity, SmallBoolean, sha256}; +use psyche_core::{ + Bloom, FixedString, FixedVec, MerkleRoot, NodeIdentity, SmallBoolean, compute_shuffled_index, + sha256, sha256v, +}; +use rand::prelude::*; use serde::{Deserialize, Serialize}; use std::{collections::HashSet, hash::Hash}; use ts_rs::TS; @@ -980,11 +984,12 @@ impl Coordinator { .map(|x| Client::new(*x)), ) .unwrap(); - // self.epoch_state.checkpointer = self.epoch_state.clients.random().unwrap(); + + let index = rand::rng().random_range(0..self.epoch_state.clients.len()); self.epoch_state.checkpointer = self .epoch_state .clients - .get(0) + .get(index) .cloned() .expect("at least one client"); self.start_warmup(unix_timestamp); From 0b6e9f506d1da6bac6594ae908ff8296ee3a5d3e Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Tue, 6 Jan 2026 18:52:54 -0300 Subject: [PATCH 11/72] Remove rand and get checkpointers at cooldown state --- Cargo.lock | 1 - shared/client/src/state/cooldown.rs | 17 ++++++++++++++++- shared/client/src/state/steps.rs | 9 +++++++-- shared/coordinator/Cargo.toml | 1 - shared/coordinator/src/coordinator.rs | 15 +++++++-------- 5 files changed, 30 insertions(+), 13 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 47b62c565..c1f585610 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6982,7 +6982,6 @@ dependencies = [ "bytemuck", "cfg_eval", "psyche-core", - "rand 0.9.2", "serde", "serde_with", "ts-rs", diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index 70f7b212d..f9fce65ed 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -132,12 +132,14 @@ impl CooldownStepMetadata { mut trainers: Vec, state: &Coordinator, from: T, + client_index: u64, ) -> Result { let Some(mut trainer) = trainers.pop() else { return Err(CooldownError::NoTrainers); }; let step = state.progress.step - 1; + let current_round = state.current_round().ok_or(CooldownError::NoTrainers)?; let run_id = String::from(&state.run_id); let checkpoint_extra_files = self.checkpoint_extra_files.clone(); let checkpoint_info = self.checkpoint_info.clone(); @@ -145,7 +147,20 @@ impl CooldownStepMetadata { let tx_model = self.tx_model.clone(); let model_task_runner = self.model_task_runner.clone(); let delete_queue = self.delete_queue.clone(); - let checkpointer: Client = state.epoch_state.checkpointer; + let mut seed = [0u8; 32]; + seed.copy_from_slice(&sha256v(&[ + &sha256(&seed.to_le_bytes()), + "COOLDOWN".as_bytes(), + ])); + let random_index = + compute_shuffled_index(client_index, state.epoch_state.clients.len() as u64, &seed) + as usize; + let checkpointer = state + .epoch_state + .clients + .get(random_index) + .cloned() + .ok_or(CooldownError::NoTrainers)?; let checkpointing_and_evals: CheckpointAndEvalsHandle = tokio::task::spawn( async move { diff --git a/shared/client/src/state/steps.rs b/shared/client/src/state/steps.rs index 83126e4b6..4a0e1984a 100644 --- a/shared/client/src/state/steps.rs +++ b/shared/client/src/state/steps.rs @@ -1,5 +1,5 @@ use crate::{ - Broadcast, BroadcastType, ClientTUIState, IntegrationTestLogMarker, + Broadcast, BroadcastType, ClientTUIState, IntegrationTestLogMarker, client, state::{train::FinishedTrainers, types::DeserializeError}, }; @@ -887,7 +887,12 @@ impl StepStateMachine Coordinator { ) .unwrap(); - let index = rand::rng().random_range(0..self.epoch_state.clients.len()); - self.epoch_state.checkpointer = self - .epoch_state - .clients - .get(index) - .cloned() - .expect("at least one client"); - self.start_warmup(unix_timestamp); + // self.epoch_state.checkpointer = self + // .epoch_state + // .clients + // .get(0) + // .cloned() + // .expect("at least one client"); + // self.start_warmup(unix_timestamp); } Ok(TickResult::Ticked) From c001fa51b3e7e60c434082bb58d5fbc5ace1ff47 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Wed, 7 Jan 2026 11:40:12 -0800 Subject: [PATCH 12/72] Rework on cooldown to select pseudo random --- .../centralized/testing/src/server.rs | 1 + config/solana-test/light-config.toml | 1 + shared/client/src/state/cooldown.rs | 22 ++++++----- shared/client/src/state/steps.rs | 25 +++++++++++- shared/coordinator/src/committee_selection.rs | 36 ++++++++++++++++++ shared/coordinator/src/coordinator.rs | 38 ++++++++++++------- shared/coordinator/src/lib.rs | 3 +- 7 files changed, 100 insertions(+), 26 deletions(-) diff --git a/architectures/centralized/testing/src/server.rs b/architectures/centralized/testing/src/server.rs index 0c0ec60a8..06062e5c5 100644 --- a/architectures/centralized/testing/src/server.rs +++ b/architectures/centralized/testing/src/server.rs @@ -77,6 +77,7 @@ impl CoordinatorServer { global_batch_size_end: global_batch_size, global_batch_size_warmup_tokens: 0, verification_percent: 0, + checkpointer_nodes: 0, witness_nodes, total_steps: 100, waiting_for_members_extra_time: 2, diff --git a/config/solana-test/light-config.toml b/config/solana-test/light-config.toml index 5de8ca8a4..922adcc9e 100644 --- a/config/solana-test/light-config.toml +++ b/config/solana-test/light-config.toml @@ -8,6 +8,7 @@ min_clients = 1 init_min_clients = 1 verification_percent = 0 witness_nodes = 0 +checkpointer_nodes = 1 global_batch_size_start = 8 global_batch_size_end = 8 global_batch_size_warmup_tokens = 0 diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index f9fce65ed..2a12de34f 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -1,10 +1,14 @@ use crate::HubUploadInfo; +use psyche_coordinator::CheckpointerSelection; use psyche_coordinator::Client; use psyche_coordinator::{ Coordinator, model::{self, HubRepo}, }; +use psyche_core::compute_shuffled_index; +use psyche_core::sha256; +use psyche_core::sha256v; use psyche_core::{FixedString, NodeIdentity}; use psyche_data_provider::{UploadModelError, upload_model_repo_async}; use psyche_modeling::{ @@ -147,18 +151,16 @@ impl CooldownStepMetadata { let tx_model = self.tx_model.clone(); let model_task_runner = self.model_task_runner.clone(); let delete_queue = self.delete_queue.clone(); - let mut seed = [0u8; 32]; - seed.copy_from_slice(&sha256v(&[ - &sha256(&seed.to_le_bytes()), - "COOLDOWN".as_bytes(), - ])); - let random_index = - compute_shuffled_index(client_index, state.epoch_state.clients.len() as u64, &seed) - as usize; + let checkpointer_selection = CheckpointerSelection::from_coordinator(state, 0) + .map_err(|e| CooldownError::NoTrainers)?; + let is_checkpointer = checkpointer_selection + .get_checkpointer(client_index, state.epoch_state.clients.len() as u64); + println!("CHECKPOINTER INDEX: {}", is_checkpointer); + println!("CLIENT INDEX: {}", client_index); let checkpointer = state .epoch_state .clients - .get(random_index) + .get(client_index as usize) .cloned() .ok_or(CooldownError::NoTrainers)?; @@ -186,7 +188,7 @@ impl CooldownStepMetadata { .send(variables_clone) .map_err(|_| CheckpointError::SendCheckpoint)?; - if checkpointer.id != from { + if !is_checkpointer { info!("Skipping checkpoint upload as this node is not the checkpointer for this epoch"); return Ok((evals, None)); } diff --git a/shared/client/src/state/steps.rs b/shared/client/src/state/steps.rs index 4a0e1984a..4a30c7d61 100644 --- a/shared/client/src/state/steps.rs +++ b/shared/client/src/state/steps.rs @@ -4,6 +4,7 @@ use crate::{ }; use iroh_blobs::api::Tag; +use psyche_coordinator::CheckpointerSelection; use psyche_coordinator::{Committee, Coordinator, RunState, Witness, WitnessProof}; use psyche_core::{MerkleRoot, MerkleTree, NodeIdentity, sha256}; use psyche_modeling::{DistroResult, Trainer}; @@ -351,7 +352,29 @@ impl StepStateMachine( + coordinator: &Coordinator, + offset: isize, + ) -> Result { + let round = match offset { + -2 => coordinator.previous_previous_round(), + -1 => coordinator.previous_round(), + 0 => coordinator.current_round(), + _ => { + return Err(CoordinatorError::NoActiveRound); + } + } + .ok_or(CoordinatorError::NoActiveRound)?; + let seed = sha256(&round.random_seed.to_le_bytes()); + Ok(Self { + cooldown_nodes: coordinator.config.checkpointer_nodes as u64, + seed, + }) + } + + pub fn get_checkpointer(&self, client_index: u64, total_clients: u64) -> bool { + let mut final_seed = [0u8; 32]; + final_seed.copy_from_slice(&sha256v(&[&sha256(&self.seed), COOLDOWN_SALT.as_bytes()])); + let index = compute_shuffled_index(client_index, total_clients, &final_seed); + index < self.cooldown_nodes + } +} #[derive( Clone, diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index ce61b5d00..76a8117ea 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -1,5 +1,5 @@ use crate::{ - Commitment, Committee, CommitteeProof, CommitteeSelection, WitnessProof, + CheckpointerSelection, Commitment, Committee, CommitteeProof, CommitteeSelection, WitnessProof, model::{Checkpoint, HubRepo, Model}, }; @@ -9,7 +9,6 @@ use psyche_core::{ Bloom, FixedString, FixedVec, MerkleRoot, NodeIdentity, SmallBoolean, compute_shuffled_index, sha256, sha256v, }; -use rand::prelude::*; use serde::{Deserialize, Serialize}; use std::{collections::HashSet, hash::Hash}; use ts_rs::TS; @@ -255,6 +254,7 @@ pub struct CoordinatorConfig { pub init_min_clients: u16, pub min_clients: u16, pub witness_nodes: u16, + pub checkpointer_nodes: u16, pub global_batch_size_start: u16, pub global_batch_size_end: u16, @@ -519,17 +519,9 @@ impl Coordinator { &mut self, from: &T, witness: Witness, - unix_timestamp: u64, + _unix_timestamp: u64, hub_repo: HubRepo, ) -> std::result::Result<(), CoordinatorError> { - match &mut self.model { - Model::LLM(llm) => match llm.checkpoint { - Checkpoint::P2P(_) => llm.checkpoint = Checkpoint::P2P(hub_repo), - Checkpoint::Hub(_) => llm.checkpoint = Checkpoint::Hub(hub_repo), - _ => {} - }, - } - if self.halted() { return Err(CoordinatorError::Halted); } @@ -538,10 +530,27 @@ impl Coordinator { return Err(CoordinatorError::InvalidRunState); } - if self.epoch_state.checkpointer.id != *from { + let client_index = self + .epoch_state + .clients + .iter() + .position(|x| x.id == *from) + .unwrap(); + + let current_round = self.current_round().unwrap(); + let checkpointer_selection = CheckpointerSelection::from_coordinator(self, 0)?; + let is_checkpointer = checkpointer_selection + .get_checkpointer(client_index as u64, self.epoch_state.clients.len() as u64); + let checkpointer = self + .epoch_state + .clients + .get(client_index as usize) + .cloned() + .unwrap(); + if !is_checkpointer { return Err(CoordinatorError::InvalidWitness); } else { - self.start_waiting_for_members(unix_timestamp); + self.epoch_state.checkpointed = true; } Ok(()) @@ -991,7 +1000,7 @@ impl Coordinator { // .get(0) // .cloned() // .expect("at least one client"); - // self.start_warmup(unix_timestamp); + self.start_warmup(unix_timestamp); } Ok(TickResult::Ticked) @@ -1235,6 +1244,7 @@ impl CoordinatorConfig { && self.global_batch_size_end >= self.global_batch_size_start && self.total_steps != 0 && self.witness_nodes <= self.min_clients + && self.checkpointer_nodes <= self.min_clients && self.witness_nodes as usize <= SOLANA_MAX_NUM_WITNESSES && self.cooldown_time > 0 && self.waiting_for_members_extra_time > 0 diff --git a/shared/coordinator/src/lib.rs b/shared/coordinator/src/lib.rs index bef26863e..3f3dc2c11 100644 --- a/shared/coordinator/src/lib.rs +++ b/shared/coordinator/src/lib.rs @@ -8,7 +8,8 @@ pub mod model; pub use commitment::Commitment; pub use committee_selection::{ - COMMITTEE_SALT, Committee, CommitteeProof, CommitteeSelection, WITNESS_SALT, WitnessProof, + COMMITTEE_SALT, CheckpointerSelection, Committee, CommitteeProof, CommitteeSelection, + WITNESS_SALT, WitnessProof, }; pub use coordinator::{ BLOOM_FALSE_RATE, Client, ClientState, Coordinator, CoordinatorConfig, CoordinatorEpochState, From ef3b46817ad4172518fdefb97b6ebfa42911bbda Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Wed, 7 Jan 2026 12:40:48 -0800 Subject: [PATCH 13/72] Fix clippy --- architectures/centralized/client/src/app.rs | 4 +- architectures/centralized/server/src/app.rs | 8 ++-- .../solana-client/src/backend.rs | 17 +++----- .../solana-client/src/instructions.rs | 2 - .../solana-coordinator/src/instance_state.rs | 10 +---- .../programs/solana-coordinator/src/lib.rs | 2 - shared/client/src/client.rs | 9 +--- shared/client/src/state/cooldown.rs | 43 ++----------------- shared/client/src/state/init.rs | 3 -- shared/client/src/state/steps.rs | 24 +++-------- shared/coordinator/src/coordinator.rs | 16 +------ shared/watcher/src/traits.rs | 8 +--- 12 files changed, 28 insertions(+), 118 deletions(-) diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index 2d9793132..25540bc0e 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -5,7 +5,7 @@ use psyche_centralized_shared::{ClientId, ClientToServerMessage, ServerToClientM use psyche_client::{ Client, ClientTUI, ClientTUIState, NC, RunInitConfig, TrainArgs, read_identity_secret_key, }; -use psyche_coordinator::{Coordinator, HealthChecks, model}; +use psyche_coordinator::{Coordinator, HealthChecks}; use psyche_metrics::ClientMetrics; use psyche_network::{ AuthenticatableIdentity, EndpointId, NetworkTUIState, NetworkTui, SecretKey, TcpClient, @@ -29,7 +29,6 @@ pub type TabsData = ::Data; pub enum ToSend { Witness(Box), HealthCheck(HealthChecks), - Checkpoint(model::HubRepo), } struct Backend { @@ -227,7 +226,6 @@ impl App { match to_send { ToSend::Witness(witness) => self.server_conn.send(ClientToServerMessage::Witness(witness)).await?, ToSend::HealthCheck(health_checks) => self.server_conn.send(ClientToServerMessage::HealthCheck(health_checks)).await?, - ToSend::Checkpoint(checkpoint) => self.server_conn.send(ClientToServerMessage::Checkpoint(checkpoint)).await?, }; } } diff --git a/architectures/centralized/server/src/app.rs b/architectures/centralized/server/src/app.rs index 3d930b767..15c2de53a 100644 --- a/architectures/centralized/server/src/app.rs +++ b/architectures/centralized/server/src/app.rs @@ -1,7 +1,7 @@ use anyhow::{Result, anyhow, bail}; use async_trait::async_trait; use psyche_centralized_shared::{ClientId, ClientToServerMessage, ServerToClientMessage}; -use psyche_coordinator::model::{self, Checkpoint, LLM, LLMTrainingDataLocation, Model}; +use psyche_coordinator::model::{Checkpoint, LLM, LLMTrainingDataLocation, Model}; use psyche_coordinator::{ Client, ClientState, Coordinator, CoordinatorError, HealthChecks, Round, RunState, SOLANA_MAX_NUM_CLIENTS, TickResult, @@ -395,9 +395,9 @@ impl App { Self::get_timestamp(), rand::rng().next_u64(), ), - OpportunisticData::CooldownStep(witness, hub_repo) => self - .coordinator - .cooldown_witness(&from, witness, Self::get_timestamp(), hub_repo), + OpportunisticData::CooldownStep(witness) => { + self.coordinator.cooldown_witness(&from, witness) + } } { warn!("Error when processing witness: {error}"); }; diff --git a/architectures/decentralized/solana-client/src/backend.rs b/architectures/decentralized/solana-client/src/backend.rs index a9be96173..bf35d5496 100644 --- a/architectures/decentralized/solana-client/src/backend.rs +++ b/architectures/decentralized/solana-client/src/backend.rs @@ -20,7 +20,7 @@ use anchor_client::{ use anyhow::{Context, Result, anyhow}; use futures_util::StreamExt; use psyche_client::IntegrationTestLogMarker; -use psyche_coordinator::{CommitteeProof, Coordinator, HealthChecks, model::HubRepo}; +use psyche_coordinator::{CommitteeProof, Coordinator, HealthChecks}; use psyche_watcher::{Backend as WatcherBackend, OpportunisticData}; use solana_account_decoder_client_types::{UiAccount, UiAccountEncoding}; use solana_transaction_status_client_types::UiTransactionEncoding; @@ -307,15 +307,12 @@ impl SolanaBackend { &user, witness, ), - OpportunisticData::CooldownStep(witness, hub_repo) => { - instructions::coordinator_cooldown_witness( - &coordinator_instance, - &coordinator_account, - &user, - witness, - hub_repo, - ) - } + OpportunisticData::CooldownStep(witness) => instructions::coordinator_cooldown_witness( + &coordinator_instance, + &coordinator_account, + &user, + witness, + ), }; self.spawn_scheduled_send("Witness", &[instruction], &[]); } diff --git a/architectures/decentralized/solana-client/src/instructions.rs b/architectures/decentralized/solana-client/src/instructions.rs index f23fcfb36..5ea7d5ffa 100644 --- a/architectures/decentralized/solana-client/src/instructions.rs +++ b/architectures/decentralized/solana-client/src/instructions.rs @@ -184,7 +184,6 @@ pub fn coordinator_cooldown_witness( coordinator_account: &Pubkey, user: &Pubkey, witness: psyche_coordinator::Witness, - hub_repo: psyche_coordinator::model::HubRepo, ) -> Instruction { anchor_instruction( psyche_solana_coordinator::ID, @@ -198,7 +197,6 @@ pub fn coordinator_cooldown_witness( participant_bloom: witness.participant_bloom, broadcast_bloom: witness.broadcast_bloom, broadcast_merkle: witness.broadcast_merkle, - hub_repo, }, ) } diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs index 864d8d0ea..d8c8b5584 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs @@ -237,18 +237,10 @@ impl CoordinatorInstanceState { &mut self, payer: &Pubkey, witness: Witness, - hub_repo: HubRepo, ) -> Result<()> { let id = self.clients_state.find_signer(payer)?; - - let clock: Clock = Clock::get()?; self.coordinator - .cooldown_witness( - id, - witness, - clock.unix_timestamp as u64, - hub_repo, - ) + .cooldown_witness(id, witness) .map_err(|err| anchor_lang::error!(ProgramError::from(err)))?; self.tick() diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs index 4dccbeec6..06f9d7856 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs @@ -296,7 +296,6 @@ pub mod psyche_solana_coordinator { participant_bloom: WitnessBloom, broadcast_bloom: WitnessBloom, broadcast_merkle: MerkleRoot, - hub_repo: HubRepo, ) -> Result<()> { let mut account = ctx.accounts.coordinator_account.load_mut()?; account.increment_nonce(); @@ -308,7 +307,6 @@ pub mod psyche_solana_coordinator { broadcast_bloom, broadcast_merkle, }, - hub_repo, ) } diff --git a/shared/client/src/client.rs b/shared/client/src/client.rs index ec72a554d..afa123c31 100644 --- a/shared/client/src/client.rs +++ b/shared/client/src/client.rs @@ -89,7 +89,6 @@ impl + 'sta // From Run let (tx_witness, mut rx_witness) = mpsc::unbounded_channel(); let (tx_health_check, mut rx_health_check) = mpsc::unbounded_channel(); - let (tx_checkpoint, mut rx_checkpoint) = mpsc::unbounded_channel(); let (tx_model, mut rx_model) = mpsc::unbounded_channel(); let (tx_distro_result, mut rx_distro_result) = mpsc::unbounded_channel(); let (tx_request_download, mut rx_request_download) = mpsc::unbounded_channel(); @@ -112,7 +111,6 @@ impl + 'sta metrics: metrics.clone(), tx_witness, tx_health_check, - tx_checkpoint, tx_model, tx_parameters_req, tx_config, @@ -135,7 +133,7 @@ impl + 'sta let mut retry_check_interval = interval(DOWNLOAD_RETRY_CHECK_INTERVAL); let mut opportunistic_witness_interval = interval(OPPROTUNISTIC_WITNESS_INTERVAL); let mut check_connection_interval = interval(CHECK_CONNECTION_INTERVAL); - let mut wait_for_checkpoint = false; + let mut _wait_for_checkpoint = false; let mut last_gossip_connection_time = SystemTime::now(); debug!("Starting client loop"); @@ -144,7 +142,7 @@ impl + 'sta _ = cancel.cancelled() => { info!("Got request to cancel main client loop"); if run.doing_checkpoint() { - wait_for_checkpoint = true; + _wait_for_checkpoint = true; } break; } @@ -531,9 +529,6 @@ impl + 'sta Some(health_check) = rx_health_check.recv() => { watcher.backend_mut().send_health_check(health_check).await?; } - // Some(checkpoint) = rx_checkpoint.recv() => { - // watcher.backend_mut().send_checkpoint(checkpoint).await?; - // } Some(model) = rx_model.recv() => { sharable_model.update_parameters(model)?; }, diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index 2a12de34f..6777ca815 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -1,20 +1,12 @@ use crate::HubUploadInfo; use psyche_coordinator::CheckpointerSelection; -use psyche_coordinator::Client; -use psyche_coordinator::{ - Coordinator, - model::{self, HubRepo}, -}; -use psyche_core::compute_shuffled_index; -use psyche_core::sha256; -use psyche_core::sha256v; -use psyche_core::{FixedString, NodeIdentity}; +use psyche_coordinator::Coordinator; +use psyche_core::NodeIdentity; use psyche_data_provider::{UploadModelError, upload_model_repo_async}; use psyche_modeling::{ SaveSafetensorsError, Trainer, TrainerThreadCommunicationError, save_tensors_into_safetensors, }; -use std::str::FromStr; use std::{ cmp::Reverse, collections::{BinaryHeap, HashMap}, @@ -47,7 +39,6 @@ pub enum CooldownError { } pub struct CooldownStepMetadata { - tx_checkpoint: mpsc::UnboundedSender, tx_model: mpsc::UnboundedSender>, checkpoint_info: CheckpointConfig, checkpoint_extra_files: Vec, @@ -64,14 +55,12 @@ pub struct CooldownStepMetadata { impl CooldownStepMetadata { pub fn new( - tx_checkpoint: mpsc::UnboundedSender, tx_model: mpsc::UnboundedSender>, checkpoint_info: CheckpointConfig, checkpoint_extra_files: Vec, model_task_runner: ModelTaskRunner, ) -> Self { Self { - tx_checkpoint, tx_model, checkpoint_info, checkpoint_extra_files, @@ -135,7 +124,6 @@ impl CooldownStepMetadata { &self, mut trainers: Vec, state: &Coordinator, - from: T, client_index: u64, ) -> Result { let Some(mut trainer) = trainers.pop() else { @@ -143,26 +131,16 @@ impl CooldownStepMetadata { }; let step = state.progress.step - 1; - let current_round = state.current_round().ok_or(CooldownError::NoTrainers)?; let run_id = String::from(&state.run_id); let checkpoint_extra_files = self.checkpoint_extra_files.clone(); let checkpoint_info = self.checkpoint_info.clone(); - let tx_checkpoint = self.tx_checkpoint.clone(); let tx_model = self.tx_model.clone(); let model_task_runner = self.model_task_runner.clone(); let delete_queue = self.delete_queue.clone(); let checkpointer_selection = CheckpointerSelection::from_coordinator(state, 0) - .map_err(|e| CooldownError::NoTrainers)?; + .map_err(|_| CooldownError::NoTrainers)?; let is_checkpointer = checkpointer_selection .get_checkpointer(client_index, state.epoch_state.clients.len() as u64); - println!("CHECKPOINTER INDEX: {}", is_checkpointer); - println!("CLIENT INDEX: {}", client_index); - let checkpointer = state - .epoch_state - .clients - .get(client_index as usize) - .cloned() - .ok_or(CooldownError::NoTrainers)?; let checkpointing_and_evals: CheckpointAndEvalsHandle = tokio::task::spawn( async move { @@ -224,7 +202,7 @@ impl CooldownStepMetadata { } = hub_upload; info!(repo = hub_repo, "Uploading checkpoint to HuggingFace"); - let revision = match upload_model_repo_async( + match upload_model_repo_async( hub_repo.clone(), local, hub_token.clone(), @@ -239,7 +217,6 @@ impl CooldownStepMetadata { revision = revision, "Upload to HuggingFace complete" ); - revision } Err(err) => { error!(repo = hub_repo, "Error uploading to HuggingFace: {err:#}"); @@ -247,13 +224,6 @@ impl CooldownStepMetadata { } }; - // tx_checkpoint - // .send(HubRepo { - // repo_id: FixedString::from_str_truncated(&hub_repo), - // revision: Some(FixedString::from_str_truncated(&revision)), - // }) - // .map_err(|_| CheckpointError::SendCheckpoint)?; - // we put the cleanup step at the end, so that if keep_steps == 0 the logic will still work // we'll just delete the dir after we've uploaded it // if we fail in any of the above steps we may wind up not queueing this dir for delete @@ -282,11 +252,6 @@ impl CooldownStepMetadata { checkpointing_and_evals, }) } - - pub fn get_repo(&self) -> Option { - let repo = HubRepo::from_str(&self.checkpoint_info.hub_upload.hub_repo).ok()?; - Some(repo) - } } type CheckpointAndEvalsHandle = JoinHandle< diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index e7063018f..e4af817bf 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -150,7 +150,6 @@ pub struct RunInitConfigAndIO { pub tx_health_check: UnboundedSender>, pub tx_witness: UnboundedSender, - pub tx_checkpoint: UnboundedSender, pub tx_model: UnboundedSender>, pub tx_parameters_req: UnboundedSender<(Vec, OneshotModelParameterSender)>, pub tx_config: UnboundedSender<(String, String)>, @@ -172,7 +171,6 @@ impl RunInitConfigAndIO RunInitConfigAndIO StepStateMachine StepStateMachine StepStateMachine Coordinator { pub fn cooldown_witness( &mut self, from: &T, - witness: Witness, - _unix_timestamp: u64, - hub_repo: HubRepo, + _witness: Witness, ) -> std::result::Result<(), CoordinatorError> { if self.halted() { return Err(CoordinatorError::Halted); @@ -537,16 +532,9 @@ impl Coordinator { .position(|x| x.id == *from) .unwrap(); - let current_round = self.current_round().unwrap(); let checkpointer_selection = CheckpointerSelection::from_coordinator(self, 0)?; let is_checkpointer = checkpointer_selection .get_checkpointer(client_index as u64, self.epoch_state.clients.len() as u64); - let checkpointer = self - .epoch_state - .clients - .get(client_index as usize) - .cloned() - .unwrap(); if !is_checkpointer { return Err(CoordinatorError::InvalidWitness); } else { diff --git a/shared/watcher/src/traits.rs b/shared/watcher/src/traits.rs index d993bfc38..944f35657 100644 --- a/shared/watcher/src/traits.rs +++ b/shared/watcher/src/traits.rs @@ -1,8 +1,5 @@ use anyhow::Result; -use psyche_coordinator::{ - Coordinator, HealthChecks, Witness, WitnessMetadata, - model::{self, HubRepo}, -}; +use psyche_coordinator::{Coordinator, HealthChecks, Witness, WitnessMetadata}; use psyche_core::NodeIdentity; use serde::{Deserialize, Serialize}; @@ -11,7 +8,7 @@ use serde::{Deserialize, Serialize}; pub enum OpportunisticData { WitnessStep(Witness, WitnessMetadata), WarmupStep(Witness), - CooldownStep(Witness, HubRepo), + CooldownStep(Witness), } impl OpportunisticData { @@ -32,5 +29,4 @@ pub trait Backend: Send + Sync { async fn wait_for_new_state(&mut self) -> Result>; async fn send_witness(&mut self, opportunistic_data: OpportunisticData) -> Result<()>; async fn send_health_check(&mut self, health_check: HealthChecks) -> Result<()>; - // async fn send_checkpoint(&mut self, checkpoint: model::HubRepo) -> Result<()>; } From 8d8c3f416bd1a9d7e27115553d558b5c72fcb59f Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Thu, 8 Jan 2026 05:31:51 -0800 Subject: [PATCH 14/72] Add support for multiple checkpointers --- shared/client/src/state/steps.rs | 5 ++++- shared/coordinator/src/coordinator.rs | 30 ++++++++++++++------------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/shared/client/src/state/steps.rs b/shared/client/src/state/steps.rs index b55be746f..bf5dfb4e3 100644 --- a/shared/client/src/state/steps.rs +++ b/shared/client/src/state/steps.rs @@ -59,6 +59,7 @@ pub struct StepStateMachine, sent_warmup_finished: bool, sent_warmup_witness: bool, + sent_cooldown_witness: bool, coordinator_state: Coordinator, @@ -166,6 +167,7 @@ impl StepStateMachine StepStateMachine StepStateMachine { /// `get_historical_clients` is what you actually want. pub clients: FixedVec, { SOLANA_MAX_NUM_CLIENTS }>, pub exited_clients: FixedVec, { SOLANA_MAX_NUM_CLIENTS }>, - pub checkpointer: Client, + pub checkpointers: FixedVec, pub rounds_head: u32, pub start_step: u32, pub last_step: u32, @@ -414,7 +414,7 @@ impl Default for CoordinatorEpochState { first_round: true.into(), clients: Default::default(), exited_clients: Default::default(), - checkpointer: Default::default(), + checkpointers: Default::default(), cold_start_epoch: false.into(), start_step: Default::default(), last_step: Default::default(), @@ -514,8 +514,8 @@ impl Coordinator { pub fn cooldown_witness( &mut self, - from: &T, - _witness: Witness, + _from: &T, + witness: Witness, ) -> std::result::Result<(), CoordinatorError> { if self.halted() { return Err(CoordinatorError::Halted); @@ -525,19 +525,19 @@ impl Coordinator { return Err(CoordinatorError::InvalidRunState); } - let client_index = self - .epoch_state - .clients - .iter() - .position(|x| x.id == *from) - .unwrap(); - let checkpointer_selection = CheckpointerSelection::from_coordinator(self, 0)?; let is_checkpointer = checkpointer_selection - .get_checkpointer(client_index as u64, self.epoch_state.clients.len() as u64); + .get_checkpointer(witness.proof.index, self.epoch_state.clients.len() as u64); if !is_checkpointer { return Err(CoordinatorError::InvalidWitness); } else { + self.epoch_state + .checkpointers + .push(witness) + .map_err(|_| CoordinatorError::WitnessesFull)?; + } + + if self.epoch_state.checkpointers.len() == self.config.checkpointer_nodes as usize { self.epoch_state.checkpointed = true; } @@ -654,8 +654,10 @@ impl Coordinator { if !matches!(self.run_state, RunState::Cooldown) { return Err(CoordinatorError::InvalidRunState); } - - if self.epoch_state.checkpointer.id != *from { + let checkpointer_selection = CheckpointerSelection::from_coordinator(self, 0)?; + let is_checkpointer = checkpointer_selection + .get_checkpointer(index as u64, self.epoch_state.clients.len() as u64); + if !is_checkpointer { return Err(CoordinatorError::InvalidWitness); } else { self.epoch_state.checkpointed = true; From 410c192217051fb40f01930f44424d16a13ec1b4 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Thu, 8 Jan 2026 12:43:46 -0300 Subject: [PATCH 15/72] Fix lint --- .../tests/suites/memnet_coordinator_full_round.rs | 1 + .../tests/suites/memnet_coordinator_rewards.rs | 1 + .../tests/suites/memnet_treasurer_create_update.rs | 1 + .../tests/suites/memnet_treasurer_full_epoch.rs | 1 + website/backend/src/coordinatorChainLoop.ts | 11 +++++++++++ 5 files changed, 15 insertions(+) diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_full_round.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_full_round.rs index fc03202cf..12f53a6de 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_full_round.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_full_round.rs @@ -106,6 +106,7 @@ pub async fn run() { global_batch_size_warmup_tokens: 0, verification_percent: 0, witness_nodes: 1, + checkpointer_nodes: 0, epoch_time: 30, total_steps: 100, waiting_for_members_extra_time: 3, diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_rewards.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_rewards.rs index f69bbf8c5..9670078a1 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_rewards.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_rewards.rs @@ -102,6 +102,7 @@ pub async fn run() { global_batch_size_warmup_tokens: 0, verification_percent: 0, witness_nodes: 0, + checkpointer_nodes: 0, epoch_time, waiting_for_members_extra_time: WAITING_FOR_MEMBERS_EXTRA_SECONDS as u8, diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_create_update.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_create_update.rs index e51ced2dd..863f881f7 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_create_update.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_create_update.rs @@ -49,6 +49,7 @@ pub async fn run() { global_batch_size_warmup_tokens: 0, verification_percent: 0, witness_nodes: 1, + checkpointer_nodes: 0, epoch_time: 30, total_steps: 100, waiting_for_members_extra_time: WAITING_FOR_MEMBERS_EXTRA_SECONDS diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_full_epoch.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_full_epoch.rs index 014772e32..592b0819b 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_full_epoch.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_full_epoch.rs @@ -225,6 +225,7 @@ pub async fn run() { global_batch_size_warmup_tokens: 0, verification_percent: 0, witness_nodes: 0, + checkpointer_nodes: 0, epoch_time, total_steps: 100, waiting_for_members_extra_time: 3, diff --git a/website/backend/src/coordinatorChainLoop.ts b/website/backend/src/coordinatorChainLoop.ts index 21aa04dc2..687e407da 100644 --- a/website/backend/src/coordinatorChainLoop.ts +++ b/website/backend/src/coordinatorChainLoop.ts @@ -342,6 +342,17 @@ export async function startWatchCoordinatorChainLoop( }) break } + case 'cooldown_witness': { + const runPdaAddr = i.accounts[1].toString() + const coordinatorAddr = i.accounts[2].toString() + runUpdates.getAndTouchCurrentRun({ + runPdaAddr, + coordinatorAddr, + decoded, + tx, + }) + break + } case 'update_client_version': { const runPdaAddr = i.accounts[1].toString() const coordinatorAddr = i.accounts[2].toString() From 3ee7cb68f52f8049ac3a37dffc48cd24eaea838e Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Thu, 8 Jan 2026 12:44:27 -0300 Subject: [PATCH 16/72] Update nano config --- config/solana-test/nano-config.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/config/solana-test/nano-config.toml b/config/solana-test/nano-config.toml index c275feea3..d16a02f5b 100644 --- a/config/solana-test/nano-config.toml +++ b/config/solana-test/nano-config.toml @@ -8,6 +8,7 @@ min_clients = 1 init_min_clients = 1 verification_percent = 0 witness_nodes = 1 +checkpoint_nodes = 0 global_batch_size_start = 4 global_batch_size_end = 4 global_batch_size_warmup_tokens = 0 From 2b07c7dd07004d0891554b5332198879e5095cd9 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Thu, 8 Jan 2026 12:59:15 -0300 Subject: [PATCH 17/72] Remove send_checkpoint function from backend --- architectures/centralized/client/src/app.rs | 5 ---- architectures/centralized/server/src/app.rs | 4 ---- .../solana-client/src/backend.rs | 22 ----------------- shared/client/src/client.rs | 24 ------------------- shared/data-provider/examples/tcp.rs | 4 ---- 5 files changed, 59 deletions(-) diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index 25540bc0e..1c16826bf 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -65,11 +65,6 @@ impl WatcherBackend for Backend { self.tx.send(ToSend::HealthCheck(health_checks))?; Ok(()) } - - // async fn send_checkpoint(&mut self, checkpoint: model::HubRepo) -> Result<()> { - // self.tx.send(ToSend::Checkpoint(checkpoint))?; - // Ok(()) - // } } pub struct App { diff --git a/architectures/centralized/server/src/app.rs b/architectures/centralized/server/src/app.rs index 15c2de53a..6537e87c8 100644 --- a/architectures/centralized/server/src/app.rs +++ b/architectures/centralized/server/src/app.rs @@ -79,10 +79,6 @@ impl psyche_watcher::Backend for ChannelCoordinatorBackend { async fn send_health_check(&mut self, _health_checks: HealthChecks) -> Result<()> { bail!("Server does not send health checks"); } - - // async fn send_checkpoint(&mut self, _checkpoint: model::HubRepo) -> Result<()> { - // bail!("Server does not send checkpoints"); - // } } type DataServer = diff --git a/architectures/decentralized/solana-client/src/backend.rs b/architectures/decentralized/solana-client/src/backend.rs index bf35d5496..d8f68072e 100644 --- a/architectures/decentralized/solana-client/src/backend.rs +++ b/architectures/decentralized/solana-client/src/backend.rs @@ -335,22 +335,6 @@ impl SolanaBackend { self.spawn_scheduled_send("Health check", &[instruction], &[]); } - // pub fn send_checkpoint( - // &self, - // coordinator_instance: Pubkey, - // coordinator_account: Pubkey, - // repo: HubRepo, - // ) { - // let user = self.get_payer(); - // let instruction = instructions::coordinator_checkpoint( - // &coordinator_instance, - // &coordinator_account, - // &user, - // repo, - // ); - // self.spawn_scheduled_send("Checkpoint", &[instruction], &[]); - // } - pub fn find_join_authorization(join_authority: &Pubkey, authorizer: Option) -> Pubkey { psyche_solana_authorizer::find_authorization( join_authority, @@ -608,12 +592,6 @@ impl WatcherBackend for SolanaBackendRunner } Ok(()) } - - // async fn send_checkpoint(&mut self, checkpoint: HubRepo) -> Result<()> { - // self.backend - // .send_checkpoint(self.instance, self.account, checkpoint); - // Ok(()) - // } } impl SolanaBackendRunner { diff --git a/shared/client/src/client.rs b/shared/client/src/client.rs index afa123c31..479c3a2a9 100644 --- a/shared/client/src/client.rs +++ b/shared/client/src/client.rs @@ -673,30 +673,6 @@ impl + 'sta let p2p_shutdown = p2p.shutdown(); - // if wait_for_checkpoint { - // info!("Waiting for all pending checkpoints to finish"); - - // // Keep waiting for checkpoints while there are uploads pending - // let mut checkpoint_check_interval = interval(Duration::from_secs(10)); - // while run.doing_checkpoint() { - // tokio::select! { - // checkpoint = rx_checkpoint.recv() => { - // if let Some(checkpoint) = checkpoint { - // info!("Checkpoint upload completed, sending to Solana"); - // watcher.backend_mut().send_checkpoint(checkpoint).await?; - // } else { - // // Channel closed, no more checkpoints coming - // break; - // } - // } - // _ = checkpoint_check_interval.tick() => { - // } - // } - // } - - // info!("All checkpoints finished, exiting main client loop"); - // } - p2p_shutdown .await .map_err(|e| anyhow!("Error shutting down p2p: {}", e)) diff --git a/shared/data-provider/examples/tcp.rs b/shared/data-provider/examples/tcp.rs index 13dcf8b44..38cb5e88b 100644 --- a/shared/data-provider/examples/tcp.rs +++ b/shared/data-provider/examples/tcp.rs @@ -36,10 +36,6 @@ impl WatcherBackend for DummyBackend { async fn send_health_check(&mut self, _health_checks: HealthChecks) -> anyhow::Result<()> { bail!("Data provider does not send health check"); } - - async fn send_checkpoint(&mut self, _checkpoint: model::HubRepo) -> anyhow::Result<()> { - bail!("Data provider does not send checkpoints"); - } } #[derive( From 7fb8add398720c72abb5541295d0742c17b253f8 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Thu, 8 Jan 2026 09:24:37 -0800 Subject: [PATCH 18/72] Reduce total amount of checkpointers --- architectures/centralized/testing/src/test_utils.rs | 2 ++ shared/client/src/state/steps.rs | 1 - shared/coordinator/src/coordinator.rs | 4 +++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/architectures/centralized/testing/src/test_utils.rs b/architectures/centralized/testing/src/test_utils.rs index 1ec700b4e..7bed4561e 100644 --- a/architectures/centralized/testing/src/test_utils.rs +++ b/architectures/centralized/testing/src/test_utils.rs @@ -127,6 +127,7 @@ pub fn dummy_client_app_params_with_training_delay( run_id: &str, training_delay_secs: u64, ) -> AppParams { + std::env::set_var("HF_TOKEN", "dummy_token"); AppParams { cancel: CancellationToken::default(), server_addr: format!("localhost:{server_port}").to_string(), @@ -134,6 +135,7 @@ pub fn dummy_client_app_params_with_training_delay( "dummy", "--run-id", run_id, "--iroh-relay", "disabled", + "--hub-repo", "dummy/repo", "--iroh-discovery", "local", "--data-parallelism", "1", "--tensor-parallelism", "1", diff --git a/shared/client/src/state/steps.rs b/shared/client/src/state/steps.rs index bf5dfb4e3..f0e093c8e 100644 --- a/shared/client/src/state/steps.rs +++ b/shared/client/src/state/steps.rs @@ -397,7 +397,6 @@ impl StepStateMachine { /// `get_historical_clients` is what you actually want. pub clients: FixedVec, { SOLANA_MAX_NUM_CLIENTS }>, pub exited_clients: FixedVec, { SOLANA_MAX_NUM_CLIENTS }>, - pub checkpointers: FixedVec, + pub checkpointers: FixedVec, pub rounds_head: u32, pub start_step: u32, pub last_step: u32, @@ -1235,6 +1236,7 @@ impl CoordinatorConfig { && self.total_steps != 0 && self.witness_nodes <= self.min_clients && self.checkpointer_nodes <= self.min_clients + && self.checkpointer_nodes as usize <= SOLANA_MAX_NUM_CHECKPOINTERS && self.witness_nodes as usize <= SOLANA_MAX_NUM_WITNESSES && self.cooldown_time > 0 && self.waiting_for_members_extra_time > 0 From e7421ed83d7c98810d26faaca980e718c36c7f5c Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Thu, 8 Jan 2026 15:40:16 -0300 Subject: [PATCH 19/72] Remove check for permissions on hf repo --- architectures/centralized/client/src/app.rs | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index 1c16826bf..cfebd1cfb 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -1,6 +1,5 @@ use anyhow::{Error, Result}; use bytemuck::Zeroable; -use hf_hub::Repo; use psyche_centralized_shared::{ClientId, ClientToServerMessage, ServerToClientMessage}; use psyche_client::{ Client, ClientTUI, ClientTUIState, NC, RunInitConfig, TrainArgs, read_identity_secret_key, @@ -165,22 +164,6 @@ impl App { p2p: NC, state_options: RunInitConfig, ) -> Result<()> { - let api = hf_hub::api::tokio::ApiBuilder::new() - .with_token(Some( - state_options.checkpoint_config.hub_upload.hub_token.clone(), - )) - .build()?; - let repo_api = api.repo(Repo::new( - state_options.checkpoint_config.hub_upload.hub_repo.clone(), - hf_hub::RepoType::Model, - )); - if !repo_api.is_writable().await { - anyhow::bail!( - "Checkpoint upload repo {} is not writable with the passed API key.", - state_options.checkpoint_config.hub_upload.hub_repo - ) - } - self.server_conn .send(ClientToServerMessage::Join { run_id: self.run_id.clone(), From a66f0896fffcbd435853ae50e53e3f79881c2a61 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Thu, 8 Jan 2026 16:09:37 -0300 Subject: [PATCH 20/72] Fix compilation error after merge --- shared/client/src/state/cooldown.rs | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index 6114db07f..579520215 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -164,14 +164,10 @@ impl CooldownStepMetadata { .send(variables_clone) .map_err(|_| CheckpointError::SendCheckpoint)?; - let (variables, trainer) = if checkpoint_info.is_some() { - // convert from internal shape to serialized shape (e.g. torchtitan to hf) - tokio::task::spawn_blocking(|| (trainer.convert(Some(variables)), trainer)) - .await - .map_err(|_| CheckpointError::ExtractThreadCrashed)? - } else { - (variables, trainer) - }; + // convert from internal shape to serialized shape (e.g. torchtitan to hf) + let (variables, trainer) = tokio::task::spawn_blocking(|| (trainer.convert(Some(variables)), trainer)) + .await + .map_err(|_| CheckpointError::ExtractThreadCrashed)?; trainers.push(trainer); let evals = model_task_runner.start(trainers); From 5d99bf041a9dcca5439bd990859b7eb8b3c24150 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Thu, 8 Jan 2026 17:05:31 -0300 Subject: [PATCH 21/72] Use convert function only on python trainer --- shared/client/src/state/cooldown.rs | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index 579520215..a5fdec2f7 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -5,8 +5,7 @@ use psyche_coordinator::Coordinator; use psyche_core::NodeIdentity; use psyche_data_provider::{UploadModelError, upload_model_repo_async}; use psyche_modeling::{ - CausalLM, SaveSafetensorsError, Trainer, TrainerThreadCommunicationError, - save_tensors_into_safetensors, + SaveSafetensorsError, Trainer, TrainerThreadCommunicationError, save_tensors_into_safetensors, }; use std::{ cmp::Reverse, @@ -165,9 +164,16 @@ impl CooldownStepMetadata { .map_err(|_| CheckpointError::SendCheckpoint)?; // convert from internal shape to serialized shape (e.g. torchtitan to hf) - let (variables, trainer) = tokio::task::spawn_blocking(|| (trainer.convert(Some(variables)), trainer)) - .await - .map_err(|_| CheckpointError::ExtractThreadCrashed)?; + let (variables, trainer) = match trainer { + #[cfg(feature = "python")] + Trainer::PythonDistributed(_) => { + info!("Converting distributed trainer variables for checkpointing..."); + tokio::task::spawn_blocking(|| (trainer.convert(Some(variables)), trainer)) + .await + .map_err(|_| CheckpointError::ExtractThreadCrashed)? + } + _ => (variables, trainer), + }; trainers.push(trainer); let evals = model_task_runner.start(trainers); From 2f77f46d7dcc116baef88eb7dcb52be98367dd39 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Thu, 8 Jan 2026 17:32:14 -0300 Subject: [PATCH 22/72] Add conditional import for python feature --- shared/client/src/state/cooldown.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index a5fdec2f7..ed4badd53 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -4,6 +4,8 @@ use psyche_coordinator::CheckpointerSelection; use psyche_coordinator::Coordinator; use psyche_core::NodeIdentity; use psyche_data_provider::{UploadModelError, upload_model_repo_async}; +#[cfg(feature = "python")] +use psyche_modeling::CausalLM; use psyche_modeling::{ SaveSafetensorsError, Trainer, TrainerThreadCommunicationError, save_tensors_into_safetensors, }; From 284ae5d5d84a394cb120453090e4c3b680d7930d Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Thu, 8 Jan 2026 14:01:07 -0800 Subject: [PATCH 23/72] Fix decentralized integration test config and entrypoints --- config/solana-test/nano-config.toml | 2 +- docker/test/client_test_entrypoint.sh | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/config/solana-test/nano-config.toml b/config/solana-test/nano-config.toml index d16a02f5b..0d05faff2 100644 --- a/config/solana-test/nano-config.toml +++ b/config/solana-test/nano-config.toml @@ -8,7 +8,7 @@ min_clients = 1 init_min_clients = 1 verification_percent = 0 witness_nodes = 1 -checkpoint_nodes = 0 +checkpointer_nodes = 0 global_batch_size_start = 4 global_batch_size_end = 4 global_batch_size_warmup_tokens = 0 diff --git a/docker/test/client_test_entrypoint.sh b/docker/test/client_test_entrypoint.sh index 64e739841..a67210b0b 100644 --- a/docker/test/client_test_entrypoint.sh +++ b/docker/test/client_test_entrypoint.sh @@ -11,20 +11,22 @@ echo "USING SIDECAR PORT: ${SIDECAR_PORT}" # Build the command based on environment variable if [ "${PYTHON_ENABLED}" = "true" ]; then echo "Starting client with Python features enabled" - psyche-solana-client train \ + HF_TOKEN="test" psyche-solana-client train \ --wallet-private-key-path "/root/.config/solana/id.json" \ --rpc "${RPC}" \ --ws-rpc "${WS_RPC}" \ --run-id "${RUN_ID}" \ + --hub-repo "dummy/test-hub-repo" \ --data-parallelism 8 \ --sidecar-port "${SIDECAR_PORT}" \ --logs "json" else echo "Starting client without Python features" - psyche-solana-client train \ + HF_TOKEN="test" psyche-solana-client train \ --wallet-private-key-path "/root/.config/solana/id.json" \ --rpc "${RPC}" \ --ws-rpc "${WS_RPC}" \ + --hub-repo "dummy/test-hub-repo" \ --run-id "${RUN_ID}" \ --logs "json" fi From ba70ece897695ce9b7b2c169fabc137714aa37b8 Mon Sep 17 00:00:00 2001 From: Mariano Nicolini Date: Fri, 9 Jan 2026 12:00:40 -0800 Subject: [PATCH 24/72] fix bug and remove debug prints --- shared/client/src/state/steps.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/shared/client/src/state/steps.rs b/shared/client/src/state/steps.rs index f0e093c8e..571dc907c 100644 --- a/shared/client/src/state/steps.rs +++ b/shared/client/src/state/steps.rs @@ -174,7 +174,6 @@ impl StepStateMachine Result<(), OpportunisticWitnessError> { - println!("CURRENT STATE = {:?}", self.coordinator_state.run_state); if self.current_round.committee_info.is_some() && !matches!( self.coordinator_state.run_state, @@ -369,6 +368,7 @@ impl StepStateMachine StepStateMachine { + self.sent_cooldown_witness = false; let (trainers, upload_handle) = cooldown.finish().await?; if let Some(handle) = upload_handle { self.pending_upload_handles.push(handle); From f82e387c3c51fc9c603c598496e7d07565a51185 Mon Sep 17 00:00:00 2001 From: Mariano Nicolini Date: Fri, 9 Jan 2026 12:07:15 -0800 Subject: [PATCH 25/72] remove unnecesary code --- shared/coordinator/src/coordinator.rs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index b99bb8717..a1e9977c5 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -985,12 +985,6 @@ impl Coordinator { ) .unwrap(); - // self.epoch_state.checkpointer = self - // .epoch_state - // .clients - // .get(0) - // .cloned() - // .expect("at least one client"); self.start_warmup(unix_timestamp); } From 10e81f3407e6943a83377c397fec414bf6a6ece4 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Fri, 9 Jan 2026 18:45:10 -0300 Subject: [PATCH 26/72] Add GCS checkpoint variant --- architectures/centralized/client/src/app.rs | 23 +-- architectures/centralized/server/src/app.rs | 2 +- .../centralized/shared/src/protocol.rs | 2 +- .../solana-client/src/backend.rs | 5 +- .../solana-client/src/command/checkpoint.rs | 3 +- .../solana-client/src/instructions.rs | 2 +- .../solana-coordinator/src/instance_state.rs | 8 +- .../programs/solana-coordinator/src/lib.rs | 3 +- python/default.nix | 2 +- shared/client/src/cli.rs | 60 +++++++- shared/client/src/lib.rs | 3 +- shared/client/src/state/cooldown.rs | 138 +++++++++++------- shared/client/src/state/init.rs | 2 +- shared/client/src/state/mod.rs | 5 +- shared/client/src/state/types.rs | 14 +- shared/coordinator/src/coordinator.rs | 15 +- shared/data-provider/src/hub.rs | 54 +++++++ shared/data-provider/src/lib.rs | 1 + shared/watcher/src/traits.rs | 2 +- 19 files changed, 258 insertions(+), 86 deletions(-) diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index 3560e0a7a..e3eb19019 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -2,8 +2,10 @@ use anyhow::{Error, Result}; use bytemuck::Zeroable; use hf_hub::Repo; use psyche_centralized_shared::{ClientId, ClientToServerMessage, ServerToClientMessage}; +use psyche_client::UploadInfo; use psyche_client::{ - Client, ClientTUI, ClientTUIState, NC, RunInitConfig, TrainArgs, read_identity_secret_key, + Client, ClientTUI, ClientTUIState, HubUploadInfo, NC, RunInitConfig, TrainArgs, + read_identity_secret_key, }; use psyche_coordinator::{Coordinator, HealthChecks, model}; use psyche_metrics::ClientMetrics; @@ -29,7 +31,7 @@ pub type TabsData = ::Data; pub enum ToSend { Witness(Box), HealthCheck(HealthChecks), - Checkpoint(model::HubRepo), + Checkpoint(model::Checkpoint), } struct Backend { @@ -67,7 +69,7 @@ impl WatcherBackend for Backend { Ok(()) } - async fn send_checkpoint(&mut self, checkpoint: model::HubRepo) -> Result<()> { + async fn send_checkpoint(&mut self, checkpoint: model::Checkpoint) -> Result<()> { self.tx.send(ToSend::Checkpoint(checkpoint))?; Ok(()) } @@ -173,18 +175,19 @@ impl App { ) -> Result<()> { // sanity checks if let Some(checkpoint_config) = &state_options.checkpoint_config { - if let Some(hub_upload) = &checkpoint_config.hub_upload { + if let Some(UploadInfo::Hub(HubUploadInfo { + hub_repo, + hub_token, + })) = &checkpoint_config.hub_upload + { let api = hf_hub::api::tokio::ApiBuilder::new() - .with_token(Some(hub_upload.hub_token.clone())) + .with_token(Some(hub_token.clone())) .build()?; - let repo_api = api.repo(Repo::new( - hub_upload.hub_repo.clone(), - hf_hub::RepoType::Model, - )); + let repo_api = api.repo(Repo::new(hub_repo.clone(), hf_hub::RepoType::Model)); if !repo_api.is_writable().await { anyhow::bail!( "Checkpoint upload repo {} is not writable with the passed API key.", - hub_upload.hub_repo + hub_repo ) } } diff --git a/architectures/centralized/server/src/app.rs b/architectures/centralized/server/src/app.rs index 6707b49bb..402146817 100644 --- a/architectures/centralized/server/src/app.rs +++ b/architectures/centralized/server/src/app.rs @@ -81,7 +81,7 @@ impl psyche_watcher::Backend for ChannelCoordinatorBackend { bail!("Server does not send health checks"); } - async fn send_checkpoint(&mut self, _checkpoint: model::HubRepo) -> Result<()> { + async fn send_checkpoint(&mut self, _checkpoint: model::Checkpoint) -> Result<()> { bail!("Server does not send checkpoints"); } } diff --git a/architectures/centralized/shared/src/protocol.rs b/architectures/centralized/shared/src/protocol.rs index c71c7524f..a96c64b8f 100644 --- a/architectures/centralized/shared/src/protocol.rs +++ b/architectures/centralized/shared/src/protocol.rs @@ -15,7 +15,7 @@ pub enum ClientToServerMessage { Join { run_id: String }, Witness(Box), HealthCheck(HealthChecks), - Checkpoint(model::HubRepo), + Checkpoint(model::Checkpoint), } #[derive(Serialize, Deserialize, Debug, Clone)] diff --git a/architectures/decentralized/solana-client/src/backend.rs b/architectures/decentralized/solana-client/src/backend.rs index 90f94e528..f948cde94 100644 --- a/architectures/decentralized/solana-client/src/backend.rs +++ b/architectures/decentralized/solana-client/src/backend.rs @@ -20,6 +20,7 @@ use anchor_client::{ use anyhow::{Context, Result, anyhow}; use futures_util::StreamExt; use psyche_client::IntegrationTestLogMarker; +use psyche_coordinator::model::{self, Checkpoint}; use psyche_coordinator::{CommitteeProof, Coordinator, HealthChecks, model::HubRepo}; use psyche_watcher::{Backend as WatcherBackend, OpportunisticData}; use solana_account_decoder_client_types::{UiAccount, UiAccountEncoding}; @@ -333,7 +334,7 @@ impl SolanaBackend { &self, coordinator_instance: Pubkey, coordinator_account: Pubkey, - repo: HubRepo, + repo: Checkpoint, ) { let user = self.get_payer(); let instruction = instructions::coordinator_checkpoint( @@ -603,7 +604,7 @@ impl WatcherBackend for SolanaBackendRunner Ok(()) } - async fn send_checkpoint(&mut self, checkpoint: HubRepo) -> Result<()> { + async fn send_checkpoint(&mut self, checkpoint: model::Checkpoint) -> Result<()> { self.backend .send_checkpoint(self.instance, self.account, checkpoint); Ok(()) diff --git a/architectures/decentralized/solana-client/src/command/checkpoint.rs b/architectures/decentralized/solana-client/src/command/checkpoint.rs index 73adc2105..3f8cc9f2b 100644 --- a/architectures/decentralized/solana-client/src/command/checkpoint.rs +++ b/architectures/decentralized/solana-client/src/command/checkpoint.rs @@ -1,5 +1,6 @@ use anyhow::Result; use clap::Args; +use psyche_coordinator::model::Checkpoint; use psyche_coordinator::model::HubRepo; use psyche_core::FixedString; @@ -45,7 +46,7 @@ pub async fn command_checkpoint_execute( &coordinator_instance, &coordinator_account, &user, - repo, + Checkpoint::Hub(repo.clone()), ); let signature = backend .send_and_retry("Checkpoint", &[instruction], &[]) diff --git a/architectures/decentralized/solana-client/src/instructions.rs b/architectures/decentralized/solana-client/src/instructions.rs index 68e169008..b535d54f4 100644 --- a/architectures/decentralized/solana-client/src/instructions.rs +++ b/architectures/decentralized/solana-client/src/instructions.rs @@ -206,7 +206,7 @@ pub fn coordinator_checkpoint( coordinator_instance: &Pubkey, coordinator_account: &Pubkey, user: &Pubkey, - repo: psyche_coordinator::model::HubRepo, + repo: psyche_coordinator::model::Checkpoint, ) -> Instruction { anchor_instruction( psyche_solana_coordinator::ID, diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs index 029b40fad..7ea079bd0 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs @@ -10,8 +10,8 @@ use psyche_coordinator::RunState; use psyche_coordinator::SOLANA_MAX_STRING_LEN; use psyche_coordinator::TickResult; use psyche_coordinator::Witness; -use psyche_coordinator::model::HubRepo; use psyche_coordinator::model::Model; +use psyche_coordinator::model::{Checkpoint, HubRepo}; use psyche_core::FixedString; use psyche_core::SmallBoolean; use psyche_core::sha256v; @@ -389,7 +389,11 @@ impl CoordinatorInstanceState { self.tick() } - pub fn checkpoint(&mut self, payer: &Pubkey, repo: HubRepo) -> Result<()> { + pub fn checkpoint( + &mut self, + payer: &Pubkey, + repo: Checkpoint, + ) -> Result<()> { // O(n) on clients, reconsider let id = self.clients_state.find_signer(payer)?; let index = self diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs index 657424195..f7feb7981 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs @@ -165,6 +165,7 @@ impl CoordinatorInstance { pub mod psyche_solana_coordinator { use super::*; + use psyche_coordinator::model::Checkpoint; use psyche_core::FixedString; pub fn init_coordinator( @@ -313,7 +314,7 @@ pub mod psyche_solana_coordinator { pub fn checkpoint( ctx: Context, - repo: HubRepo, + repo: psyche_coordinator::model::Checkpoint, ) -> Result<()> { let mut account = ctx.accounts.coordinator_account.load_mut()?; account.increment_nonce(); diff --git a/python/default.nix b/python/default.nix index 0ae9670bd..196b07867 100644 --- a/python/default.nix +++ b/python/default.nix @@ -32,7 +32,6 @@ let # packages that we provide to the venv via nix derivations topLevelNixPkgs = [ "torch" - "vllm" ] ++ lib.optionals stdenvNoCC.hostPlatform.isLinux [ "flash-attn" @@ -40,6 +39,7 @@ let # i'm really not a fan of providing torchtitan like this. i'd much rather have it be built as a git dep via uv2nix. # i think there's room to figure out how to provide setuptools for it. "torchtitan" + "vllm" ]; nixProvidedPythonPkgs = getAllTransitiveDeps topLevelNixPkgs; diff --git a/shared/client/src/cli.rs b/shared/client/src/cli.rs index 268ea753a..2f6f2b080 100644 --- a/shared/client/src/cli.rs +++ b/shared/client/src/cli.rs @@ -1,5 +1,7 @@ use crate::{CheckpointConfig, HubUploadInfo, WandBInfo}; +use crate::GcsUploadInfo; +use crate::UploadInfo; use anyhow::{Result, anyhow, bail}; use clap::Args; use psyche_eval::tasktype_from_name; @@ -146,6 +148,14 @@ pub struct TrainArgs { #[clap(long, env)] pub hub_repo: Option, + /// Path to the Hugging Face repository containing model data and configuration. + #[clap(long, env)] + pub gcs_bucket: Option, + + /// Path to the Hugging Face repository containing model data and configuration. + #[clap(long, env)] + pub gcs_prefix: Option, + #[clap(long, env, default_value_t = 3)] pub hub_max_concurrent_downloads: usize, @@ -227,37 +237,73 @@ impl TrainArgs { let checkpoint_upload_info = match ( &hub_read_token, self.hub_repo.clone(), + self.gcs_bucket.clone(), + self.gcs_prefix.clone(), self.checkpoint_dir.clone(), self.delete_old_steps, self.keep_steps, ) { - (Some(token), Some(repo), Some(dir), delete_old_steps, keep_steps) => { + (_, Some(_), Some(_), _, _, _, _) => { + bail!("Use either GCS or HF hub for checkpoint uploads, not both.") + } + (Some(token), Some(repo), None, _, Some(dir), delete_old_steps, keep_steps) => { if keep_steps == 0 { bail!("keep_steps must be >= 1 for hub repository uploads (got {keep_steps})") } Some(CheckpointConfig { checkpoint_dir: dir, - hub_upload: Some(HubUploadInfo { + hub_upload: Some(UploadInfo::Hub(HubUploadInfo { hub_repo: repo, hub_token: token.to_string(), - }), + })), + delete_old_steps, + keep_steps, + }) + } + (_, _, Some(gcp_bucket), Some(gcs_prefix), Some(dir), delete_old_steps, keep_steps) => { + if keep_steps == 0 { + bail!("keep_steps must be >= 1 for GCS uploads (got {keep_steps})") + } + Some(CheckpointConfig { + checkpoint_dir: dir, + hub_upload: Some(UploadInfo::Gcs(GcsUploadInfo { + gcs_bucket: gcp_bucket, + gcs_prefix: Some(gcs_prefix), + })), + delete_old_steps, + keep_steps, + }) + } + (_, _, Some(gcp_bucket), None, Some(dir), delete_old_steps, keep_steps) => { + if keep_steps == 0 { + bail!("keep_steps must be >= 1 for GCS uploads (got {keep_steps})") + } + Some(CheckpointConfig { + checkpoint_dir: dir, + hub_upload: Some(UploadInfo::Gcs(GcsUploadInfo { + gcs_bucket: gcp_bucket, + gcs_prefix: None, + })), delete_old_steps, keep_steps, }) } - (None, Some(_), Some(_), _, _) => { + (None, Some(_), None, _, _, _, _) => { bail!("hub-repo and checkpoint-dir set, but no HF_TOKEN env variable.") } - (_, Some(_), None, _, _) => { + (_, None, Some(_), _, None, _, _) => { + bail!("gcs-bucket and checkpoint-dir set, but no GCS_TOKEN env variable.") + } + (_, Some(_), None, _, None, _, _) => { bail!("--hub-repo was set, but no --checkpoint-dir was passed!") } - (_, None, Some(dir), delete_old_steps, keep_steps) => Some(CheckpointConfig { + (_, None, None, _, Some(dir), delete_old_steps, keep_steps) => Some(CheckpointConfig { checkpoint_dir: dir, hub_upload: None, delete_old_steps, keep_steps, }), - (_, None, _, _, _) => None, + (_, None, None, _, _, _, _) => None, }; Ok(checkpoint_upload_info) diff --git a/shared/client/src/lib.rs b/shared/client/src/lib.rs index f1c545ed7..8a0086932 100644 --- a/shared/client/src/lib.rs +++ b/shared/client/src/lib.rs @@ -10,7 +10,8 @@ pub use cli::{TrainArgs, prepare_environment, print_identity_keys, read_identity pub use client::Client; pub use protocol::{Broadcast, BroadcastType, Finished, NC, TrainingResult}; pub use state::{ - CheckpointConfig, HubUploadInfo, InitRunError, RoundState, RunInitConfig, RunInitConfigAndIO, + CheckpointConfig, GcsUploadInfo, HubUploadInfo, InitRunError, RoundState, RunInitConfig, + RunInitConfigAndIO, UploadInfo, }; pub use testing::IntegrationTestLogMarker; pub use tui::{ClientTUI, ClientTUIState}; diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index 6e2efb6ef..5b0a567e2 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -1,11 +1,12 @@ -use crate::HubUploadInfo; +use crate::{HubUploadInfo, state::types::GcsUploadInfo}; +use crate::UploadInfo; use psyche_coordinator::{ Coordinator, - model::{self, HubRepo}, + model::{self, GcsRepo, HubRepo}, }; use psyche_core::{FixedString, NodeIdentity}; -use psyche_data_provider::{UploadModelError, upload_model_repo_async}; +use psyche_data_provider::{UploadModelError, upload_model_repo_async, upload_to_gcs_async}; use psyche_modeling::{ CausalLM, SaveSafetensorsError, Trainer, TrainerThreadCommunicationError, save_tensors_into_safetensors, @@ -42,7 +43,7 @@ pub enum CooldownError { } pub struct CooldownStepMetadata { - tx_checkpoint: mpsc::UnboundedSender, + tx_checkpoint: mpsc::UnboundedSender, tx_model: mpsc::UnboundedSender>, checkpoint_info: Option, checkpoint_extra_files: Vec, @@ -59,7 +60,7 @@ pub struct CooldownStepMetadata { impl CooldownStepMetadata { pub fn new( - tx_checkpoint: mpsc::UnboundedSender, + tx_checkpoint: mpsc::UnboundedSender, tx_model: mpsc::UnboundedSender>, checkpoint_info: Option, checkpoint_extra_files: Vec, @@ -207,53 +208,90 @@ impl CooldownStepMetadata { local.push(to); } - let Some(HubUploadInfo { - hub_repo, - hub_token, - }) = hub_upload - else { - cleanup_dirs( - delete_queue, - keep_steps, - run_id, - delete_old_steps, - step, - checkpoint_dir, - ) - .await; - return Ok::<(), CheckpointError>(()); - }; - - info!(repo = hub_repo, "Uploading checkpoint to HuggingFace"); - let revision = match upload_model_repo_async( - hub_repo.clone(), - local, - hub_token.clone(), - Some(format!("step {step}")), - None, - ) - .await - { - Ok(revision) => { - info!( - repo = hub_repo, - revision = revision, - "Upload to HuggingFace complete" - ); - revision + match hub_upload { + Some(UploadInfo::Gcs(GcsUploadInfo { + gcs_bucket, + gcs_prefix, + })) => { + info!(bucket = gcs_bucket, "Uploading checkpoint to GCS"); + match upload_to_gcs_async(gcs_bucket.clone(), local, gcs_prefix.clone()) + .await + { + Ok(path) => { + info!( + "Upload to GCS complete at gs://{}/{}", + gcs_bucket, + gcs_prefix.clone().unwrap_or_default() + ); + path + } + Err(err) => { + error!(bucket = gcs_bucket, "Error uploading to GCS: {err:#}"); + return Err(err.into()); + } + }; + tx_checkpoint + .send(model::Checkpoint::Gcs(GcsRepo { + bucket: FixedString::from_str_truncated(&format!( + "gs://{}/{:?}", + gcs_bucket, gcs_prefix + )), + prefix: Some(FixedString::from_str_truncated(&format!( + "step-{step}" + ))), + })) + .map_err(|_| CheckpointError::SendCheckpoint)?; } - Err(err) => { - error!(repo = hub_repo, "Error uploading to HuggingFace: {err:#}"); - return Err(err.into()); + Some(UploadInfo::Hub(HubUploadInfo { + hub_repo, + hub_token, + })) => { + info!(repo = hub_repo, "Uploading checkpoint to HuggingFace"); + let revision = match upload_model_repo_async( + hub_repo.clone(), + local, + hub_token.clone(), + Some(format!("step {step}")), + None, + ) + .await + { + Ok(revision) => { + info!( + repo = hub_repo, + revision = revision, + "Upload to HuggingFace complete" + ); + revision + } + Err(err) => { + error!( + repo = hub_repo, + "Error uploading to HuggingFace: {err:#}" + ); + return Err(err.into()); + } + }; + tx_checkpoint + .send(model::Checkpoint::Hub(HubRepo { + repo_id: FixedString::from_str_truncated(&hub_repo), + revision: Some(FixedString::from_str_truncated(&revision)), + })) + .map_err(|_| CheckpointError::SendCheckpoint)?; } - }; - - tx_checkpoint - .send(HubRepo { - repo_id: FixedString::from_str_truncated(&hub_repo), - revision: Some(FixedString::from_str_truncated(&revision)), - }) - .map_err(|_| CheckpointError::SendCheckpoint)?; + None => { + cleanup_dirs( + delete_queue, + keep_steps, + run_id, + delete_old_steps, + step, + checkpoint_dir, + ) + .await; + return Ok::<(), CheckpointError>(()); + } + } // we put the cleanup step at the end, so that if keep_steps == 0 the logic will still work // we'll just delete the dir after we've uploaded it diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index 54073d3b2..3d396d904 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -154,7 +154,7 @@ pub struct RunInitConfigAndIO { pub tx_health_check: UnboundedSender>, pub tx_witness: UnboundedSender, - pub tx_checkpoint: UnboundedSender, + pub tx_checkpoint: UnboundedSender, pub tx_model: UnboundedSender>, pub tx_parameters_req: UnboundedSender<(Vec, OneshotModelParameterSender)>, pub tx_config: UnboundedSender<(String, String)>, diff --git a/shared/client/src/state/mod.rs b/shared/client/src/state/mod.rs index e4d73d5b9..a148eac59 100644 --- a/shared/client/src/state/mod.rs +++ b/shared/client/src/state/mod.rs @@ -16,4 +16,7 @@ mod witness; pub use init::{InitRunError, RunInitConfig, RunInitConfigAndIO}; pub use round_state::RoundState; pub use steps::{ApplyMessageOutcome, RunManager}; -pub use types::{CheckpointConfig, DistroBroadcastAndPayload, FinishedBroadcast, HubUploadInfo}; +pub use types::{ + CheckpointConfig, DistroBroadcastAndPayload, FinishedBroadcast, GcsUploadInfo, HubUploadInfo, + UploadInfo, +}; diff --git a/shared/client/src/state/types.rs b/shared/client/src/state/types.rs index 2edf22760..68e5e09f8 100644 --- a/shared/client/src/state/types.rs +++ b/shared/client/src/state/types.rs @@ -14,9 +14,21 @@ pub struct HubUploadInfo { pub hub_token: String, } +#[derive(Debug, Clone)] +pub struct GcsUploadInfo { + pub gcs_bucket: String, + pub gcs_prefix: Option, +} + +#[derive(Debug, Clone)] +pub enum UploadInfo { + Hub(HubUploadInfo), + Gcs(GcsUploadInfo), +} + #[derive(Debug, Clone)] pub struct CheckpointConfig { - pub hub_upload: Option, + pub hub_upload: Option, pub checkpoint_dir: PathBuf, pub delete_old_steps: bool, pub keep_steps: u32, diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index 0a3a9975d..8d18f199e 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -1,6 +1,6 @@ use crate::{ Commitment, Committee, CommitteeProof, CommitteeSelection, WitnessProof, - model::{Checkpoint, HubRepo, Model}, + model::{Checkpoint, GcsRepo, HubRepo, Model}, }; use anchor_lang::{AnchorDeserialize, AnchorSerialize, InitSpace, prelude::borsh}; @@ -596,7 +596,7 @@ impl Coordinator { &mut self, from: &T, index: u64, - hub_repo: HubRepo, + hub_repo: Checkpoint, ) -> std::result::Result<(), CoordinatorError> { let index = index as usize; if index >= self.epoch_state.clients.len() || self.epoch_state.clients[index].id != *from { @@ -607,8 +607,15 @@ impl Coordinator { // more download options. match &mut self.model { Model::LLM(llm) => match llm.checkpoint { - Checkpoint::P2P(_) => llm.checkpoint = Checkpoint::P2P(hub_repo), - Checkpoint::Hub(_) => llm.checkpoint = Checkpoint::Hub(hub_repo), + Checkpoint::P2P(HubRepo { repo_id, revision }) => { + llm.checkpoint = Checkpoint::P2P(HubRepo { repo_id, revision }) + } + Checkpoint::Hub(HubRepo { repo_id, revision }) => { + llm.checkpoint = Checkpoint::Hub(HubRepo { repo_id, revision }) + } + Checkpoint::Gcs(GcsRepo { bucket, prefix }) => { + llm.checkpoint = Checkpoint::Gcs(GcsRepo { bucket, prefix }) + } _ => {} }, } diff --git a/shared/data-provider/src/hub.rs b/shared/data-provider/src/hub.rs index ca090a759..df6938f74 100644 --- a/shared/data-provider/src/hub.rs +++ b/shared/data-provider/src/hub.rs @@ -1,3 +1,5 @@ +use google_cloud_storage::client::{Client, ClientConfig}; +use google_cloud_storage::http::objects::upload::{Media, UploadObjectRequest, UploadType}; use hf_hub::{ Cache, Repo, RepoType, api::{ @@ -262,3 +264,55 @@ pub async fn upload_model_repo_async( }; Ok(commit_info.oid) } + +pub async fn upload_to_gcs_async( + bucket: String, + files: Vec, + prefix: Option, +) -> Result<(), UploadModelError> { + let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { + info!("Using authenticated GCS client"); + ClientConfig::default().with_auth().await? + } else { + info!("Using anonymous GCS client"); + ClientConfig::default().anonymous() + }; + let client = Client::new(config); + + for path in files { + let file_name = path + .file_name() + .ok_or_else(|| UploadModelError::NotAFile(path.clone()))? + .to_str() + .ok_or_else(|| UploadModelError::InvalidFilename(path.clone()))?; + + let object_name = match &prefix { + Some(p) => format!("{}/{}", p, file_name), + None => file_name.to_string(), + }; + + let data = tokio::fs::read(&path).await.unwrap(); + + let upload_type = UploadType::Simple(Media::new(object_name.clone())); + let uploaded = client + .upload_object( + &UploadObjectRequest { + bucket: bucket.clone(), + ..Default::default() + }, + data, + &upload_type, + ) + .await + .unwrap(); + + info!( + bucket = bucket, + object = object_name, + size = uploaded.size, + "Successfully uploaded file to GCS" + ); + } + + Ok(()) +} diff --git a/shared/data-provider/src/lib.rs b/shared/data-provider/src/lib.rs index 7b97fe101..0c884bea3 100644 --- a/shared/data-provider/src/lib.rs +++ b/shared/data-provider/src/lib.rs @@ -19,6 +19,7 @@ pub use gcs::{GcsError, download_model_from_gcs_async, download_model_from_gcs_s pub use hub::{ UploadModelError, download_dataset_repo_async, download_dataset_repo_sync, download_model_repo_async, download_model_repo_sync, upload_model_repo_async, + upload_to_gcs_async, }; pub use local::LocalDataProvider; pub use parquet::record::{ListAccessor, MapAccessor, RowAccessor}; diff --git a/shared/watcher/src/traits.rs b/shared/watcher/src/traits.rs index 1cd96a7ed..50bf317bb 100644 --- a/shared/watcher/src/traits.rs +++ b/shared/watcher/src/traits.rs @@ -27,5 +27,5 @@ pub trait Backend: Send + Sync { async fn wait_for_new_state(&mut self) -> Result>; async fn send_witness(&mut self, opportunistic_data: OpportunisticData) -> Result<()>; async fn send_health_check(&mut self, health_check: HealthChecks) -> Result<()>; - async fn send_checkpoint(&mut self, checkpoint: model::HubRepo) -> Result<()>; + async fn send_checkpoint(&mut self, checkpoint: model::Checkpoint) -> Result<()>; } From 9fcf8083693845ba7e2c77521d790e5d54235238 Mon Sep 17 00:00:00 2001 From: Mariano Nicolini Date: Mon, 12 Jan 2026 06:36:24 -0800 Subject: [PATCH 27/72] polish nits on coordinator code --- shared/coordinator/src/coordinator.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index a1e9977c5..e8d94e558 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -531,13 +531,13 @@ impl Coordinator { .get_checkpointer(witness.proof.index, self.epoch_state.clients.len() as u64); if !is_checkpointer { return Err(CoordinatorError::InvalidWitness); - } else { - self.epoch_state - .checkpointers - .push(witness) - .map_err(|_| CoordinatorError::WitnessesFull)?; } + self.epoch_state + .checkpointers + .push(witness) + .map_err(|_| CoordinatorError::WitnessesFull)?; + if self.epoch_state.checkpointers.len() == self.config.checkpointer_nodes as usize { self.epoch_state.checkpointed = true; } @@ -660,10 +660,10 @@ impl Coordinator { .get_checkpointer(index as u64, self.epoch_state.clients.len() as u64); if !is_checkpointer { return Err(CoordinatorError::InvalidWitness); - } else { - self.epoch_state.checkpointed = true; } + self.epoch_state.checkpointed = true; + Ok(()) } From e893c7c01481cf99cfff0ca74e16986f3ee69117 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Mon, 12 Jan 2026 09:30:53 -0800 Subject: [PATCH 28/72] Fix convert call --- shared/client/src/cli.rs | 4 +--- shared/client/src/state/cooldown.rs | 17 ++++++++++------- shared/data-provider/src/hub.rs | 2 +- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/shared/client/src/cli.rs b/shared/client/src/cli.rs index 2f6f2b080..73e4f8a4c 100644 --- a/shared/client/src/cli.rs +++ b/shared/client/src/cli.rs @@ -291,9 +291,6 @@ impl TrainArgs { (None, Some(_), None, _, _, _, _) => { bail!("hub-repo and checkpoint-dir set, but no HF_TOKEN env variable.") } - (_, None, Some(_), _, None, _, _) => { - bail!("gcs-bucket and checkpoint-dir set, but no GCS_TOKEN env variable.") - } (_, Some(_), None, _, None, _, _) => { bail!("--hub-repo was set, but no --checkpoint-dir was passed!") } @@ -304,6 +301,7 @@ impl TrainArgs { keep_steps, }), (_, None, None, _, _, _, _) => None, + _ => todo!(), }; Ok(checkpoint_upload_info) diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index 5b0a567e2..94c5ac685 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -166,13 +166,16 @@ impl CooldownStepMetadata { .send(variables_clone) .map_err(|_| CheckpointError::SendCheckpoint)?; - let (variables, trainer) = if checkpoint_info.is_some() { - // convert from internal shape to serialized shape (e.g. torchtitan to hf) - tokio::task::spawn_blocking(|| (trainer.convert(Some(variables)), trainer)) - .await - .map_err(|_| CheckpointError::ExtractThreadCrashed)? - } else { - (variables, trainer) + // convert from internal shape to serialized shape (e.g. torchtitan to hf) + let (variables, trainer) = match trainer { + #[cfg(feature = "python")] + Trainer::PythonDistributed(_) => { + info!("Converting distributed trainer variables for checkpointing..."); + tokio::task::spawn_blocking(|| (trainer.convert(Some(variables)), trainer)) + .await + .map_err(|_| CheckpointError::ExtractThreadCrashed)? + } + _ => (variables, trainer), }; trainers.push(trainer); diff --git a/shared/data-provider/src/hub.rs b/shared/data-provider/src/hub.rs index df6938f74..d112e661e 100644 --- a/shared/data-provider/src/hub.rs +++ b/shared/data-provider/src/hub.rs @@ -272,7 +272,7 @@ pub async fn upload_to_gcs_async( ) -> Result<(), UploadModelError> { let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { info!("Using authenticated GCS client"); - ClientConfig::default().with_auth().await? + ClientConfig::default().with_auth().await.unwrap() } else { info!("Using anonymous GCS client"); ClientConfig::default().anonymous() From df99802ac1e7de6602ab3737c182fa0b15967fd4 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Mon, 12 Jan 2026 17:07:55 -0300 Subject: [PATCH 29/72] Refactor on model download and upload --- architectures/centralized/client/src/app.rs | 6 +- .../solana-client/src/backend.rs | 2 +- .../solana-client/src/command/checkpoint.rs | 2 +- .../solana-coordinator/src/instance_state.rs | 2 +- .../programs/solana-coordinator/src/lib.rs | 3 +- shared/client/src/cli.rs | 137 +++++++------- shared/client/src/state/cooldown.rs | 170 ++++++------------ shared/client/src/state/init.rs | 8 +- shared/client/src/state/mod.rs | 6 +- shared/client/src/state/types.rs | 15 +- shared/coordinator/src/coordinator.rs | 48 +++-- shared/data-provider/src/errors.rs | 47 +++++ shared/data-provider/src/gcs.rs | 104 +++++++++-- shared/data-provider/src/hub.rs | 145 +++++---------- shared/data-provider/src/lib.rs | 11 +- 15 files changed, 355 insertions(+), 351 deletions(-) create mode 100644 shared/data-provider/src/errors.rs diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index e3eb19019..8f6a9bcca 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -2,10 +2,10 @@ use anyhow::{Error, Result}; use bytemuck::Zeroable; use hf_hub::Repo; use psyche_centralized_shared::{ClientId, ClientToServerMessage, ServerToClientMessage}; +use psyche_client::HubUploadInfo; use psyche_client::UploadInfo; use psyche_client::{ - Client, ClientTUI, ClientTUIState, HubUploadInfo, NC, RunInitConfig, TrainArgs, - read_identity_secret_key, + Client, ClientTUI, ClientTUIState, NC, RunInitConfig, TrainArgs, read_identity_secret_key, }; use psyche_coordinator::{Coordinator, HealthChecks, model}; use psyche_metrics::ClientMetrics; @@ -178,7 +178,7 @@ impl App { if let Some(UploadInfo::Hub(HubUploadInfo { hub_repo, hub_token, - })) = &checkpoint_config.hub_upload + })) = &checkpoint_config.upload_info { let api = hf_hub::api::tokio::ApiBuilder::new() .with_token(Some(hub_token.clone())) diff --git a/architectures/decentralized/solana-client/src/backend.rs b/architectures/decentralized/solana-client/src/backend.rs index f948cde94..629ead188 100644 --- a/architectures/decentralized/solana-client/src/backend.rs +++ b/architectures/decentralized/solana-client/src/backend.rs @@ -21,7 +21,7 @@ use anyhow::{Context, Result, anyhow}; use futures_util::StreamExt; use psyche_client::IntegrationTestLogMarker; use psyche_coordinator::model::{self, Checkpoint}; -use psyche_coordinator::{CommitteeProof, Coordinator, HealthChecks, model::HubRepo}; +use psyche_coordinator::{CommitteeProof, Coordinator, HealthChecks}; use psyche_watcher::{Backend as WatcherBackend, OpportunisticData}; use solana_account_decoder_client_types::{UiAccount, UiAccountEncoding}; use solana_transaction_status_client_types::UiTransactionEncoding; diff --git a/architectures/decentralized/solana-client/src/command/checkpoint.rs b/architectures/decentralized/solana-client/src/command/checkpoint.rs index 3f8cc9f2b..3ee098640 100644 --- a/architectures/decentralized/solana-client/src/command/checkpoint.rs +++ b/architectures/decentralized/solana-client/src/command/checkpoint.rs @@ -46,7 +46,7 @@ pub async fn command_checkpoint_execute( &coordinator_instance, &coordinator_account, &user, - Checkpoint::Hub(repo.clone()), + Checkpoint::Hub(repo), ); let signature = backend .send_and_retry("Checkpoint", &[instruction], &[]) diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs index 7ea079bd0..8d9ddb977 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs @@ -10,8 +10,8 @@ use psyche_coordinator::RunState; use psyche_coordinator::SOLANA_MAX_STRING_LEN; use psyche_coordinator::TickResult; use psyche_coordinator::Witness; +use psyche_coordinator::model::Checkpoint; use psyche_coordinator::model::Model; -use psyche_coordinator::model::{Checkpoint, HubRepo}; use psyche_core::FixedString; use psyche_core::SmallBoolean; use psyche_core::sha256v; diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs index f7feb7981..0a041a6e9 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs @@ -21,7 +21,7 @@ use psyche_coordinator::Witness; use psyche_coordinator::WitnessBloom; use psyche_coordinator::WitnessMetadata; use psyche_coordinator::WitnessProof; -use psyche_coordinator::model::{HubRepo, Model}; +use psyche_coordinator::model::Model; use psyche_core::MerkleRoot; use serde::Deserialize; use serde::Serialize; @@ -165,7 +165,6 @@ impl CoordinatorInstance { pub mod psyche_solana_coordinator { use super::*; - use psyche_coordinator::model::Checkpoint; use psyche_core::FixedString; pub fn init_coordinator( diff --git a/shared/client/src/cli.rs b/shared/client/src/cli.rs index 73e4f8a4c..a4ef145f0 100644 --- a/shared/client/src/cli.rs +++ b/shared/client/src/cli.rs @@ -1,9 +1,9 @@ -use crate::{CheckpointConfig, HubUploadInfo, WandBInfo}; +use crate::{CheckpointConfig, WandBInfo}; -use crate::GcsUploadInfo; use crate::UploadInfo; use anyhow::{Result, anyhow, bail}; use clap::Args; +use psyche_data_provider::{GcsUploadInfo, HubUploadInfo}; use psyche_eval::tasktype_from_name; use psyche_modeling::Devices; use psyche_network::{DiscoveryMode, RelayKind, SecretKey}; @@ -148,11 +148,11 @@ pub struct TrainArgs { #[clap(long, env)] pub hub_repo: Option, - /// Path to the Hugging Face repository containing model data and configuration. + /// Name of the GCS bucket containing model data and configuration. #[clap(long, env)] pub gcs_bucket: Option, - /// Path to the Hugging Face repository containing model data and configuration. + /// Prefix within the GCS bucket for model data and configuration. #[clap(long, env)] pub gcs_prefix: Option, @@ -234,77 +234,72 @@ impl TrainArgs { pub fn checkpoint_config(&self) -> Result> { let hub_read_token = std::env::var("HF_TOKEN").ok(); - let checkpoint_upload_info = match ( - &hub_read_token, - self.hub_repo.clone(), - self.gcs_bucket.clone(), - self.gcs_prefix.clone(), - self.checkpoint_dir.clone(), - self.delete_old_steps, - self.keep_steps, - ) { - (_, Some(_), Some(_), _, _, _, _) => { - bail!("Use either GCS or HF hub for checkpoint uploads, not both.") - } - (Some(token), Some(repo), None, _, Some(dir), delete_old_steps, keep_steps) => { - if keep_steps == 0 { - bail!("keep_steps must be >= 1 for hub repository uploads (got {keep_steps})") - } - Some(CheckpointConfig { - checkpoint_dir: dir, - hub_upload: Some(UploadInfo::Hub(HubUploadInfo { - hub_repo: repo, - hub_token: token.to_string(), - })), - delete_old_steps, - keep_steps, - }) - } - (_, _, Some(gcp_bucket), Some(gcs_prefix), Some(dir), delete_old_steps, keep_steps) => { - if keep_steps == 0 { - bail!("keep_steps must be >= 1 for GCS uploads (got {keep_steps})") - } - Some(CheckpointConfig { - checkpoint_dir: dir, - hub_upload: Some(UploadInfo::Gcs(GcsUploadInfo { - gcs_bucket: gcp_bucket, - gcs_prefix: Some(gcs_prefix), - })), - delete_old_steps, - keep_steps, - }) - } - (_, _, Some(gcp_bucket), None, Some(dir), delete_old_steps, keep_steps) => { - if keep_steps == 0 { - bail!("keep_steps must be >= 1 for GCS uploads (got {keep_steps})") + + if self.hub_repo.is_some() && self.gcs_bucket.is_some() { + bail!("Use either GCS or HF hub for checkpoint uploads, not both."); + } + + let checkpoint_dir = match &self.checkpoint_dir { + Some(dir) => dir, + None => { + if self.hub_repo.is_some() || self.gcs_bucket.is_some() { + bail!( + "--hub-repo or --gcs-bucket was set, but no --checkpoint-dir was passed!" + ); } - Some(CheckpointConfig { - checkpoint_dir: dir, - hub_upload: Some(UploadInfo::Gcs(GcsUploadInfo { - gcs_bucket: gcp_bucket, - gcs_prefix: None, - })), - delete_old_steps, - keep_steps, - }) + return Ok(None); } - (None, Some(_), None, _, _, _, _) => { - bail!("hub-repo and checkpoint-dir set, but no HF_TOKEN env variable.") - } - (_, Some(_), None, _, None, _, _) => { - bail!("--hub-repo was set, but no --checkpoint-dir was passed!") - } - (_, None, None, _, Some(dir), delete_old_steps, keep_steps) => Some(CheckpointConfig { - checkpoint_dir: dir, - hub_upload: None, - delete_old_steps, - keep_steps, - }), - (_, None, None, _, _, _, _) => None, - _ => todo!(), }; - Ok(checkpoint_upload_info) + let upload_info = self.build_upload_info(&hub_read_token)?; + + if upload_info.is_some() && self.keep_steps == 0 { + bail!( + "keep_steps must be >= 1 for checkpoint uploads (got {})", + self.keep_steps + ); + } + + Ok(Some(CheckpointConfig { + checkpoint_dir: checkpoint_dir.clone(), + upload_info, + delete_old_steps: self.delete_old_steps, + keep_steps: self.keep_steps, + })) + } + + fn build_upload_info(&self, hub_token: &Option) -> Result> { + if let Some(repo) = &self.hub_repo { + return self.build_hub_upload_info(repo, hub_token); + } + + if let Some(bucket) = &self.gcs_bucket { + return self.build_gcs_upload_info(bucket); + } + + Ok(None) + } + + fn build_hub_upload_info( + &self, + repo: &str, + token: &Option, + ) -> Result> { + let token = token.as_ref().ok_or_else(|| { + anyhow::anyhow!("hub-repo and checkpoint-dir set, but no HF_TOKEN env variable.") + })?; + + Ok(Some(UploadInfo::Hub(HubUploadInfo { + hub_repo: repo.to_string(), + hub_token: token.to_string(), + }))) + } + + fn build_gcs_upload_info(&self, bucket: &str) -> Result> { + Ok(Some(UploadInfo::Gcs(GcsUploadInfo { + gcs_bucket: bucket.to_string(), + gcs_prefix: self.gcs_prefix.clone(), + }))) } pub fn eval_tasks(&self) -> Result> { diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index 94c5ac685..488c86d62 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -1,15 +1,12 @@ -use crate::{HubUploadInfo, state::types::GcsUploadInfo}; - use crate::UploadInfo; use psyche_coordinator::{ Coordinator, - model::{self, GcsRepo, HubRepo}, + model::{self}, }; -use psyche_core::{FixedString, NodeIdentity}; -use psyche_data_provider::{UploadModelError, upload_model_repo_async, upload_to_gcs_async}; +use psyche_core::NodeIdentity; +use psyche_data_provider::{UploadError, upload_to_gcs, upload_to_hub}; use psyche_modeling::{ - CausalLM, SaveSafetensorsError, Trainer, TrainerThreadCommunicationError, - save_tensors_into_safetensors, + SaveSafetensorsError, Trainer, TrainerThreadCommunicationError, save_tensors_into_safetensors, }; use std::{ cmp::Reverse, @@ -94,8 +91,8 @@ pub enum CheckpointError { #[error("Writing extra file to disk failed: {0}")] WriteExtraFile(#[from] tokio::io::Error), - #[error("Couldn't upload model to huggingface: {0}")] - UploadError(#[from] UploadModelError), + #[error("Couldn't upload model to huggingface or GCS: {0}")] + UploadError(#[from] UploadError), #[error("Couldn't send checkpoint - channel closed")] SendCheckpoint, @@ -182,127 +179,25 @@ impl CooldownStepMetadata { let evals = model_task_runner.start(trainers); let Some(CheckpointConfig { - hub_upload, + upload_info, checkpoint_dir, delete_old_steps, keep_steps, }) = checkpoint_info else { - // If there was no HF checkpointing configuration, return immediately return Ok((evals, None)); }; - // Start the upload process of the updated model parameters in a separate task let upload_handle = tokio::task::spawn(async move { let path = checkpoint_dir.join(format!("{run_id}-step{step}")); - info!("Saving to {}", path.display()); - let mut local = tokio::task::spawn_blocking({ - let path = path.clone(); - move || save_tensors_into_safetensors(variables, path) - }) - .await - .map_err(|_| CheckpointError::WriteThreadCrashed)??; - - for extra in checkpoint_extra_files { - let to = path.join(extra.file_name().unwrap()); - tokio::fs::copy(extra.clone(), to.clone()) - .await - .map_err(CheckpointError::WriteExtraFile)?; - local.push(to); - } + let local = + save_checkpoint_locally(path, variables, checkpoint_extra_files).await?; - match hub_upload { - Some(UploadInfo::Gcs(GcsUploadInfo { - gcs_bucket, - gcs_prefix, - })) => { - info!(bucket = gcs_bucket, "Uploading checkpoint to GCS"); - match upload_to_gcs_async(gcs_bucket.clone(), local, gcs_prefix.clone()) - .await - { - Ok(path) => { - info!( - "Upload to GCS complete at gs://{}/{}", - gcs_bucket, - gcs_prefix.clone().unwrap_or_default() - ); - path - } - Err(err) => { - error!(bucket = gcs_bucket, "Error uploading to GCS: {err:#}"); - return Err(err.into()); - } - }; - tx_checkpoint - .send(model::Checkpoint::Gcs(GcsRepo { - bucket: FixedString::from_str_truncated(&format!( - "gs://{}/{:?}", - gcs_bucket, gcs_prefix - )), - prefix: Some(FixedString::from_str_truncated(&format!( - "step-{step}" - ))), - })) - .map_err(|_| CheckpointError::SendCheckpoint)?; - } - Some(UploadInfo::Hub(HubUploadInfo { - hub_repo, - hub_token, - })) => { - info!(repo = hub_repo, "Uploading checkpoint to HuggingFace"); - let revision = match upload_model_repo_async( - hub_repo.clone(), - local, - hub_token.clone(), - Some(format!("step {step}")), - None, - ) - .await - { - Ok(revision) => { - info!( - repo = hub_repo, - revision = revision, - "Upload to HuggingFace complete" - ); - revision - } - Err(err) => { - error!( - repo = hub_repo, - "Error uploading to HuggingFace: {err:#}" - ); - return Err(err.into()); - } - }; - tx_checkpoint - .send(model::Checkpoint::Hub(HubRepo { - repo_id: FixedString::from_str_truncated(&hub_repo), - revision: Some(FixedString::from_str_truncated(&revision)), - })) - .map_err(|_| CheckpointError::SendCheckpoint)?; - } - None => { - cleanup_dirs( - delete_queue, - keep_steps, - run_id, - delete_old_steps, - step, - checkpoint_dir, - ) - .await; - return Ok::<(), CheckpointError>(()); - } + if let Some(upload_info) = upload_info { + upload_checkpoint(upload_info, local.clone(), step as u64, tx_checkpoint) + .await?; } - // we put the cleanup step at the end, so that if keep_steps == 0 the logic will still work - // we'll just delete the dir after we've uploaded it - // if we fail in any of the above steps we may wind up not queueing this dir for delete - // but that's probably better than risking having the dir deleted from under us - // for a relatively low priority disk cleanup task - // and this may actually be preferred anyway because if we failed to upload, we may want to keep - // the data around locally on disk cleanup_dirs( delete_queue, keep_steps, @@ -320,12 +215,53 @@ impl CooldownStepMetadata { } .instrument(info_span!("checkpointing")), ); + Ok(CooldownStep { checkpointing_and_evals, }) } } +async fn save_checkpoint_locally( + path: PathBuf, + variables: HashMap, + checkpoint_extra_files: Vec, +) -> Result, CheckpointError> { + info!("Saving to {}", path.display()); + let mut local = tokio::task::spawn_blocking({ + let path = path.clone(); + move || save_tensors_into_safetensors(variables, path) + }) + .await + .map_err(|_| CheckpointError::WriteThreadCrashed)??; + + for extra in checkpoint_extra_files { + let to = path.join(extra.file_name().unwrap()); + tokio::fs::copy(extra.clone(), to.clone()) + .await + .map_err(CheckpointError::WriteExtraFile)?; + local.push(to); + } + + Ok(local) +} + +async fn upload_checkpoint( + upload_info: UploadInfo, + local: Vec, + step: u64, + tx_checkpoint: mpsc::UnboundedSender, +) -> Result<(), CheckpointError> { + match upload_info { + UploadInfo::Gcs(gcs_info) => upload_to_gcs(gcs_info, local, step, tx_checkpoint) + .await + .map_err(CheckpointError::UploadError), + UploadInfo::Hub(hub_info) => upload_to_hub(hub_info, local, step, tx_checkpoint) + .await + .map_err(CheckpointError::UploadError), + } +} + type CheckpointAndEvalsHandle = JoinHandle< Result< ( diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index 3d396d904..36e19db90 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -5,9 +5,9 @@ use psyche_coordinator::{ }; use psyche_core::{Barrier, CancellableBarrier, NodeIdentity, Shuffle, TokenSize}; use psyche_data_provider::{ - DataProvider, DataProviderTcpClient, DummyDataProvider, GcsError, PreprocessedDataProvider, - Split, WeightedDataProvider, download_dataset_repo_async, download_model_from_gcs_async, - download_model_repo_async, + DataProvider, DataProviderTcpClient, DownloadError, DummyDataProvider, + PreprocessedDataProvider, Split, WeightedDataProvider, download_dataset_repo_async, + download_model_from_gcs_async, download_model_repo_async, http::{FileURLs, HttpDataProvider}, }; use psyche_metrics::ClientMetrics; @@ -91,7 +91,7 @@ pub enum InitRunError { HfModelLoad(#[from] hf_hub::api::tokio::ApiError), #[error("failed to download model from GCS: {0}")] - GcsModelLoad(#[from] GcsError), + GcsModelLoad(#[from] DownloadError), #[error("model loading thread crashed")] ModelLoadingThreadCrashed(JoinError), diff --git a/shared/client/src/state/mod.rs b/shared/client/src/state/mod.rs index a148eac59..78e6cd1eb 100644 --- a/shared/client/src/state/mod.rs +++ b/shared/client/src/state/mod.rs @@ -14,9 +14,7 @@ mod warmup; mod witness; pub use init::{InitRunError, RunInitConfig, RunInitConfigAndIO}; +pub use psyche_data_provider::{GcsUploadInfo, HubUploadInfo}; pub use round_state::RoundState; pub use steps::{ApplyMessageOutcome, RunManager}; -pub use types::{ - CheckpointConfig, DistroBroadcastAndPayload, FinishedBroadcast, GcsUploadInfo, HubUploadInfo, - UploadInfo, -}; +pub use types::{CheckpointConfig, DistroBroadcastAndPayload, FinishedBroadcast, UploadInfo}; diff --git a/shared/client/src/state/types.rs b/shared/client/src/state/types.rs index 68e5e09f8..29734f1a0 100644 --- a/shared/client/src/state/types.rs +++ b/shared/client/src/state/types.rs @@ -2,24 +2,13 @@ use std::path::PathBuf; use psyche_coordinator::CommitteeProof; use psyche_core::{BatchId, MerkleRoot, NodeIdentity}; +use psyche_data_provider::{GcsUploadInfo, HubUploadInfo}; use psyche_modeling::DistroResult; use psyche_network::{BlobTicket, TransmittableDistroResult}; use tch::TchError; use thiserror::Error; use tokio::task::JoinHandle; -#[derive(Debug, Clone)] -pub struct HubUploadInfo { - pub hub_repo: String, - pub hub_token: String, -} - -#[derive(Debug, Clone)] -pub struct GcsUploadInfo { - pub gcs_bucket: String, - pub gcs_prefix: Option, -} - #[derive(Debug, Clone)] pub enum UploadInfo { Hub(HubUploadInfo), @@ -28,7 +17,7 @@ pub enum UploadInfo { #[derive(Debug, Clone)] pub struct CheckpointConfig { - pub hub_upload: Option, + pub upload_info: Option, pub checkpoint_dir: PathBuf, pub delete_old_steps: bool, pub keep_steps: u32, diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index 8d18f199e..5b10bfcbd 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -1,6 +1,6 @@ use crate::{ Commitment, Committee, CommitteeProof, CommitteeSelection, WitnessProof, - model::{Checkpoint, GcsRepo, HubRepo, Model}, + model::{Checkpoint, Model}, }; use anchor_lang::{AnchorDeserialize, AnchorSerialize, InitSpace, prelude::borsh}; @@ -596,29 +596,43 @@ impl Coordinator { &mut self, from: &T, index: u64, - hub_repo: Checkpoint, + checkpoint_repo: Checkpoint, ) -> std::result::Result<(), CoordinatorError> { let index = index as usize; if index >= self.epoch_state.clients.len() || self.epoch_state.clients[index].id != *from { return Err(CoordinatorError::InvalidCommitteeProof); } - // TODO: In the case of more than one checkpointer, this will overwrite the hub repo - // with the last checkpointed one. We could instead have a vector of hub repos to have + + // TODO: In the case of more than one checkpointer, this will overwrite the checkpoint + // with the last checkpointed one. We could instead have a vector of checkpoints to have // more download options. - match &mut self.model { - Model::LLM(llm) => match llm.checkpoint { - Checkpoint::P2P(HubRepo { repo_id, revision }) => { - llm.checkpoint = Checkpoint::P2P(HubRepo { repo_id, revision }) - } - Checkpoint::Hub(HubRepo { repo_id, revision }) => { - llm.checkpoint = Checkpoint::Hub(HubRepo { repo_id, revision }) - } - Checkpoint::Gcs(GcsRepo { bucket, prefix }) => { - llm.checkpoint = Checkpoint::Gcs(GcsRepo { bucket, prefix }) - } - _ => {} - }, + let Model::LLM(llm) = &mut self.model; + match (&llm.checkpoint, checkpoint_repo) { + // If current is P2P, wrap the new checkpoint in P2P + (Checkpoint::P2P(_), Checkpoint::Hub(hub_repo)) => { + llm.checkpoint = Checkpoint::P2P(hub_repo); + } + (Checkpoint::P2PGcs(_), Checkpoint::Gcs(gcs_repo)) => { + llm.checkpoint = Checkpoint::P2PGcs(gcs_repo); + } + // If current is Hub, only accept Hub updates + (Checkpoint::Hub(_), Checkpoint::Hub(hub_repo)) => { + llm.checkpoint = Checkpoint::Hub(hub_repo); + } + // If current is Gcs, only accept Gcs updates + (Checkpoint::Gcs(_), Checkpoint::Gcs(gcs_repo)) => { + llm.checkpoint = Checkpoint::Gcs(gcs_repo); + } + (Checkpoint::P2PGcs(_), Checkpoint::Hub(hub_repo)) => { + llm.checkpoint = Checkpoint::P2P(hub_repo); + } + (Checkpoint::P2P(_), Checkpoint::Gcs(gcs_repo)) => { + llm.checkpoint = Checkpoint::P2PGcs(gcs_repo); + } + // Ignore other combinations + _ => {} } + Ok(()) } diff --git a/shared/data-provider/src/errors.rs b/shared/data-provider/src/errors.rs new file mode 100644 index 000000000..20e601b25 --- /dev/null +++ b/shared/data-provider/src/errors.rs @@ -0,0 +1,47 @@ +use std::path::PathBuf; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum UploadError { + #[error("path {0} is not a file")] + NotAFile(PathBuf), + + #[error("file {0} doesn't have a valid utf-8 representation")] + InvalidFilename(PathBuf), + + #[error("failed to send checkpoint notification")] + SendCheckpoint, + + // Hub-specific errors + #[error("failed to connect to HF hub: {0}")] + HfHub(#[from] hf_hub::api::tokio::ApiError), + + #[error("failed to commit files: {0}")] + Commit(#[from] hf_hub::api::tokio::CommitError), + + // GCS-specific errors + #[error("GCS authentication failed: {0}")] + GcsAuth(#[from] google_cloud_storage::client::google_cloud_auth::error::Error), + + #[error("GCS operation failed: {0}")] + GcsStorage(#[from] google_cloud_storage::http::Error), + + // Common errors + #[error("IO error: {0}")] + Io(#[from] std::io::Error), +} + +#[derive(Error, Debug)] +pub enum DownloadError { + #[error("failed to connect to HF hub: {0}")] + HfHub(#[from] hf_hub::api::tokio::ApiError), + + #[error("GCS authentication failed: {0}")] + GcsAuth(#[from] google_cloud_storage::client::google_cloud_auth::error::Error), + + #[error("GCS operation failed: {0}")] + GcsStorage(#[from] google_cloud_storage::http::Error), + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), +} diff --git a/shared/data-provider/src/gcs.rs b/shared/data-provider/src/gcs.rs index cb6848c1d..adf67a5fb 100644 --- a/shared/data-provider/src/gcs.rs +++ b/shared/data-provider/src/gcs.rs @@ -1,22 +1,22 @@ +use crate::errors::{DownloadError, UploadError}; use google_cloud_storage::client::{Client, ClientConfig}; +use google_cloud_storage::http::objects::upload::Media; +use google_cloud_storage::http::objects::upload::UploadObjectRequest; +use google_cloud_storage::http::objects::upload::UploadType; use google_cloud_storage::http::objects::{ download::Range, get::GetObjectRequest, list::ListObjectsRequest, }; +use psyche_coordinator::model::{self, GcsRepo}; +use psyche_core::FixedString; use std::path::PathBuf; -use thiserror::Error; use tokio::runtime::Runtime; +use tokio::sync::mpsc; use tracing::info; -#[derive(Debug, Error)] -pub enum GcsError { - #[error("GCS authentication failed: {0}")] - Auth(#[from] google_cloud_storage::client::google_cloud_auth::error::Error), - - #[error("GCS operation failed: {0}")] - Storage(#[from] google_cloud_storage::http::Error), - - #[error("IO error: {0}")] - Io(#[from] std::io::Error), +#[derive(Debug, Clone)] +pub struct GcsUploadInfo { + pub gcs_bucket: String, + pub gcs_prefix: Option, } const MODEL_EXTENSIONS: [&str; 3] = [".safetensors", ".json", ".py"]; @@ -42,7 +42,7 @@ fn get_cache_dir(bucket: &str, prefix: Option<&str>) -> PathBuf { pub async fn download_model_from_gcs_async( bucket: &str, prefix: Option<&str>, -) -> Result, GcsError> { +) -> Result, DownloadError> { // Use authenticated client if GOOGLE_APPLICATION_CREDENTIALS is set, otherwise anonymous let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { info!("Using authenticated GCS client"); @@ -132,7 +132,83 @@ pub async fn download_model_from_gcs_async( pub fn download_model_from_gcs_sync( bucket: &str, prefix: Option<&str>, -) -> Result, GcsError> { - let rt = Runtime::new().map_err(GcsError::Io)?; +) -> Result, DownloadError> { + let rt = Runtime::new().map_err(DownloadError::Io)?; rt.block_on(download_model_from_gcs_async(bucket, prefix)) } + +pub async fn upload_to_gcs( + gcs_info: GcsUploadInfo, + local: Vec, + step: u64, + tx_checkpoint: mpsc::UnboundedSender, +) -> Result<(), UploadError> { + let GcsUploadInfo { + gcs_bucket, + gcs_prefix, + } = gcs_info; + + info!(bucket = gcs_bucket, "Uploading checkpoint to GCS"); + + let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { + info!("Using authenticated GCS client"); + ClientConfig::default().with_auth().await? + } else { + info!("Using anonymous GCS client"); + ClientConfig::default().anonymous() + }; + let client = Client::new(config); + + for path in local { + let file_name = path + .file_name() + .ok_or_else(|| UploadError::NotAFile(path.clone()))? + .to_str() + .ok_or_else(|| UploadError::InvalidFilename(path.clone()))?; + + let object_name = match &gcs_prefix { + Some(p) => format!("{}/{}", p, file_name), + None => file_name.to_string(), + }; + + let data = tokio::fs::read(&path).await?; + + let upload_type = UploadType::Simple(Media::new(object_name.clone())); + let uploaded = client + .upload_object( + &UploadObjectRequest { + bucket: gcs_bucket.clone(), + ..Default::default() + }, + data, + &upload_type, + ) + .await?; + + info!( + bucket = gcs_bucket, + object = object_name, + size = uploaded.size, + "Successfully uploaded file to GCS" + ); + } + + info!( + "Upload to GCS complete at gs://{}/{}", + gcs_bucket, + gcs_prefix.as_deref().unwrap_or("") + ); + + tx_checkpoint + .send(model::Checkpoint::Gcs(GcsRepo { + bucket: FixedString::from_str_truncated(&format!( + "gs://{}/{}", + gcs_bucket, + gcs_prefix.as_deref().unwrap_or("") + )), + prefix: Some(FixedString::from_str_truncated(&format!("step-{step}"))), + })) + .map_err(|_| UploadError::SendCheckpoint)?; + + Ok(()) +} diff --git a/shared/data-provider/src/hub.rs b/shared/data-provider/src/hub.rs index d112e661e..13a575b84 100644 --- a/shared/data-provider/src/hub.rs +++ b/shared/data-provider/src/hub.rs @@ -1,14 +1,16 @@ -use google_cloud_storage::client::{Client, ClientConfig}; -use google_cloud_storage::http::objects::upload::{Media, UploadObjectRequest, UploadType}; +use crate::errors::UploadError; +use crate::hub::model::HubRepo; use hf_hub::{ Cache, Repo, RepoType, api::{ Siblings, - tokio::{ApiError, CommitError, UploadSource}, + tokio::{ApiError, UploadSource}, }, }; +use psyche_coordinator::model; +use psyche_core::FixedString; use std::{path::PathBuf, time::Instant}; -use thiserror::Error; +use tokio::sync::mpsc; use tracing::{error, info}; const MODEL_EXTENSIONS: [&str; 3] = [".safetensors", ".json", ".py"]; @@ -191,42 +193,39 @@ pub fn download_dataset_repo_sync( ) } -#[derive(Error, Debug)] -pub enum UploadModelError { - #[error("path {0} is not a file")] - NotAFile(PathBuf), - - #[error("file {0} doesn't have a valid utf-8 representation")] - InvalidFilename(PathBuf), +#[derive(Debug, Clone)] +pub struct HubUploadInfo { + pub hub_repo: String, + pub hub_token: String, +} - #[error("failed to connect to HF hub: {0}")] - HfHub(#[from] ApiError), +pub async fn upload_to_hub( + hub_info: HubUploadInfo, + local: Vec, + step: u64, + tx_checkpoint: mpsc::UnboundedSender, +) -> Result<(), UploadError> { + let HubUploadInfo { + hub_repo, + hub_token, + } = hub_info; - #[error("failed to commit files: {0}")] - Commit(#[from] CommitError), -} + info!(repo = hub_repo, "Uploading checkpoint to HuggingFace"); -pub async fn upload_model_repo_async( - repo_id: String, - files: Vec, - token: String, - commit_message: Option, - commit_description: Option, -) -> Result { let api = hf_hub::api::tokio::ApiBuilder::new() - .with_token(Some(token)) + .with_token(Some(hub_token.clone())) .build()?; - let repo = Repo::model(repo_id.clone()); + let repo = Repo::model(hub_repo.clone()); let api_repo = api.repo(repo); - let files: Result, _> = files + let files: Result, _> = local .into_iter() .map(|path| { path.file_name() - .ok_or(UploadModelError::NotAFile(path.clone())) + .ok_or(UploadError::NotAFile(path.clone())) .and_then(|name| { name.to_str() - .ok_or(UploadModelError::InvalidFilename(path.clone())) + .ok_or(UploadError::InvalidFilename(path.clone())) .map(|s| s.to_string()) }) .map(|name| (path.into(), name)) @@ -235,84 +234,32 @@ pub async fn upload_model_repo_async( let files = files?; - let commit_info = match api_repo - .upload_files( - files, - commit_message.clone(), - commit_description.clone(), - false, - ) + let commit_info = api_repo + .upload_files(files, Some(format!("step {step}")), None, false) .await - { - Ok(info) => { - info!( - repo = repo_id, - oid = info.oid, - "Successfully uploaded files to HuggingFace" - ); - info - } - Err(e) => { + .map_err(|e| { error!( - repo = repo_id, + repo = hub_repo, error = ?e, - "Failed to upload files to HuggingFace. Full error details: {:#?}", - e + "Failed to upload files to HuggingFace" ); - return Err(e.into()); - } - }; - Ok(commit_info.oid) -} - -pub async fn upload_to_gcs_async( - bucket: String, - files: Vec, - prefix: Option, -) -> Result<(), UploadModelError> { - let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { - info!("Using authenticated GCS client"); - ClientConfig::default().with_auth().await.unwrap() - } else { - info!("Using anonymous GCS client"); - ClientConfig::default().anonymous() - }; - let client = Client::new(config); + e + })?; - for path in files { - let file_name = path - .file_name() - .ok_or_else(|| UploadModelError::NotAFile(path.clone()))? - .to_str() - .ok_or_else(|| UploadModelError::InvalidFilename(path.clone()))?; + let revision = commit_info.oid; - let object_name = match &prefix { - Some(p) => format!("{}/{}", p, file_name), - None => file_name.to_string(), - }; + info!( + repo = hub_repo, + revision = revision, + "Upload to HuggingFace complete" + ); - let data = tokio::fs::read(&path).await.unwrap(); - - let upload_type = UploadType::Simple(Media::new(object_name.clone())); - let uploaded = client - .upload_object( - &UploadObjectRequest { - bucket: bucket.clone(), - ..Default::default() - }, - data, - &upload_type, - ) - .await - .unwrap(); - - info!( - bucket = bucket, - object = object_name, - size = uploaded.size, - "Successfully uploaded file to GCS" - ); - } + tx_checkpoint + .send(model::Checkpoint::Hub(HubRepo { + repo_id: FixedString::from_str_truncated(&hub_repo), + revision: Some(FixedString::from_str_truncated(&revision)), + })) + .map_err(|_| UploadError::SendCheckpoint)?; Ok(()) } diff --git a/shared/data-provider/src/lib.rs b/shared/data-provider/src/lib.rs index 0c884bea3..a35b73dae 100644 --- a/shared/data-provider/src/lib.rs +++ b/shared/data-provider/src/lib.rs @@ -1,6 +1,7 @@ mod data_provider; mod dataset; mod dummy; +mod errors; mod file_extensions; mod gcs; pub mod http; @@ -14,12 +15,14 @@ mod weighted; pub use data_provider::DataProvider; pub use dataset::{Dataset, Field, Row, Split}; pub use dummy::DummyDataProvider; +pub use errors::{DownloadError, UploadError}; pub use file_extensions::{DATA_FILE_EXTENSIONS, PARQUET_EXTENSION}; -pub use gcs::{GcsError, download_model_from_gcs_async, download_model_from_gcs_sync}; +pub use gcs::{ + GcsUploadInfo, download_model_from_gcs_async, download_model_from_gcs_sync, upload_to_gcs, +}; pub use hub::{ - UploadModelError, download_dataset_repo_async, download_dataset_repo_sync, - download_model_repo_async, download_model_repo_sync, upload_model_repo_async, - upload_to_gcs_async, + HubUploadInfo, download_dataset_repo_async, download_dataset_repo_sync, + download_model_repo_async, download_model_repo_sync, upload_to_hub, }; pub use local::LocalDataProvider; pub use parquet::record::{ListAccessor, MapAccessor, RowAccessor}; From a41c4572194db31f6201775dbfca15fbdad4cc81 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Mon, 12 Jan 2026 12:20:52 -0800 Subject: [PATCH 30/72] Fix import with python feature --- shared/client/src/state/cooldown.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index 488c86d62..28b6294be 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -5,6 +5,8 @@ use psyche_coordinator::{ }; use psyche_core::NodeIdentity; use psyche_data_provider::{UploadError, upload_to_gcs, upload_to_hub}; +#[cfg(feature = "python")] +use psyche_modeling::CausalLM; use psyche_modeling::{ SaveSafetensorsError, Trainer, TrainerThreadCommunicationError, save_tensors_into_safetensors, }; From b19bc90e57b4add11fe352ce1dd583c03c08d1a9 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Mon, 12 Jan 2026 12:24:16 -0800 Subject: [PATCH 31/72] Remove vllm from nix --- nix/lib.nix | 1 + python/default.nix | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/nix/lib.nix b/nix/lib.nix index 882f17040..37ab6d844 100644 --- a/nix/lib.nix +++ b/nix/lib.nix @@ -36,6 +36,7 @@ let python312 pkg-config perl + cargo-nextest ]; buildInputs = diff --git a/python/default.nix b/python/default.nix index 196b07867..ea9faa77b 100644 --- a/python/default.nix +++ b/python/default.nix @@ -39,7 +39,6 @@ let # i'm really not a fan of providing torchtitan like this. i'd much rather have it be built as a git dep via uv2nix. # i think there's room to figure out how to provide setuptools for it. "torchtitan" - "vllm" ]; nixProvidedPythonPkgs = getAllTransitiveDeps topLevelNixPkgs; From 9daad93e52d4da71c2a1d2d8fe0e66dfe97453e0 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Mon, 12 Jan 2026 14:13:31 -0800 Subject: [PATCH 32/72] Add cancellation process after one client checkpoints --- shared/client/src/cli.rs | 18 ++----- shared/client/src/state/steps.rs | 44 +++++++-------- shared/coordinator/src/coordinator.rs | 13 +---- shared/data-provider/src/errors.rs | 3 ++ shared/data-provider/src/gcs.rs | 49 +++++++++-------- shared/data-provider/src/hub.rs | 77 ++++++++++++++------------- 6 files changed, 93 insertions(+), 111 deletions(-) diff --git a/shared/client/src/cli.rs b/shared/client/src/cli.rs index 9b996cd9f..23536330f 100644 --- a/shared/client/src/cli.rs +++ b/shared/client/src/cli.rs @@ -241,18 +241,6 @@ impl TrainArgs { bail!("Either --hub-repo or --gcs-bucket must be set for checkpoint uploads"); } - let checkpoint_dir = match &self.checkpoint_dir { - Some(dir) => dir, - None => { - if self.hub_repo.is_some() || self.gcs_bucket.is_some() { - bail!( - "--hub-repo or --gcs-bucket was set, but no --checkpoint-dir was passed!" - ); - } - return Ok(None); - } - }; - let upload_info = self.build_upload_info(&hub_read_token)?; if upload_info.is_some() && self.keep_steps == 0 { @@ -262,12 +250,12 @@ impl TrainArgs { ); } - Ok(Some(CheckpointConfig { - checkpoint_dir: checkpoint_dir.clone(), + Ok(CheckpointConfig { + checkpoint_dir: self.checkpoint_dir.clone(), upload_info, delete_old_steps: self.delete_old_steps, keep_steps: self.keep_steps, - })) + }) } fn build_upload_info(&self, hub_token: &Option) -> Result> { diff --git a/shared/client/src/state/steps.rs b/shared/client/src/state/steps.rs index 571dc907c..6f7962b94 100644 --- a/shared/client/src/state/steps.rs +++ b/shared/client/src/state/steps.rs @@ -62,10 +62,6 @@ pub struct StepStateMachine, - - // Handles for HuggingFace uploads running in background - pending_upload_handles: - Vec>>, } #[derive(Error, Debug)] @@ -168,8 +164,6 @@ impl StepStateMachine StepStateMachine StepStateMachine { + cooldown.cancel(); // Cancel any ongoing upload + self.sent_cooldown_witness = false; let (trainers, upload_handle) = cooldown.finish().await?; - if let Some(handle) = upload_handle { - self.pending_upload_handles.push(handle); - } ActiveStep::Warmup(self.warmup.start( trainers, &mut self.previous_round, @@ -949,8 +942,8 @@ impl StepStateMachine RunManager { } pub fn doing_checkpoint(&self) -> bool { - match &self.0 { - InitStage::Running(step_state_machine) => { - let has_pending_uploads = step_state_machine - .pending_upload_handles - .iter() - .any(|handle| !handle.is_finished()); - - has_pending_uploads - } - _ => false, - } + // match &self.0 { + // InitStage::Running(step_state_machine) => { + // let has_pending_uploads = step_state_machine + // .pending_upload_handles + // .iter() + // .any(|handle| !handle.is_finished()); + + // has_pending_uploads + // } + // _ => false, + // } + false } } diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index 18344ae35..b789cb57d 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -276,7 +276,6 @@ pub struct CoordinatorEpochState { /// `get_historical_clients` is what you actually want. pub clients: FixedVec, { SOLANA_MAX_NUM_CLIENTS }>, pub exited_clients: FixedVec, { SOLANA_MAX_NUM_CLIENTS }>, - pub checkpointers: FixedVec, pub rounds_head: u32, pub start_step: u32, pub last_step: u32, @@ -415,7 +414,6 @@ impl Default for CoordinatorEpochState { first_round: true.into(), clients: Default::default(), exited_clients: Default::default(), - checkpointers: Default::default(), cold_start_epoch: false.into(), start_step: Default::default(), last_step: Default::default(), @@ -523,7 +521,7 @@ impl Coordinator { } if !matches!(self.run_state, RunState::Cooldown) { - return Err(CoordinatorError::InvalidRunState); + return Ok(()); } let checkpointer_selection = CheckpointerSelection::from_coordinator(self, 0)?; @@ -533,14 +531,7 @@ impl Coordinator { return Err(CoordinatorError::InvalidWitness); } - self.epoch_state - .checkpointers - .push(witness) - .map_err(|_| CoordinatorError::WitnessesFull)?; - - if self.epoch_state.checkpointers.len() == self.config.checkpointer_nodes as usize { - self.epoch_state.checkpointed = true; - } + self.epoch_state.checkpointed = true; Ok(()) } diff --git a/shared/data-provider/src/errors.rs b/shared/data-provider/src/errors.rs index 20e601b25..a649281c6 100644 --- a/shared/data-provider/src/errors.rs +++ b/shared/data-provider/src/errors.rs @@ -12,6 +12,9 @@ pub enum UploadError { #[error("failed to send checkpoint notification")] SendCheckpoint, + #[error("Upload was cancelled")] + Cancelled, + // Hub-specific errors #[error("failed to connect to HF hub: {0}")] HfHub(#[from] hf_hub::api::tokio::ApiError), diff --git a/shared/data-provider/src/gcs.rs b/shared/data-provider/src/gcs.rs index adf67a5fb..16873162d 100644 --- a/shared/data-provider/src/gcs.rs +++ b/shared/data-provider/src/gcs.rs @@ -141,7 +141,7 @@ pub async fn upload_to_gcs( gcs_info: GcsUploadInfo, local: Vec, step: u64, - tx_checkpoint: mpsc::UnboundedSender, + cancellation_token: tokio_util::sync::CancellationToken, ) -> Result<(), UploadError> { let GcsUploadInfo { gcs_bucket, @@ -160,6 +160,12 @@ pub async fn upload_to_gcs( let client = Client::new(config); for path in local { + // Check for cancellation before each file upload + if cancellation_token.is_cancelled() { + info!("Upload cancelled before uploading {}", path.display()); + return Err(UploadError::Cancelled); + } + let file_name = path .file_name() .ok_or_else(|| UploadError::NotAFile(path.clone()))? @@ -174,16 +180,26 @@ pub async fn upload_to_gcs( let data = tokio::fs::read(&path).await?; let upload_type = UploadType::Simple(Media::new(object_name.clone())); - let uploaded = client - .upload_object( - &UploadObjectRequest { - bucket: gcs_bucket.clone(), - ..Default::default() - }, - data, - &upload_type, - ) - .await?; + + // Bind to a variable so it lives long enough + let upload_request = UploadObjectRequest { + bucket: gcs_bucket.clone(), + ..Default::default() + }; + + let upload_future = client.upload_object(&upload_request, data, &upload_type); + + let uploaded = tokio::select! { + biased; + + _ = cancellation_token.cancelled() => { + info!("Upload cancelled during upload of {}", path.display()); + return Err(UploadError::Cancelled); + } + result = upload_future => { + result? + } + }; info!( bucket = gcs_bucket, @@ -199,16 +215,5 @@ pub async fn upload_to_gcs( gcs_prefix.as_deref().unwrap_or("") ); - tx_checkpoint - .send(model::Checkpoint::Gcs(GcsRepo { - bucket: FixedString::from_str_truncated(&format!( - "gs://{}/{}", - gcs_bucket, - gcs_prefix.as_deref().unwrap_or("") - )), - prefix: Some(FixedString::from_str_truncated(&format!("step-{step}"))), - })) - .map_err(|_| UploadError::SendCheckpoint)?; - Ok(()) } diff --git a/shared/data-provider/src/hub.rs b/shared/data-provider/src/hub.rs index 13a575b84..0a2db93c6 100644 --- a/shared/data-provider/src/hub.rs +++ b/shared/data-provider/src/hub.rs @@ -203,13 +203,17 @@ pub async fn upload_to_hub( hub_info: HubUploadInfo, local: Vec, step: u64, - tx_checkpoint: mpsc::UnboundedSender, + cancellation_token: tokio_util::sync::CancellationToken, ) -> Result<(), UploadError> { let HubUploadInfo { hub_repo, hub_token, } = hub_info; + if cancellation_token.is_cancelled() { + return Err(UploadError::Cancelled); + } + info!(repo = hub_repo, "Uploading checkpoint to HuggingFace"); let api = hf_hub::api::tokio::ApiBuilder::new() @@ -218,48 +222,45 @@ pub async fn upload_to_hub( let repo = Repo::model(hub_repo.clone()); let api_repo = api.repo(repo); - let files: Result, _> = local - .into_iter() - .map(|path| { - path.file_name() - .ok_or(UploadError::NotAFile(path.clone())) - .and_then(|name| { - name.to_str() - .ok_or(UploadError::InvalidFilename(path.clone())) - .map(|s| s.to_string()) - }) - .map(|name| (path.into(), name)) - }) - .collect(); + for path in local { + if cancellation_token.is_cancelled() { + info!(repo = hub_repo, "Upload to HuggingFace cancelled"); + return Err(UploadError::Cancelled); + } - let files = files?; + let file_name = path + .file_name() + .ok_or_else(|| UploadError::NotAFile(path.clone()))? + .to_str() + .ok_or_else(|| UploadError::InvalidFilename(path.clone()))? + .to_string(); - let commit_info = api_repo - .upload_files(files, Some(format!("step {step}")), None, false) - .await - .map_err(|e| { - error!( - repo = hub_repo, - error = ?e, - "Failed to upload files to HuggingFace" - ); - e - })?; + let upload_future = api_repo.upload_files( + vec![(path.into(), file_name.clone())], + Some(format!("step {step}")), + None, + false, + ); - let revision = commit_info.oid; + tokio::select! { + biased; - info!( - repo = hub_repo, - revision = revision, - "Upload to HuggingFace complete" - ); + _ = cancellation_token.cancelled() => { + info!(repo = hub_repo, file = file_name, "Upload cancelled"); + return Err(UploadError::Cancelled); + } + result = upload_future => { + result.map_err(|e| { + error!(repo = hub_repo, error = ?e, "Failed to upload file"); + e + })?; + } + } + + info!(repo = hub_repo, file = file_name, "Uploaded file"); + } - tx_checkpoint - .send(model::Checkpoint::Hub(HubRepo { - repo_id: FixedString::from_str_truncated(&hub_repo), - revision: Some(FixedString::from_str_truncated(&revision)), - })) - .map_err(|_| UploadError::SendCheckpoint)?; + info!(repo = hub_repo, "Upload to HuggingFace complete"); Ok(()) } From 4511ae4d2b472674e6d10fc47654e408f48f50fd Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Tue, 13 Jan 2026 07:41:30 -0800 Subject: [PATCH 33/72] Fix cancellation for upload task --- architectures/centralized/client/src/app.rs | 2 - shared/client/src/state/cooldown.rs | 64 ++++++++++----------- shared/client/src/state/steps.rs | 4 +- shared/data-provider/src/errors.rs | 3 - shared/data-provider/src/gcs.rs | 9 +-- shared/data-provider/src/hub.rs | 15 ++--- 6 files changed, 37 insertions(+), 60 deletions(-) diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index 1ea073de9..cfebd1cfb 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -1,8 +1,6 @@ use anyhow::{Error, Result}; use bytemuck::Zeroable; use psyche_centralized_shared::{ClientId, ClientToServerMessage, ServerToClientMessage}; -use psyche_client::HubUploadInfo; -use psyche_client::UploadInfo; use psyche_client::{ Client, ClientTUI, ClientTUIState, NC, RunInitConfig, TrainArgs, read_identity_secret_key, }; diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index d7c40840f..3e9ecbbd1 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -1,9 +1,6 @@ use crate::UploadInfo; use psyche_coordinator::CheckpointerSelection; -use psyche_coordinator::{ - Coordinator, - model::{self}, -}; +use psyche_coordinator::Coordinator; use psyche_core::NodeIdentity; use psyche_data_provider::{UploadError, upload_to_gcs, upload_to_hub}; #[cfg(feature = "python")] @@ -183,43 +180,40 @@ impl CooldownStepMetadata { trainers.push(trainer); let evals = model_task_runner.start(trainers); - if !is_checkpointer { + if !is_checkpointer { info!("Skipping checkpoint upload as this node is not the checkpointer for this epoch"); return Ok((evals, None)); } - let CheckpointConfig { - upload_info, - checkpoint_dir, - delete_old_steps, - keep_steps, - } = checkpoint_info; - - let upload_handle = tokio::task::spawn(async move { - let path = checkpoint_dir.join(format!("{run_id}-step{step}")); - let local = - save_checkpoint_locally(path, variables, checkpoint_extra_files).await?; - - if let Some(upload_info) = upload_info { - upload_checkpoint(upload_info, local.clone(), step as u64, cancellation_token.clone()) - .await?; - } - cleanup_dirs( - delete_queue, - keep_steps, - run_id, - delete_old_steps, - step, - checkpoint_dir, - ) - .await; + let CheckpointConfig { + upload_info, + checkpoint_dir, + delete_old_steps, + keep_steps, + } = checkpoint_info; - Ok(()) - }); + // Do the upload inline instead of spawning + let path = checkpoint_dir.join(format!("{run_id}-step{step}")); + let local = save_checkpoint_locally(path, variables, checkpoint_extra_files).await?; - Ok((evals, Some(upload_handle))) + if let Some(upload_info) = upload_info { + upload_checkpoint(upload_info, local.clone(), step as u64, cancellation_token.clone()) + .await?; } - .instrument(info_span!("checkpointing")) + + cleanup_dirs( + delete_queue, + keep_steps, + run_id, + delete_old_steps, + step, + checkpoint_dir, + ) + .await; + + Ok((evals, None)) // No separate handle needed + } + .instrument(info_span!("checkpointing")) }); Ok(CooldownStep { @@ -260,7 +254,7 @@ async fn upload_checkpoint( cancellation_token: tokio_util::sync::CancellationToken, ) -> Result<(), CheckpointError> { match upload_info { - UploadInfo::Gcs(gcs_info) => upload_to_gcs(gcs_info, local, step, cancellation_token) + UploadInfo::Gcs(gcs_info) => upload_to_gcs(gcs_info, local, cancellation_token) .await .map_err(CheckpointError::UploadError), UploadInfo::Hub(hub_info) => upload_to_hub(hub_info, local, step, cancellation_token) diff --git a/shared/client/src/state/steps.rs b/shared/client/src/state/steps.rs index 6f7962b94..64a826ffd 100644 --- a/shared/client/src/state/steps.rs +++ b/shared/client/src/state/steps.rs @@ -779,7 +779,7 @@ impl StepStateMachine StepStateMachine { cooldown.cancel(); // Cancel any ongoing upload + let (trainers, _) = cooldown.finish().await?; self.sent_cooldown_witness = false; - let (trainers, upload_handle) = cooldown.finish().await?; ActiveStep::Warmup(self.warmup.start( trainers, &mut self.previous_round, diff --git a/shared/data-provider/src/errors.rs b/shared/data-provider/src/errors.rs index a649281c6..20e601b25 100644 --- a/shared/data-provider/src/errors.rs +++ b/shared/data-provider/src/errors.rs @@ -12,9 +12,6 @@ pub enum UploadError { #[error("failed to send checkpoint notification")] SendCheckpoint, - #[error("Upload was cancelled")] - Cancelled, - // Hub-specific errors #[error("failed to connect to HF hub: {0}")] HfHub(#[from] hf_hub::api::tokio::ApiError), diff --git a/shared/data-provider/src/gcs.rs b/shared/data-provider/src/gcs.rs index 16873162d..ab266a757 100644 --- a/shared/data-provider/src/gcs.rs +++ b/shared/data-provider/src/gcs.rs @@ -6,11 +6,8 @@ use google_cloud_storage::http::objects::upload::UploadType; use google_cloud_storage::http::objects::{ download::Range, get::GetObjectRequest, list::ListObjectsRequest, }; -use psyche_coordinator::model::{self, GcsRepo}; -use psyche_core::FixedString; use std::path::PathBuf; use tokio::runtime::Runtime; -use tokio::sync::mpsc; use tracing::info; #[derive(Debug, Clone)] @@ -140,7 +137,6 @@ pub fn download_model_from_gcs_sync( pub async fn upload_to_gcs( gcs_info: GcsUploadInfo, local: Vec, - step: u64, cancellation_token: tokio_util::sync::CancellationToken, ) -> Result<(), UploadError> { let GcsUploadInfo { @@ -163,7 +159,7 @@ pub async fn upload_to_gcs( // Check for cancellation before each file upload if cancellation_token.is_cancelled() { info!("Upload cancelled before uploading {}", path.display()); - return Err(UploadError::Cancelled); + return Ok(()); } let file_name = path @@ -186,7 +182,6 @@ pub async fn upload_to_gcs( bucket: gcs_bucket.clone(), ..Default::default() }; - let upload_future = client.upload_object(&upload_request, data, &upload_type); let uploaded = tokio::select! { @@ -194,7 +189,7 @@ pub async fn upload_to_gcs( _ = cancellation_token.cancelled() => { info!("Upload cancelled during upload of {}", path.display()); - return Err(UploadError::Cancelled); + return Ok(()); } result = upload_future => { result? diff --git a/shared/data-provider/src/hub.rs b/shared/data-provider/src/hub.rs index 0a2db93c6..24809c4f8 100644 --- a/shared/data-provider/src/hub.rs +++ b/shared/data-provider/src/hub.rs @@ -1,16 +1,9 @@ use crate::errors::UploadError; -use crate::hub::model::HubRepo; use hf_hub::{ Cache, Repo, RepoType, - api::{ - Siblings, - tokio::{ApiError, UploadSource}, - }, + api::{Siblings, tokio::ApiError}, }; -use psyche_coordinator::model; -use psyche_core::FixedString; use std::{path::PathBuf, time::Instant}; -use tokio::sync::mpsc; use tracing::{error, info}; const MODEL_EXTENSIONS: [&str; 3] = [".safetensors", ".json", ".py"]; @@ -211,7 +204,7 @@ pub async fn upload_to_hub( } = hub_info; if cancellation_token.is_cancelled() { - return Err(UploadError::Cancelled); + return Ok(()); } info!(repo = hub_repo, "Uploading checkpoint to HuggingFace"); @@ -225,7 +218,7 @@ pub async fn upload_to_hub( for path in local { if cancellation_token.is_cancelled() { info!(repo = hub_repo, "Upload to HuggingFace cancelled"); - return Err(UploadError::Cancelled); + return Ok(()); } let file_name = path @@ -247,7 +240,7 @@ pub async fn upload_to_hub( _ = cancellation_token.cancelled() => { info!(repo = hub_repo, file = file_name, "Upload cancelled"); - return Err(UploadError::Cancelled); + return Ok(()); } result = upload_future => { result.map_err(|e| { From 64ace6f9bceed61667d0c3b797ecb42bc5913438 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Tue, 13 Jan 2026 17:34:47 -0300 Subject: [PATCH 34/72] General refactor and cleanup for new checkpointers logic --- Cargo.lock | 1 + Cargo.toml | 1 + architectures/centralized/client/Cargo.toml | 1 + architectures/centralized/client/src/app.rs | 78 ++- architectures/centralized/server/src/app.rs | 6 +- .../solana-client/src/backend.rs | 23 + .../solana-coordinator/src/instance_state.rs | 9 +- .../programs/solana-coordinator/src/lib.rs | 17 +- shared/client/src/client.rs | 4 - shared/client/src/state/cooldown.rs | 48 +- shared/client/src/state/steps.rs | 40 +- shared/coordinator/src/checkpointer.rs | 57 +++ shared/coordinator/src/committee.rs | 169 +++++++ shared/coordinator/src/committee_selection.rs | 460 ------------------ shared/coordinator/src/coordinator.rs | 13 +- shared/coordinator/src/lib.rs | 14 +- shared/coordinator/src/tests.rs | 174 +++++++ shared/coordinator/src/types.rs | 85 ++++ shared/data-provider/Cargo.toml | 2 +- shared/watcher/src/traits.rs | 4 + 20 files changed, 645 insertions(+), 561 deletions(-) create mode 100644 shared/coordinator/src/checkpointer.rs create mode 100644 shared/coordinator/src/committee.rs delete mode 100644 shared/coordinator/src/committee_selection.rs create mode 100644 shared/coordinator/src/tests.rs create mode 100644 shared/coordinator/src/types.rs diff --git a/Cargo.lock b/Cargo.lock index 45cf331f8..d6eb881d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6829,6 +6829,7 @@ dependencies = [ "bytemuck", "clap", "clap-markdown", + "google-cloud-storage", "hex", "hf-hub", "psyche-centralized-shared", diff --git a/Cargo.toml b/Cargo.toml index fd50e3d37..376e58700 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -76,6 +76,7 @@ indicatif = "0.17.5" tokenizers = { version = "0.20.0", default-features = false, features = [ "onig", ] } +google-cloud-storage = "0.24.0" tch = { git = "https://github.com/jquesnelle/tch-rs.git", rev = "11d1ca2ef6dbd3f1e5b0986fab0a90fbb6734496" } torch-sys = { git = "https://github.com/jquesnelle/tch-rs.git", rev = "11d1ca2ef6dbd3f1e5b0986fab0a90fbb6734496" } pyo3-tch = { git = "https://github.com/jquesnelle/tch-rs.git", rev = "11d1ca2ef6dbd3f1e5b0986fab0a90fbb6734496" } diff --git a/architectures/centralized/client/Cargo.toml b/architectures/centralized/client/Cargo.toml index fbc6ca39a..be8454978 100644 --- a/architectures/centralized/client/Cargo.toml +++ b/architectures/centralized/client/Cargo.toml @@ -25,6 +25,7 @@ time.workspace = true bytemuck.workspace = true clap-markdown.workspace = true hex = "0.4.3" +google-cloud-storage.workspace = true psyche-python-extension-impl = { workspace = true, optional = true } [features] diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index cfebd1cfb..4e5d20d45 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -1,9 +1,15 @@ use anyhow::{Error, Result}; use bytemuck::Zeroable; +use google_cloud_storage::client::{Client as GcsClient, ClientConfig}; +use google_cloud_storage::http::objects::delete::DeleteObjectRequest; +use google_cloud_storage::http::objects::upload::{Media, UploadObjectRequest, UploadType}; +use hf_hub::Repo; use psyche_centralized_shared::{ClientId, ClientToServerMessage, ServerToClientMessage}; use psyche_client::{ - Client, ClientTUI, ClientTUIState, NC, RunInitConfig, TrainArgs, read_identity_secret_key, + CheckpointConfig, Client, ClientTUI, ClientTUIState, NC, RunInitConfig, TrainArgs, UploadInfo, + read_identity_secret_key, }; +use psyche_coordinator::model::Checkpoint; use psyche_coordinator::{Coordinator, HealthChecks}; use psyche_metrics::ClientMetrics; use psyche_network::{ @@ -28,6 +34,7 @@ pub type TabsData = ::Data; pub enum ToSend { Witness(Box), HealthCheck(HealthChecks), + Checkpoint(Checkpoint), } struct Backend { @@ -64,6 +71,11 @@ impl WatcherBackend for Backend { self.tx.send(ToSend::HealthCheck(health_checks))?; Ok(()) } + + async fn send_checkpoint(&mut self, checkpoint: Checkpoint) -> Result<()> { + self.tx.send(ToSend::Checkpoint(checkpoint))?; + Ok(()) + } } pub struct App { @@ -164,6 +176,69 @@ impl App { p2p: NC, state_options: RunInitConfig, ) -> Result<()> { + // sanity checks + let CheckpointConfig { upload_info, .. } = state_options.checkpoint_config.clone(); + match upload_info { + Some(UploadInfo::Hub(hub_info)) => { + let api = hf_hub::api::tokio::ApiBuilder::new() + .with_token(Some(hub_info.hub_token)) + .build()?; + let repo_api = api.repo(Repo::new( + hub_info.hub_repo.clone(), + hf_hub::RepoType::Model, + )); + if !repo_api.is_writable().await { + anyhow::bail!( + "Checkpoint upload repo {} is not writable with the passed API key.", + hub_info.hub_repo + ) + } + } + Some(UploadInfo::Gcs(gcs_info)) => { + // Create GCS client + let config = ClientConfig::default().with_auth().await?; + let client = GcsClient::new(config); + + // Test write access by attempting to upload a small test object + let test_key = format!( + "{}/.write_test", + gcs_info.gcs_prefix.clone().unwrap_or_default() + ); + + let upload_result = client + .upload_object( + &UploadObjectRequest { + bucket: gcs_info.gcs_bucket.clone(), + ..Default::default() + }, + vec![], // empty content + &UploadType::Simple(Media::new(test_key.clone())), + ) + .await; + + match upload_result { + Ok(_) => { + // Clean up test object + let delete_request = DeleteObjectRequest { + bucket: gcs_info.gcs_bucket.clone(), + object: test_key.clone(), + ..Default::default() + }; + let _ = client.delete_object(&delete_request).await; + } + Err(e) => { + anyhow::bail!( + "GCS bucket gs://{}/{} is not writable: {}", + gcs_info.gcs_bucket, + gcs_info.gcs_prefix.clone().unwrap_or_default(), + e + ) + } + } + } + None => {} + } + self.server_conn .send(ClientToServerMessage::Join { run_id: self.run_id.clone(), @@ -204,6 +279,7 @@ impl App { match to_send { ToSend::Witness(witness) => self.server_conn.send(ClientToServerMessage::Witness(witness)).await?, ToSend::HealthCheck(health_checks) => self.server_conn.send(ClientToServerMessage::HealthCheck(health_checks)).await?, + ToSend::Checkpoint(checkpoint) => self.server_conn.send(ClientToServerMessage::Checkpoint(checkpoint)).await?, }; } } diff --git a/architectures/centralized/server/src/app.rs b/architectures/centralized/server/src/app.rs index 7dd68bc77..4df2ca662 100644 --- a/architectures/centralized/server/src/app.rs +++ b/architectures/centralized/server/src/app.rs @@ -80,6 +80,10 @@ impl psyche_watcher::Backend for ChannelCoordinatorBackend { async fn send_health_check(&mut self, _health_checks: HealthChecks) -> Result<()> { bail!("Server does not send health checks"); } + + async fn send_checkpoint(&mut self, _checkpoint: Checkpoint) -> Result<()> { + bail!("Server does not send checkpoints"); + } } type DataServer = @@ -399,7 +403,7 @@ impl App { rand::rng().next_u64(), ), OpportunisticData::CooldownStep(witness) => { - self.coordinator.cooldown_witness(&from, witness) + self.coordinator.cooldown_witness(witness) } } { warn!("Error when processing witness: {error}"); diff --git a/architectures/decentralized/solana-client/src/backend.rs b/architectures/decentralized/solana-client/src/backend.rs index d8f68072e..d63c5fbd5 100644 --- a/architectures/decentralized/solana-client/src/backend.rs +++ b/architectures/decentralized/solana-client/src/backend.rs @@ -20,6 +20,7 @@ use anchor_client::{ use anyhow::{Context, Result, anyhow}; use futures_util::StreamExt; use psyche_client::IntegrationTestLogMarker; +use psyche_coordinator::model::Checkpoint; use psyche_coordinator::{CommitteeProof, Coordinator, HealthChecks}; use psyche_watcher::{Backend as WatcherBackend, OpportunisticData}; use solana_account_decoder_client_types::{UiAccount, UiAccountEncoding}; @@ -335,6 +336,22 @@ impl SolanaBackend { self.spawn_scheduled_send("Health check", &[instruction], &[]); } + pub fn send_checkpoint( + &self, + coordinator_instance: Pubkey, + coordinator_account: Pubkey, + repo: Checkpoint, + ) { + let user = self.get_payer(); + let instruction = instructions::coordinator_checkpoint( + &coordinator_instance, + &coordinator_account, + &user, + repo, + ); + self.spawn_scheduled_send("Checkpoint", &[instruction], &[]); + } + pub fn find_join_authorization(join_authority: &Pubkey, authorizer: Option) -> Pubkey { psyche_solana_authorizer::find_authorization( join_authority, @@ -592,6 +609,12 @@ impl WatcherBackend for SolanaBackendRunner } Ok(()) } + + async fn send_checkpoint(&mut self, checkpoint: Checkpoint) -> Result<()> { + self.backend + .send_checkpoint(self.instance, self.account, checkpoint); + Ok(()) + } } impl SolanaBackendRunner { diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs index ecd365074..6751fdf0c 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs @@ -233,14 +233,9 @@ impl CoordinatorInstanceState { self.tick() } - pub fn cooldown_witness( - &mut self, - payer: &Pubkey, - witness: Witness, - ) -> Result<()> { - let id = self.clients_state.find_signer(payer)?; + pub fn cooldown_witness(&mut self, witness: Witness) -> Result<()> { self.coordinator - .cooldown_witness(id, witness) + .cooldown_witness(witness) .map_err(|err| anchor_lang::error!(ProgramError::from(err)))?; self.tick() diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs index a6e8366a1..71ee7e949 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs @@ -268,7 +268,6 @@ pub mod psyche_solana_coordinator { ) } - #[allow(unused_variables)] // for the metadata field. adding a _ prefix results in anchor's IDL not matching the actual types. lol. pub fn warmup_witness( ctx: Context, proof: WitnessProof, @@ -289,7 +288,6 @@ pub mod psyche_solana_coordinator { ) } - #[allow(unused_variables)] // for the metadata field. adding a _ prefix results in anchor's IDL not matching the actual types. lol. pub fn cooldown_witness( ctx: Context, proof: WitnessProof, @@ -299,15 +297,12 @@ pub mod psyche_solana_coordinator { ) -> Result<()> { let mut account = ctx.accounts.coordinator_account.load_mut()?; account.increment_nonce(); - account.state.cooldown_witness( - ctx.accounts.user.key, - Witness { - proof, - participant_bloom, - broadcast_bloom, - broadcast_merkle, - }, - ) + account.state.cooldown_witness(Witness { + proof, + participant_bloom, + broadcast_bloom, + broadcast_merkle, + }) } pub fn health_check( diff --git a/shared/client/src/client.rs b/shared/client/src/client.rs index 479c3a2a9..db4f3af0f 100644 --- a/shared/client/src/client.rs +++ b/shared/client/src/client.rs @@ -133,7 +133,6 @@ impl + 'sta let mut retry_check_interval = interval(DOWNLOAD_RETRY_CHECK_INTERVAL); let mut opportunistic_witness_interval = interval(OPPROTUNISTIC_WITNESS_INTERVAL); let mut check_connection_interval = interval(CHECK_CONNECTION_INTERVAL); - let mut _wait_for_checkpoint = false; let mut last_gossip_connection_time = SystemTime::now(); debug!("Starting client loop"); @@ -141,9 +140,6 @@ impl + 'sta select! { _ = cancel.cancelled() => { info!("Got request to cancel main client loop"); - if run.doing_checkpoint() { - _wait_for_checkpoint = true; - } break; } diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index 3e9ecbbd1..54450921d 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -37,6 +37,9 @@ pub enum CooldownError { #[error("error while checkpointing: {0}")] Checkpoint(#[from] CheckpointError), + + #[error("error in cooldown step: {0}")] + CoordinatorError(#[from] psyche_coordinator::CoordinatorError), } pub struct CooldownStepMetadata { @@ -138,15 +141,15 @@ impl CooldownStepMetadata { let tx_model = self.tx_model.clone(); let model_task_runner = self.model_task_runner.clone(); let delete_queue = self.delete_queue.clone(); - let checkpointer_selection = CheckpointerSelection::from_coordinator(state, 0) - .map_err(|_| CooldownError::NoTrainers)?; + let checkpointer_selection = CheckpointerSelection::from_coordinator(state, 0)?; let is_checkpointer = checkpointer_selection - .get_checkpointer(client_index, state.epoch_state.clients.len() as u64); + .is_checkpointer(client_index, state.epoch_state.clients.len() as u64); let cancellation_token = tokio_util::sync::CancellationToken::new(); - let checkpointing_and_evals: CheckpointAndEvalsHandle = tokio::task::spawn({ - let cancellation_token = cancellation_token.clone(); - async move { + let checkpointing_and_evals: JoinHandle> = + tokio::task::spawn({ + let cancellation_token = cancellation_token.clone(); + async move { info!("Extracting full model..."); let (variables, trainer) = tokio::task::spawn_blocking::<_, Result<_, CheckpointError>>(|| { @@ -182,7 +185,7 @@ impl CooldownStepMetadata { let evals = model_task_runner.start(trainers); if !is_checkpointer { info!("Skipping checkpoint upload as this node is not the checkpointer for this epoch"); - return Ok((evals, None)); + return Ok(evals); } let CheckpointConfig { @@ -192,7 +195,6 @@ impl CooldownStepMetadata { keep_steps, } = checkpoint_info; - // Do the upload inline instead of spawning let path = checkpoint_dir.join(format!("{run_id}-step{step}")); let local = save_checkpoint_locally(path, variables, checkpoint_extra_files).await?; @@ -211,10 +213,10 @@ impl CooldownStepMetadata { ) .await; - Ok((evals, None)) // No separate handle needed + Ok(evals) } .instrument(info_span!("checkpointing")) - }); + }); Ok(CooldownStep { checkpointing_and_evals, @@ -263,38 +265,20 @@ async fn upload_checkpoint( } } -type CheckpointAndEvalsHandle = JoinHandle< - Result< - ( - RunningEvals, - Option>>, - ), - CheckpointError, - >, ->; - #[derive(Debug)] pub struct CooldownStep { - checkpointing_and_evals: CheckpointAndEvalsHandle, + checkpointing_and_evals: JoinHandle>, cancellation_token: tokio_util::sync::CancellationToken, } impl CooldownStep { - pub async fn finish( - self, - ) -> Result< - ( - RunningEvals, - Option>>, - ), - CooldownError, - > { - let (running_evals, upload_handle) = self + pub async fn finish(self) -> Result { + let running_evals = self .checkpointing_and_evals .await .map_err(|_| CooldownError::CheckpointThreadCrashed)??; - Ok((running_evals, upload_handle)) + Ok(running_evals) } pub fn cancel(&self) { diff --git a/shared/client/src/state/steps.rs b/shared/client/src/state/steps.rs index 64a826ffd..29f96c79b 100644 --- a/shared/client/src/state/steps.rs +++ b/shared/client/src/state/steps.rs @@ -358,12 +358,11 @@ impl StepStateMachine StepStateMachine StepStateMachine { let trainers = witnessing.finish().await?.stop_evals().await?; - // check here - self.cleanup_completed_uploads(); ActiveStep::Cooldown(self.cooldown.start(trainers, &state, client_index)?) } @@ -903,9 +897,11 @@ impl StepStateMachine { - cooldown.cancel(); // Cancel any ongoing upload + // If we reach state it means at least one of the clients has successfully uploaded the model checkpoint. + // We can cancel any of the other uploads in progress. + cooldown.cancel(); - let (trainers, _) = cooldown.finish().await?; + let trainers = cooldown.finish().await?; self.sent_cooldown_witness = false; ActiveStep::Warmup(self.warmup.start( trainers, @@ -940,11 +936,6 @@ impl StepStateMachine RunManager { } Ok(()) } - - pub fn doing_checkpoint(&self) -> bool { - // match &self.0 { - // InitStage::Running(step_state_machine) => { - // let has_pending_uploads = step_state_machine - // .pending_upload_handles - // .iter() - // .any(|handle| !handle.is_finished()); - - // has_pending_uploads - // } - // _ => false, - // } - false - } } impl From<&RunManager> diff --git a/shared/coordinator/src/checkpointer.rs b/shared/coordinator/src/checkpointer.rs new file mode 100644 index 000000000..29f8dd417 --- /dev/null +++ b/shared/coordinator/src/checkpointer.rs @@ -0,0 +1,57 @@ +use crate::{Coordinator, CoordinatorError}; +use psyche_core::{NodeIdentity, compute_shuffled_index, sha256, sha256v}; + +use super::types::salts; + +#[derive(Clone)] +pub struct CheckpointerSelection { + cooldown_nodes: u64, + seed: [u8; 32], +} + +impl CheckpointerSelection { + pub fn new(cooldown_nodes: u64, seed: [u8; 32]) -> Self { + Self { + cooldown_nodes, + seed, + } + } + + pub fn from_coordinator( + coordinator: &Coordinator, + offset: isize, + ) -> Result { + let round = get_round_by_offset(coordinator, offset)?; + let seed = sha256(&round.random_seed.to_le_bytes()); + + Ok(Self { + cooldown_nodes: coordinator.config.checkpointer_nodes as u64, + seed, + }) + } + + pub fn is_checkpointer(&self, client_index: u64, total_clients: u64) -> bool { + let final_seed = compute_salted_seed(&self.seed, salts::COOLDOWN); + let index = compute_shuffled_index(client_index, total_clients, &final_seed); + index < self.cooldown_nodes + } +} + +pub(crate) fn get_round_by_offset( + coordinator: &Coordinator, + offset: isize, +) -> Result<&crate::Round, CoordinatorError> { + match offset { + -2 => coordinator.previous_previous_round(), + -1 => coordinator.previous_round(), + 0 => coordinator.current_round(), + _ => return Err(CoordinatorError::NoActiveRound), + } + .ok_or(CoordinatorError::NoActiveRound) +} + +pub(crate) fn compute_salted_seed(seed: &[u8; 32], salt: &str) -> [u8; 32] { + let mut result = [0u8; 32]; + result.copy_from_slice(&sha256v(&[&sha256(seed), salt.as_bytes()])); + result +} diff --git a/shared/coordinator/src/committee.rs b/shared/coordinator/src/committee.rs new file mode 100644 index 000000000..e7e94dd84 --- /dev/null +++ b/shared/coordinator/src/committee.rs @@ -0,0 +1,169 @@ +use crate::{Client, Coordinator, CoordinatorError, SOLANA_MAX_NUM_WITNESSES}; +use psyche_core::{NodeIdentity, compute_shuffled_index, sha256}; + +use super::checkpointer::get_round_by_offset; +use super::types::{Committee, CommitteeProof, WitnessProof, salts}; + +#[derive(Clone)] +pub struct CommitteeSelection { + pub(crate) tie_breaker_nodes: u64, + pub(crate) verifier_nodes: u64, + pub(crate) total_nodes: u64, + pub(crate) witness_nodes: u64, + pub(crate) seed: [u8; 32], +} + +impl CommitteeSelection { + pub fn new( + tie_breaker_nodes: usize, + witness_nodes: usize, + verification_percent: u8, + total_nodes: usize, + seed: u64, + ) -> Result { + Self::validate_params( + tie_breaker_nodes, + witness_nodes, + verification_percent, + total_nodes, + )?; + + let free_nodes = total_nodes - tie_breaker_nodes; + let verifier_nodes = (free_nodes * verification_percent as usize) / 100; + let seed = sha256(&seed.to_le_bytes()); + + Ok(Self { + tie_breaker_nodes: tie_breaker_nodes as u64, + verifier_nodes: verifier_nodes as u64, + total_nodes: total_nodes as u64, + witness_nodes: witness_nodes as u64, + seed, + }) + } + + fn validate_params( + tie_breaker_nodes: usize, + witness_nodes: usize, + verification_percent: u8, + total_nodes: usize, + ) -> Result<(), CoordinatorError> { + if total_nodes >= u64::MAX as usize { + return Err(CoordinatorError::InvalidCommitteeSelection); + } + if total_nodes < tie_breaker_nodes { + return Err(CoordinatorError::InvalidCommitteeSelection); + } + if witness_nodes != 0 && total_nodes < witness_nodes { + return Err(CoordinatorError::InvalidCommitteeSelection); + } + if verification_percent > 100 { + return Err(CoordinatorError::InvalidCommitteeSelection); + } + Ok(()) + } + + pub fn from_coordinator( + coordinator: &Coordinator, + offset: isize, + ) -> Result { + let round = get_round_by_offset(coordinator, offset)?; + Self::new( + round.tie_breaker_tasks as usize, + coordinator.config.witness_nodes as usize, + coordinator.config.verification_percent, + round.clients_len as usize, + round.random_seed, + ) + } + + pub fn get_witness(&self, index: u64) -> WitnessProof { + let position = self.compute_shuffled_index(index, salts::WITNESS); + let witness = self.is_witness_at_position(position); + WitnessProof { + witness: witness.into(), + position, + index, + } + } + + pub fn verify_witness(&self, proof: &WitnessProof) -> bool { + let position = self.compute_shuffled_index(proof.index, salts::WITNESS); + proof.position == position && proof.witness == self.is_witness_at_position(position).into() + } + + pub fn verify_witness_for_client( + &self, + client_id: &T, + proof: &WitnessProof, + clients: &[Client], + ) -> bool { + Self::verify_client(client_id, proof.index, clients) && self.verify_witness(proof) + } + + fn is_witness_at_position(&self, position: u64) -> bool { + match self.witness_nodes { + 0 => position < SOLANA_MAX_NUM_WITNESSES as u64, + witness_nodes => position < witness_nodes, + } + } + + pub fn get_committee(&self, index: u64) -> CommitteeProof { + let position = self.compute_shuffled_index(index, salts::COMMITTEE); + let committee = self.get_committee_from_position(position); + CommitteeProof { + committee, + position, + index, + } + } + + pub fn get_committee_from_position(&self, position: u64) -> Committee { + if position < self.tie_breaker_nodes { + Committee::TieBreaker + } else if position < self.tie_breaker_nodes + self.verifier_nodes { + Committee::Verifier + } else { + Committee::Trainer + } + } + + pub fn verify_committee(&self, proof: &CommitteeProof) -> bool { + let position = self.compute_shuffled_index(proof.index, salts::COMMITTEE); + proof.position == position && proof.committee == self.get_committee_from_position(position) + } + + pub fn verify_committee_for_client( + &self, + client_id: &T, + proof: &CommitteeProof, + clients: &[Client], + ) -> bool { + Self::verify_client(client_id, proof.index, clients) && self.verify_committee(proof) + } + + fn verify_client(client_id: &T, index: u64, clients: &[Client]) -> bool { + clients.get(index as usize).map(|c| &c.id) == Some(client_id) + } + + fn compute_shuffled_index(&self, index: u64, salt: &str) -> u64 { + let mut seed = [0u8; 32]; + seed.copy_from_slice(&psyche_core::sha256v(&[&self.seed, salt.as_bytes()])); + compute_shuffled_index(index, self.total_nodes, &seed) + } + + pub fn get_seed(&self) -> [u8; 32] { + self.seed + } + + pub fn get_num_tie_breaker_nodes(&self) -> u64 { + self.tie_breaker_nodes + } + + pub fn get_num_verifier_nodes(&self) -> u64 { + self.verifier_nodes + } + + pub fn get_num_trainer_nodes(&self) -> u64 { + self.total_nodes - self.tie_breaker_nodes - self.verifier_nodes + } +} diff --git a/shared/coordinator/src/committee_selection.rs b/shared/coordinator/src/committee_selection.rs deleted file mode 100644 index 9fe92262b..000000000 --- a/shared/coordinator/src/committee_selection.rs +++ /dev/null @@ -1,460 +0,0 @@ -use crate::{Client, Coordinator, CoordinatorError, SOLANA_MAX_NUM_WITNESSES}; - -use anchor_lang::{AnchorDeserialize, AnchorSerialize, InitSpace, prelude::borsh}; -use bytemuck::Zeroable; -use psyche_core::{NodeIdentity, SmallBoolean, compute_shuffled_index, sha256, sha256v}; -use serde::{Deserialize, Serialize}; -use ts_rs::TS; - -pub const COMMITTEE_SALT: &str = "committee"; -pub const WITNESS_SALT: &str = "witness"; -pub const COOLDOWN_SALT: &str = "cooldown"; - -#[derive(Clone)] -pub struct CheckpointerSelection { - cooldown_nodes: u64, - seed: [u8; 32], -} - -impl CheckpointerSelection { - pub fn from_coordinator( - coordinator: &Coordinator, - offset: isize, - ) -> Result { - let round = match offset { - -2 => coordinator.previous_previous_round(), - -1 => coordinator.previous_round(), - 0 => coordinator.current_round(), - _ => { - return Err(CoordinatorError::NoActiveRound); - } - } - .ok_or(CoordinatorError::NoActiveRound)?; - let seed = sha256(&round.random_seed.to_le_bytes()); - Ok(Self { - cooldown_nodes: coordinator.config.checkpointer_nodes as u64, - seed, - }) - } - - pub fn get_checkpointer(&self, client_index: u64, total_clients: u64) -> bool { - let mut final_seed = [0u8; 32]; - final_seed.copy_from_slice(&sha256v(&[&sha256(&self.seed), COOLDOWN_SALT.as_bytes()])); - let index = compute_shuffled_index(client_index, total_clients, &final_seed); - index < self.cooldown_nodes - } -} - -#[derive( - Clone, - Copy, - Debug, - Default, - PartialEq, - Zeroable, - AnchorDeserialize, - AnchorSerialize, - Serialize, - Deserialize, -)] -#[repr(C)] -pub enum Committee { - #[default] - TieBreaker, - Verifier, - Trainer, -} - -#[derive(Clone)] -pub struct CommitteeSelection { - tie_breaker_nodes: u64, - verifier_nodes: u64, - total_nodes: u64, - witness_nodes: u64, - seed: [u8; 32], -} - -#[derive( - Clone, - Copy, - Debug, - Default, - PartialEq, - Zeroable, - AnchorDeserialize, - AnchorSerialize, - Serialize, - Deserialize, -)] -#[repr(C)] -pub struct CommitteeProof { - pub committee: Committee, - pub position: u64, - pub index: u64, -} - -#[derive( - Clone, - Copy, - Debug, - PartialEq, - Zeroable, - Default, - AnchorDeserialize, - AnchorSerialize, - Serialize, - Deserialize, - InitSpace, - TS, -)] -#[repr(C)] -pub struct WitnessProof { - // position in virtual shuffle, as determined by seed - pub position: u64, - // index into epoch_state.clients of sender - pub index: u64, - // assertion of witness membership or non-membership - pub witness: SmallBoolean, -} - -impl CommitteeSelection { - pub fn new( - tie_breaker_nodes: usize, - witness_nodes: usize, - verification_percent: u8, - total_nodes: usize, - seed: u64, - ) -> Result { - if total_nodes >= u64::MAX as usize { - return Err(CoordinatorError::InvalidCommitteeSelection); - } - - if total_nodes < tie_breaker_nodes { - return Err(CoordinatorError::InvalidCommitteeSelection); - } - - if witness_nodes != 0 && total_nodes < witness_nodes { - return Err(CoordinatorError::InvalidCommitteeSelection); - } - - if verification_percent > 100 { - return Err(CoordinatorError::InvalidCommitteeSelection); - } - - let free_nodes = total_nodes - tie_breaker_nodes; - let verifier_nodes = (free_nodes * verification_percent as usize) / 100; - - let seed = sha256(&seed.to_le_bytes()); - - Ok(Self { - tie_breaker_nodes: tie_breaker_nodes as u64, - verifier_nodes: verifier_nodes as u64, - total_nodes: total_nodes as u64, - witness_nodes: witness_nodes as u64, - seed, - }) - } - - pub fn from_coordinator( - coordinator: &Coordinator, - offset: isize, - ) -> Result { - let round = match offset { - -2 => coordinator.previous_previous_round(), - -1 => coordinator.previous_round(), - 0 => coordinator.current_round(), - _ => { - return Err(CoordinatorError::NoActiveRound); - } - } - .ok_or(CoordinatorError::NoActiveRound)?; - Self::new( - round.tie_breaker_tasks as usize, - coordinator.config.witness_nodes as usize, - coordinator.config.verification_percent, - round.clients_len as usize, - round.random_seed, - ) - } - - pub fn get_witness(&self, index: u64) -> WitnessProof { - let position = self.compute_shuffled_index(index, WITNESS_SALT); - let witness = self.get_witness_from_position(position); - WitnessProof { - witness: witness.into(), - position, - index, - } - } - - pub fn get_committee(&self, index: u64) -> CommitteeProof { - let position = self.compute_shuffled_index(index, COMMITTEE_SALT); - let committee = self.get_committee_from_position(position); - CommitteeProof { - committee, - position, - index, - } - } - - pub fn get_committee_from_position(&self, committee_position: u64) -> Committee { - if committee_position < self.tie_breaker_nodes { - Committee::TieBreaker - } else if committee_position < self.tie_breaker_nodes + self.verifier_nodes { - Committee::Verifier - } else { - Committee::Trainer - } - } - - fn get_witness_from_position(&self, witness_position: u64) -> bool { - match self.witness_nodes { - 0 => witness_position < SOLANA_MAX_NUM_WITNESSES as u64, - witness_nodes => witness_position < witness_nodes, - } - } - - pub fn verify_committee_for_client( - &self, - client_id: &T, - proof: &CommitteeProof, - clients: &[Client], - ) -> bool { - Self::verify_client(client_id, proof.index, clients) && self.verify_committee(proof) - } - - pub fn verify_witness_for_client( - &self, - client_id: &T, - proof: &WitnessProof, - clients: &[Client], - ) -> bool { - Self::verify_client(client_id, proof.index, clients) && self.verify_witness(proof) - } - - fn verify_client(client_id: &T, index: u64, clients: &[Client]) -> bool { - clients.get(index as usize).map(|c| &c.id) == Some(client_id) - } - - fn verify_committee(&self, proof: &CommitteeProof) -> bool { - let position = self.compute_shuffled_index(proof.index, COMMITTEE_SALT); - proof.position == position && proof.committee == self.get_committee_from_position(position) - } - - fn verify_witness(&self, proof: &WitnessProof) -> bool { - let position = self.compute_shuffled_index(proof.index, WITNESS_SALT); - proof.position == position - && proof.witness == self.get_witness_from_position(position).into() - } - - fn compute_shuffled_index(&self, index: u64, salt: &str) -> u64 { - let mut seed = [0u8; 32]; - seed.copy_from_slice(&sha256v(&[&self.seed, salt.as_bytes()])); - - compute_shuffled_index(index, self.total_nodes, &seed) - } - - pub fn get_seed(&self) -> [u8; 32] { - self.seed - } - - pub fn get_num_tie_breaker_nodes(&self) -> u64 { - self.tie_breaker_nodes - } - - pub fn get_num_verifier_nodes(&self) -> u64 { - self.verifier_nodes - } - - pub fn get_num_trainer_nodes(&self) -> u64 { - self.total_nodes - self.tie_breaker_nodes - self.verifier_nodes - } -} - -impl std::fmt::Display for Committee { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Committee::TieBreaker => write!(f, "Tie breaker"), - Committee::Verifier => write!(f, "Verifier"), - Committee::Trainer => write!(f, "Trainer"), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_new_committee_selection() { - let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); - assert_eq!(cs.tie_breaker_nodes, 10); - assert_eq!(cs.witness_nodes, 20); - assert_eq!(cs.verifier_nodes, 27); // (100 - 10) * 30% = 27 - assert_eq!(cs.total_nodes, 100); - } - - #[test] - fn test_get_committee() { - let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); - - // Test for all possible indexes - for i in 0..100 { - let proof = cs.get_committee(i); - assert!(proof.position < 100); - - // Verify that the committee matches the position - match proof.committee { - Committee::TieBreaker => assert!(proof.position < 10), - Committee::Verifier => assert!(proof.position >= 10 && proof.position < 37), - Committee::Trainer => assert!(proof.position >= 37), - } - } - } - - #[test] - fn test_get_witness() { - let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); - - // Test for all possible indexes - for i in 0..100 { - let proof = cs.get_witness(i); - assert!(proof.position < 100); - - // Verify that the witness status matches the position - if proof.witness.is_true() { - assert!(proof.position < 20); - } else { - assert!(proof.position >= 20); - } - } - } - - #[test] - fn test_verify_committee() { - let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); - - for i in 0..100 { - let proof = cs.get_committee(i); - assert!(cs.verify_committee(&proof)); - - // Test with incorrect proof - let incorrect_proof = CommitteeProof { - committee: Committee::Verifier, - position: 99, - index: i, - }; - assert!(!cs.verify_committee(&incorrect_proof)); - } - } - - #[test] - fn test_verify_witness() { - let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); - - for i in 0..100 { - let proof = cs.get_witness(i); - assert!(cs.verify_witness(&proof)); - - // Test with incorrect proof - let incorrect_proof = WitnessProof { - witness: !proof.witness, - position: 99, - index: i, - }; - assert!(!cs.verify_witness(&incorrect_proof)); - } - } - - #[test] - fn test_committee_distribution() { - let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); - let mut tie_breaker_count = 0; - let mut verifier_count = 0; - let mut trainer_count = 0; - - for i in 0..100 { - match cs.get_committee(i).committee { - Committee::TieBreaker => tie_breaker_count += 1, - Committee::Verifier => verifier_count += 1, - Committee::Trainer => trainer_count += 1, - } - } - - assert_eq!(tie_breaker_count, 10); - assert_eq!(verifier_count, 27); - assert_eq!(trainer_count, 63); - } - - #[test] - fn test_witness_distribution() { - let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); - let mut witness_count = 0; - - for i in 0..100 { - if cs.get_witness(i).witness.is_true() { - witness_count += 1; - } - } - - assert_eq!(witness_count, 20); - } - - #[test] - fn test_get_num_nodes() { - let cs = CommitteeSelection::new(10, 5, 20, 100, 12345).unwrap(); - assert_eq!(cs.get_num_tie_breaker_nodes(), 10); - assert_eq!(cs.get_num_verifier_nodes(), 18); - assert_eq!(cs.get_num_trainer_nodes(), 72); - } - - #[test] - fn test_seed_consistency() { - let cs1 = CommitteeSelection::new(10, 5, 20, 100, 12345).unwrap(); - let cs2 = CommitteeSelection::new(10, 5, 20, 100, 12345).unwrap(); - assert_eq!(cs1.get_seed(), cs2.get_seed()); - } - - #[test] - fn test_invalid_total_nodes() { - assert!(CommitteeSelection::new(10, 5, 20, 9, 12345).is_err()); - } - - #[test] - fn test_invalid_comittee_selections() { - // verification_percent > 100 - assert!(CommitteeSelection::new(10, 5, 101, 100, 12345).is_err()); - // total_nodes < tie_breaker_nodes - assert!(CommitteeSelection::new(10, 5, 101, 5, 12345).is_err()); - // total_nodes < witness_nodes - assert!(CommitteeSelection::new(10, 50, 101, 11, 12345).is_err()); - // total_nodes >= u64::MAX - assert!(CommitteeSelection::new(10, 50, 101, u64::MAX as usize, 12345).is_err()); - } - - #[test] - fn test_edge_case_all_tie_breakers() { - let cs = CommitteeSelection::new(100, 5, 20, 100, 12345).unwrap(); - for i in 0..100 { - let committee = cs.get_committee(i).committee; - assert_eq!(committee, Committee::TieBreaker); - } - } - - #[test] - fn test_edge_case_no_verifiers() { - let cs = CommitteeSelection::new(10, 5, 0, 100, 12345).unwrap(); - let mut tie_breaker_count = 0; - let mut trainer_count = 0; - for i in 0..100 { - let committee = cs.get_committee(i).committee; - match committee { - Committee::TieBreaker => tie_breaker_count += 1, - Committee::Trainer => trainer_count += 1, - _ => panic!("Unexpected committee type"), - } - } - assert_eq!(tie_breaker_count, 10); - assert_eq!(trainer_count, 90); - } -} diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index b789cb57d..84372c17b 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -513,7 +513,6 @@ impl Coordinator { pub fn cooldown_witness( &mut self, - _from: &T, witness: Witness, ) -> std::result::Result<(), CoordinatorError> { if self.halted() { @@ -525,9 +524,9 @@ impl Coordinator { } let checkpointer_selection = CheckpointerSelection::from_coordinator(self, 0)?; - let is_checkpointer = checkpointer_selection - .get_checkpointer(witness.proof.index, self.epoch_state.clients.len() as u64); - if !is_checkpointer { + if !checkpointer_selection + .is_checkpointer(witness.proof.index, self.epoch_state.clients.len() as u64) + { return Err(CoordinatorError::InvalidWitness); } @@ -670,9 +669,9 @@ impl Coordinator { return Err(CoordinatorError::InvalidRunState); } let checkpointer_selection = CheckpointerSelection::from_coordinator(self, 0)?; - let is_checkpointer = checkpointer_selection - .get_checkpointer(index as u64, self.epoch_state.clients.len() as u64); - if !is_checkpointer { + if !checkpointer_selection + .is_checkpointer(index as u64, self.epoch_state.clients.len() as u64) + { return Err(CoordinatorError::InvalidWitness); } diff --git a/shared/coordinator/src/lib.rs b/shared/coordinator/src/lib.rs index 3f3dc2c11..e366ee5cf 100644 --- a/shared/coordinator/src/lib.rs +++ b/shared/coordinator/src/lib.rs @@ -1,16 +1,19 @@ #![allow(unexpected_cfgs)] +mod checkpointer; mod commitment; -mod committee_selection; +mod committee; mod coordinator; mod data_selection; pub mod model; +mod types; +#[cfg(test)] +mod tests; + +pub use checkpointer::CheckpointerSelection; pub use commitment::Commitment; -pub use committee_selection::{ - COMMITTEE_SALT, CheckpointerSelection, Committee, CommitteeProof, CommitteeSelection, - WITNESS_SALT, WitnessProof, -}; +pub use committee::CommitteeSelection; pub use coordinator::{ BLOOM_FALSE_RATE, Client, ClientState, Coordinator, CoordinatorConfig, CoordinatorEpochState, CoordinatorError, CoordinatorProgress, HealthChecks, MAX_TOKENS_TO_SEND, NUM_STORED_ROUNDS, @@ -21,3 +24,4 @@ pub use coordinator::{ pub use data_selection::{ assign_data_for_state, get_batch_ids_for_node, get_batch_ids_for_round, get_data_index_for_step, }; +pub use types::{Committee, CommitteeProof, WitnessProof, salts}; diff --git a/shared/coordinator/src/tests.rs b/shared/coordinator/src/tests.rs new file mode 100644 index 000000000..7c7e75548 --- /dev/null +++ b/shared/coordinator/src/tests.rs @@ -0,0 +1,174 @@ +use super::*; + +#[test] +fn test_new_committee_selection() { + let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); + assert_eq!(cs.tie_breaker_nodes, 10); + assert_eq!(cs.witness_nodes, 20); + assert_eq!(cs.verifier_nodes, 27); // (100 - 10) * 30% = 27 + assert_eq!(cs.total_nodes, 100); +} + +#[test] +fn test_get_committee() { + let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); + + // Test for all possible indexes + for i in 0..100 { + let proof = cs.get_committee(i); + assert!(proof.position < 100); + + // Verify that the committee matches the position + match proof.committee { + Committee::TieBreaker => assert!(proof.position < 10), + Committee::Verifier => assert!(proof.position >= 10 && proof.position < 37), + Committee::Trainer => assert!(proof.position >= 37), + } + } +} + +#[test] +fn test_get_witness() { + let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); + + // Test for all possible indexes + for i in 0..100 { + let proof = cs.get_witness(i); + assert!(proof.position < 100); + + // Verify that the witness status matches the position + if proof.witness.is_true() { + assert!(proof.position < 20); + } else { + assert!(proof.position >= 20); + } + } +} + +#[test] +fn test_verify_committee() { + let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); + + for i in 0..100 { + let proof = cs.get_committee(i); + assert!(cs.verify_committee(&proof)); + + // Test with incorrect proof + let incorrect_proof = CommitteeProof { + committee: Committee::Verifier, + position: 99, + index: i, + }; + assert!(!cs.verify_committee(&incorrect_proof)); + } +} + +#[test] +fn test_verify_witness() { + let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); + + for i in 0..100 { + let proof = cs.get_witness(i); + assert!(cs.verify_witness(&proof)); + + // Test with incorrect proof + let incorrect_proof = WitnessProof { + witness: !proof.witness, + position: 99, + index: i, + }; + assert!(!cs.verify_witness(&incorrect_proof)); + } +} + +#[test] +fn test_committee_distribution() { + let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); + let mut tie_breaker_count = 0; + let mut verifier_count = 0; + let mut trainer_count = 0; + + for i in 0..100 { + match cs.get_committee(i).committee { + Committee::TieBreaker => tie_breaker_count += 1, + Committee::Verifier => verifier_count += 1, + Committee::Trainer => trainer_count += 1, + } + } + + assert_eq!(tie_breaker_count, 10); + assert_eq!(verifier_count, 27); + assert_eq!(trainer_count, 63); +} + +#[test] +fn test_witness_distribution() { + let cs = CommitteeSelection::new(10, 20, 30, 100, 12345).unwrap(); + let mut witness_count = 0; + + for i in 0..100 { + if cs.get_witness(i).witness.is_true() { + witness_count += 1; + } + } + + assert_eq!(witness_count, 20); +} + +#[test] +fn test_get_num_nodes() { + let cs = CommitteeSelection::new(10, 5, 20, 100, 12345).unwrap(); + assert_eq!(cs.get_num_tie_breaker_nodes(), 10); + assert_eq!(cs.get_num_verifier_nodes(), 18); + assert_eq!(cs.get_num_trainer_nodes(), 72); +} + +#[test] +fn test_seed_consistency() { + let cs1 = CommitteeSelection::new(10, 5, 20, 100, 12345).unwrap(); + let cs2 = CommitteeSelection::new(10, 5, 20, 100, 12345).unwrap(); + assert_eq!(cs1.get_seed(), cs2.get_seed()); +} + +#[test] +fn test_invalid_total_nodes() { + assert!(CommitteeSelection::new(10, 5, 20, 9, 12345).is_err()); +} + +#[test] +fn test_invalid_comittee_selections() { + // verification_percent > 100 + assert!(CommitteeSelection::new(10, 5, 101, 100, 12345).is_err()); + // total_nodes < tie_breaker_nodes + assert!(CommitteeSelection::new(10, 5, 101, 5, 12345).is_err()); + // total_nodes < witness_nodes + assert!(CommitteeSelection::new(10, 50, 101, 11, 12345).is_err()); + // total_nodes >= u64::MAX + assert!(CommitteeSelection::new(10, 50, 101, u64::MAX as usize, 12345).is_err()); +} + +#[test] +fn test_edge_case_all_tie_breakers() { + let cs = CommitteeSelection::new(100, 5, 20, 100, 12345).unwrap(); + for i in 0..100 { + let committee = cs.get_committee(i).committee; + assert_eq!(committee, Committee::TieBreaker); + } +} + +#[test] +fn test_edge_case_no_verifiers() { + let cs = CommitteeSelection::new(10, 5, 0, 100, 12345).unwrap(); + let mut tie_breaker_count = 0; + let mut trainer_count = 0; + for i in 0..100 { + let committee = cs.get_committee(i).committee; + match committee { + Committee::TieBreaker => tie_breaker_count += 1, + Committee::Trainer => trainer_count += 1, + _ => panic!("Unexpected committee type"), + } + } + assert_eq!(tie_breaker_count, 10); + assert_eq!(trainer_count, 90); +} diff --git a/shared/coordinator/src/types.rs b/shared/coordinator/src/types.rs new file mode 100644 index 000000000..ac49815d7 --- /dev/null +++ b/shared/coordinator/src/types.rs @@ -0,0 +1,85 @@ +use anchor_lang::{AnchorDeserialize, AnchorSerialize, InitSpace, prelude::borsh}; +use bytemuck::Zeroable; +use psyche_core::SmallBoolean; +use serde::{Deserialize, Serialize}; +use ts_rs::TS; + +/// Salt constants for deterministic shuffling +pub mod salts { + pub const COMMITTEE: &str = "committee"; + pub const WITNESS: &str = "witness"; + pub const COOLDOWN: &str = "cooldown"; +} + +#[derive( + Clone, + Copy, + Debug, + Default, + PartialEq, + Zeroable, + AnchorDeserialize, + AnchorSerialize, + Serialize, + Deserialize, +)] +#[repr(C)] +pub enum Committee { + #[default] + TieBreaker, + Verifier, + Trainer, +} + +impl std::fmt::Display for Committee { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Committee::TieBreaker => write!(f, "Tie breaker"), + Committee::Verifier => write!(f, "Verifier"), + Committee::Trainer => write!(f, "Trainer"), + } + } +} + +#[derive( + Clone, + Copy, + Debug, + Default, + PartialEq, + Zeroable, + AnchorDeserialize, + AnchorSerialize, + Serialize, + Deserialize, +)] +#[repr(C)] +pub struct CommitteeProof { + pub committee: Committee, + pub position: u64, + pub index: u64, +} + +#[derive( + Clone, + Copy, + Debug, + Default, + PartialEq, + Zeroable, + AnchorDeserialize, + AnchorSerialize, + Serialize, + Deserialize, + InitSpace, + TS, +)] +#[repr(C)] +pub struct WitnessProof { + /// Position in virtual shuffle, as determined by seed + pub position: u64, + /// Index into epoch_state.clients of sender + pub index: u64, + /// Assertion of witness membership or non-membership + pub witness: SmallBoolean, +} diff --git a/shared/data-provider/Cargo.toml b/shared/data-provider/Cargo.toml index ea0e59791..1d6f123a7 100644 --- a/shared/data-provider/Cargo.toml +++ b/shared/data-provider/Cargo.toml @@ -26,7 +26,7 @@ thiserror.workspace = true postcard.workspace = true bytemuck.workspace = true reqwest = "0.12.12" -google-cloud-storage = "0.24.0" +google-cloud-storage.workspace = true ts-rs.workspace = true rayon.workspace = true diff --git a/shared/watcher/src/traits.rs b/shared/watcher/src/traits.rs index 944f35657..f19c6af4d 100644 --- a/shared/watcher/src/traits.rs +++ b/shared/watcher/src/traits.rs @@ -29,4 +29,8 @@ pub trait Backend: Send + Sync { async fn wait_for_new_state(&mut self) -> Result>; async fn send_witness(&mut self, opportunistic_data: OpportunisticData) -> Result<()>; async fn send_health_check(&mut self, health_check: HealthChecks) -> Result<()>; + async fn send_checkpoint( + &mut self, + checkpoint: psyche_coordinator::model::Checkpoint, + ) -> Result<()>; } From d19df41b7bd4cb6e1bd8040631a23b01124b0c4a Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Tue, 13 Jan 2026 18:01:19 -0300 Subject: [PATCH 35/72] Remove comments --- architectures/centralized/client/src/app.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index 4e5d20d45..47b8d5870 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -195,7 +195,6 @@ impl App { } } Some(UploadInfo::Gcs(gcs_info)) => { - // Create GCS client let config = ClientConfig::default().with_auth().await?; let client = GcsClient::new(config); @@ -211,14 +210,13 @@ impl App { bucket: gcs_info.gcs_bucket.clone(), ..Default::default() }, - vec![], // empty content + vec![], &UploadType::Simple(Media::new(test_key.clone())), ) .await; match upload_result { Ok(_) => { - // Clean up test object let delete_request = DeleteObjectRequest { bucket: gcs_info.gcs_bucket.clone(), object: test_key.clone(), From 86ba846d75d7eaeae25dc01b2bee2c93e1e9bb41 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Wed, 14 Jan 2026 11:22:00 -0300 Subject: [PATCH 36/72] Remove hub-repo and gcs-bucket from train args --- architectures/centralized/client/src/app.rs | 29 +++++++- shared/client/src/cli.rs | 78 ++------------------- shared/client/src/state/cooldown.rs | 38 +++++++++- shared/client/src/state/types.rs | 2 +- shared/data-provider/src/hub.rs | 2 +- 5 files changed, 70 insertions(+), 79 deletions(-) diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index 8f6a9bcca..d884a79d0 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -2,11 +2,16 @@ use anyhow::{Error, Result}; use bytemuck::Zeroable; use hf_hub::Repo; use psyche_centralized_shared::{ClientId, ClientToServerMessage, ServerToClientMessage}; +use psyche_client::GcsUploadInfo; use psyche_client::HubUploadInfo; use psyche_client::UploadInfo; use psyche_client::{ Client, ClientTUI, ClientTUIState, NC, RunInitConfig, TrainArgs, read_identity_secret_key, }; +use psyche_coordinator::model::GcsRepo; +use psyche_coordinator::model::HubRepo; +use psyche_coordinator::model::LLM; +use psyche_coordinator::model::Model; use psyche_coordinator::{Coordinator, HealthChecks, model}; use psyche_metrics::ClientMetrics; use psyche_network::{ @@ -174,11 +179,31 @@ impl App { state_options: RunInitConfig, ) -> Result<()> { // sanity checks - if let Some(checkpoint_config) = &state_options.checkpoint_config { + let Model::LLM(LLM { checkpoint, .. }) = &self.coordinator_state.model; + + let upload_info = match checkpoint { + model::Checkpoint::Hub(HubRepo { repo_id, revision }) + | model::Checkpoint::P2P(HubRepo { repo_id, revision }) => { + Some(UploadInfo::Hub(HubUploadInfo { + hub_repo: (repo_id).into(), + hub_token: (&revision.unwrap_or_default()).into(), + })) + } + model::Checkpoint::Gcs(GcsRepo { bucket, prefix }) + | model::Checkpoint::P2PGcs(model::GcsRepo { bucket, prefix }) => { + Some(UploadInfo::Gcs(GcsUploadInfo { + gcs_bucket: (bucket).into(), + gcs_prefix: Some((&prefix.unwrap_or_default()).into()), + })) + } + _ => None, + }; + + if state_options.checkpoint_config.is_some() { if let Some(UploadInfo::Hub(HubUploadInfo { hub_repo, hub_token, - })) = &checkpoint_config.upload_info + })) = &upload_info { let api = hf_hub::api::tokio::ApiBuilder::new() .with_token(Some(hub_token.clone())) diff --git a/shared/client/src/cli.rs b/shared/client/src/cli.rs index a4ef145f0..48f6fa210 100644 --- a/shared/client/src/cli.rs +++ b/shared/client/src/cli.rs @@ -1,9 +1,7 @@ use crate::{CheckpointConfig, WandBInfo}; -use crate::UploadInfo; use anyhow::{Result, anyhow, bail}; use clap::Args; -use psyche_data_provider::{GcsUploadInfo, HubUploadInfo}; use psyche_eval::tasktype_from_name; use psyche_modeling::Devices; use psyche_network::{DiscoveryMode, RelayKind, SecretKey}; @@ -141,20 +139,8 @@ pub struct TrainArgs { pub prompt_task: bool, /// If provided, every model parameters update will be save in this directory after each epoch. - #[clap(long, env)] - pub checkpoint_dir: Option, - - /// Path to the Hugging Face repository containing model data and configuration. - #[clap(long, env)] - pub hub_repo: Option, - - /// Name of the GCS bucket containing model data and configuration. - #[clap(long, env)] - pub gcs_bucket: Option, - - /// Prefix within the GCS bucket for model data and configuration. - #[clap(long, env)] - pub gcs_prefix: Option, + #[clap(long, env, default_value = "~/.cache/psyche/checkpoints")] + pub checkpoint_dir: PathBuf, #[clap(long, env, default_value_t = 3)] pub hub_max_concurrent_downloads: usize, @@ -233,27 +219,9 @@ impl TrainArgs { } pub fn checkpoint_config(&self) -> Result> { - let hub_read_token = std::env::var("HF_TOKEN").ok(); - - if self.hub_repo.is_some() && self.gcs_bucket.is_some() { - bail!("Use either GCS or HF hub for checkpoint uploads, not both."); - } - - let checkpoint_dir = match &self.checkpoint_dir { - Some(dir) => dir, - None => { - if self.hub_repo.is_some() || self.gcs_bucket.is_some() { - bail!( - "--hub-repo or --gcs-bucket was set, but no --checkpoint-dir was passed!" - ); - } - return Ok(None); - } - }; - - let upload_info = self.build_upload_info(&hub_read_token)?; + let hub_token = std::env::var("HF_TOKEN").ok(); - if upload_info.is_some() && self.keep_steps == 0 { + if self.keep_steps == 0 { bail!( "keep_steps must be >= 1 for checkpoint uploads (got {})", self.keep_steps @@ -261,47 +229,13 @@ impl TrainArgs { } Ok(Some(CheckpointConfig { - checkpoint_dir: checkpoint_dir.clone(), - upload_info, + checkpoint_dir: self.checkpoint_dir.clone(), delete_old_steps: self.delete_old_steps, keep_steps: self.keep_steps, + hub_token, })) } - fn build_upload_info(&self, hub_token: &Option) -> Result> { - if let Some(repo) = &self.hub_repo { - return self.build_hub_upload_info(repo, hub_token); - } - - if let Some(bucket) = &self.gcs_bucket { - return self.build_gcs_upload_info(bucket); - } - - Ok(None) - } - - fn build_hub_upload_info( - &self, - repo: &str, - token: &Option, - ) -> Result> { - let token = token.as_ref().ok_or_else(|| { - anyhow::anyhow!("hub-repo and checkpoint-dir set, but no HF_TOKEN env variable.") - })?; - - Ok(Some(UploadInfo::Hub(HubUploadInfo { - hub_repo: repo.to_string(), - hub_token: token.to_string(), - }))) - } - - fn build_gcs_upload_info(&self, bucket: &str) -> Result> { - Ok(Some(UploadInfo::Gcs(GcsUploadInfo { - gcs_bucket: bucket.to_string(), - gcs_prefix: self.gcs_prefix.clone(), - }))) - } - pub fn eval_tasks(&self) -> Result> { let eval_tasks = match &self.eval_tasks { Some(eval_tasks) => Self::eval_tasks_from_args(eval_tasks, self.eval_seed)?, diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index 28b6294be..643c3c210 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -1,10 +1,12 @@ use crate::UploadInfo; use psyche_coordinator::{ Coordinator, - model::{self}, + model::{self, HubRepo, LLM, Model}, }; use psyche_core::NodeIdentity; -use psyche_data_provider::{UploadError, upload_to_gcs, upload_to_hub}; +use psyche_data_provider::{ + GcsUploadInfo, HubUploadInfo, UploadError, upload_to_gcs, upload_to_hub, +}; #[cfg(feature = "python")] use psyche_modeling::CausalLM; use psyche_modeling::{ @@ -139,6 +141,7 @@ impl CooldownStepMetadata { let run_id = String::from(&state.run_id); let checkpoint_extra_files = self.checkpoint_extra_files.clone(); let checkpoint_info = self.checkpoint_info.clone(); + let Model::LLM(LLM { checkpoint, .. }) = state.model; let tx_checkpoint = self.tx_checkpoint.clone(); let tx_model = self.tx_model.clone(); let model_task_runner = self.model_task_runner.clone(); @@ -181,15 +184,44 @@ impl CooldownStepMetadata { let evals = model_task_runner.start(trainers); let Some(CheckpointConfig { - upload_info, checkpoint_dir, delete_old_steps, keep_steps, + hub_token, }) = checkpoint_info else { return Ok((evals, None)); }; + let upload_info = match checkpoint { + model::Checkpoint::Hub(HubRepo { + repo_id, + revision: _, + }) + | model::Checkpoint::P2P(HubRepo { + repo_id, + revision: _, + }) => { + if let Some(token) = hub_token { + Some(UploadInfo::Hub(HubUploadInfo { + hub_repo: (&repo_id).into(), + hub_token: token, + })) + } else { + warn!("HF_TOKEN env not provided, skipping upload to HuggingFace Hub"); + None + } + } + model::Checkpoint::Gcs(model::GcsRepo { bucket, prefix }) + | model::Checkpoint::P2PGcs(model::GcsRepo { bucket, prefix }) => { + Some(UploadInfo::Gcs(GcsUploadInfo { + gcs_bucket: (&bucket).into(), + gcs_prefix: Some((&prefix.unwrap_or_default()).into()), + })) + } + _ => None, + }; + let upload_handle = tokio::task::spawn(async move { let path = checkpoint_dir.join(format!("{run_id}-step{step}")); let local = diff --git a/shared/client/src/state/types.rs b/shared/client/src/state/types.rs index 29734f1a0..20196008e 100644 --- a/shared/client/src/state/types.rs +++ b/shared/client/src/state/types.rs @@ -17,10 +17,10 @@ pub enum UploadInfo { #[derive(Debug, Clone)] pub struct CheckpointConfig { - pub upload_info: Option, pub checkpoint_dir: PathBuf, pub delete_old_steps: bool, pub keep_steps: u32, + pub hub_token: Option, } #[derive(Debug)] diff --git a/shared/data-provider/src/hub.rs b/shared/data-provider/src/hub.rs index 13a575b84..7afabf628 100644 --- a/shared/data-provider/src/hub.rs +++ b/shared/data-provider/src/hub.rs @@ -213,7 +213,7 @@ pub async fn upload_to_hub( info!(repo = hub_repo, "Uploading checkpoint to HuggingFace"); let api = hf_hub::api::tokio::ApiBuilder::new() - .with_token(Some(hub_token.clone())) + .with_token(Some(hub_token)) .build()?; let repo = Repo::model(hub_repo.clone()); let api_repo = api.repo(repo); From 44826237b3c928aa21b88f1df8378a4c5ac2649c Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Wed, 14 Jan 2026 07:41:10 -0800 Subject: [PATCH 37/72] Fix prefix --- shared/client/src/state/cooldown.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index 643c3c210..7c2ef65d5 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -216,7 +216,7 @@ impl CooldownStepMetadata { | model::Checkpoint::P2PGcs(model::GcsRepo { bucket, prefix }) => { Some(UploadInfo::Gcs(GcsUploadInfo { gcs_bucket: (&bucket).into(), - gcs_prefix: Some((&prefix.unwrap_or_default()).into()), + gcs_prefix: prefix.as_ref().map(|p| p.into()), })) } _ => None, From 5ea55642654dfe51097b9d19b18bd39b56319711 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Wed, 14 Jan 2026 07:57:38 -0800 Subject: [PATCH 38/72] Fix tcp example --- shared/data-provider/examples/tcp.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shared/data-provider/examples/tcp.rs b/shared/data-provider/examples/tcp.rs index 13dcf8b44..67cc9ab9d 100644 --- a/shared/data-provider/examples/tcp.rs +++ b/shared/data-provider/examples/tcp.rs @@ -37,7 +37,7 @@ impl WatcherBackend for DummyBackend { bail!("Data provider does not send health check"); } - async fn send_checkpoint(&mut self, _checkpoint: model::HubRepo) -> anyhow::Result<()> { + async fn send_checkpoint(&mut self, _checkpoint: model::Checkpoint) -> anyhow::Result<()> { bail!("Data provider does not send checkpoints"); } } From 8da9c66e245e0e1b8d2a8c4519e2fd34d4d4e464 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Wed, 14 Jan 2026 15:09:06 -0300 Subject: [PATCH 39/72] Calculate checkpointers instead of adding new config --- .../centralized/testing/src/server.rs | 1 - .../suites/memnet_coordinator_full_round.rs | 1 - .../tests/suites/memnet_coordinator_rewards.rs | 1 - .../suites/memnet_treasurer_create_update.rs | 1 - .../suites/memnet_treasurer_full_epoch.rs | 1 - config/solana-test/light-config.toml | 1 - config/solana-test/nano-config.toml | 1 - shared/coordinator/src/checkpointer.rs | 18 ++++++++++++------ shared/coordinator/src/coordinator.rs | 3 --- 9 files changed, 12 insertions(+), 16 deletions(-) diff --git a/architectures/centralized/testing/src/server.rs b/architectures/centralized/testing/src/server.rs index 06062e5c5..0c0ec60a8 100644 --- a/architectures/centralized/testing/src/server.rs +++ b/architectures/centralized/testing/src/server.rs @@ -77,7 +77,6 @@ impl CoordinatorServer { global_batch_size_end: global_batch_size, global_batch_size_warmup_tokens: 0, verification_percent: 0, - checkpointer_nodes: 0, witness_nodes, total_steps: 100, waiting_for_members_extra_time: 2, diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_full_round.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_full_round.rs index 12f53a6de..fc03202cf 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_full_round.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_full_round.rs @@ -106,7 +106,6 @@ pub async fn run() { global_batch_size_warmup_tokens: 0, verification_percent: 0, witness_nodes: 1, - checkpointer_nodes: 0, epoch_time: 30, total_steps: 100, waiting_for_members_extra_time: 3, diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_rewards.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_rewards.rs index 9670078a1..f69bbf8c5 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_rewards.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_rewards.rs @@ -102,7 +102,6 @@ pub async fn run() { global_batch_size_warmup_tokens: 0, verification_percent: 0, witness_nodes: 0, - checkpointer_nodes: 0, epoch_time, waiting_for_members_extra_time: WAITING_FOR_MEMBERS_EXTRA_SECONDS as u8, diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_create_update.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_create_update.rs index 863f881f7..e51ced2dd 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_create_update.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_create_update.rs @@ -49,7 +49,6 @@ pub async fn run() { global_batch_size_warmup_tokens: 0, verification_percent: 0, witness_nodes: 1, - checkpointer_nodes: 0, epoch_time: 30, total_steps: 100, waiting_for_members_extra_time: WAITING_FOR_MEMBERS_EXTRA_SECONDS diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_full_epoch.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_full_epoch.rs index 592b0819b..014772e32 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_full_epoch.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_full_epoch.rs @@ -225,7 +225,6 @@ pub async fn run() { global_batch_size_warmup_tokens: 0, verification_percent: 0, witness_nodes: 0, - checkpointer_nodes: 0, epoch_time, total_steps: 100, waiting_for_members_extra_time: 3, diff --git a/config/solana-test/light-config.toml b/config/solana-test/light-config.toml index 922adcc9e..5de8ca8a4 100644 --- a/config/solana-test/light-config.toml +++ b/config/solana-test/light-config.toml @@ -8,7 +8,6 @@ min_clients = 1 init_min_clients = 1 verification_percent = 0 witness_nodes = 0 -checkpointer_nodes = 1 global_batch_size_start = 8 global_batch_size_end = 8 global_batch_size_warmup_tokens = 0 diff --git a/config/solana-test/nano-config.toml b/config/solana-test/nano-config.toml index 0d05faff2..c275feea3 100644 --- a/config/solana-test/nano-config.toml +++ b/config/solana-test/nano-config.toml @@ -8,7 +8,6 @@ min_clients = 1 init_min_clients = 1 verification_percent = 0 witness_nodes = 1 -checkpointer_nodes = 0 global_batch_size_start = 4 global_batch_size_end = 4 global_batch_size_warmup_tokens = 0 diff --git a/shared/coordinator/src/checkpointer.rs b/shared/coordinator/src/checkpointer.rs index 29f8dd417..9cbd83383 100644 --- a/shared/coordinator/src/checkpointer.rs +++ b/shared/coordinator/src/checkpointer.rs @@ -1,18 +1,20 @@ -use crate::{Coordinator, CoordinatorError}; +use std::cmp::max; + +use crate::{Coordinator, CoordinatorError, coordinator::SOLANA_MAX_NUM_CHECKPOINTERS}; use psyche_core::{NodeIdentity, compute_shuffled_index, sha256, sha256v}; use super::types::salts; #[derive(Clone)] pub struct CheckpointerSelection { - cooldown_nodes: u64, + checkpointers: u64, seed: [u8; 32], } impl CheckpointerSelection { - pub fn new(cooldown_nodes: u64, seed: [u8; 32]) -> Self { + pub fn new(checkpointers: u64, seed: [u8; 32]) -> Self { Self { - cooldown_nodes, + checkpointers, seed, } } @@ -24,8 +26,12 @@ impl CheckpointerSelection { let round = get_round_by_offset(coordinator, offset)?; let seed = sha256(&round.random_seed.to_le_bytes()); + let checkpointers = max( + (coordinator.epoch_state.clients.len() / 3).min(SOLANA_MAX_NUM_CHECKPOINTERS), + 1, + ) as u64; Ok(Self { - cooldown_nodes: coordinator.config.checkpointer_nodes as u64, + checkpointers, seed, }) } @@ -33,7 +39,7 @@ impl CheckpointerSelection { pub fn is_checkpointer(&self, client_index: u64, total_clients: u64) -> bool { let final_seed = compute_salted_seed(&self.seed, salts::COOLDOWN); let index = compute_shuffled_index(client_index, total_clients, &final_seed); - index < self.cooldown_nodes + index < self.checkpointers } } diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index 84372c17b..f0084e0e7 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -252,7 +252,6 @@ pub struct CoordinatorConfig { pub init_min_clients: u16, pub min_clients: u16, pub witness_nodes: u16, - pub checkpointer_nodes: u16, pub global_batch_size_start: u16, pub global_batch_size_end: u16, @@ -1245,8 +1244,6 @@ impl CoordinatorConfig { && self.global_batch_size_end >= self.global_batch_size_start && self.total_steps != 0 && self.witness_nodes <= self.min_clients - && self.checkpointer_nodes <= self.min_clients - && self.checkpointer_nodes as usize <= SOLANA_MAX_NUM_CHECKPOINTERS && self.witness_nodes as usize <= SOLANA_MAX_NUM_WITNESSES && self.cooldown_time > 0 && self.waiting_for_members_extra_time > 0 From 6f7aaf264bb00e0cb4e394834d2aa43c64a95a27 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Wed, 14 Jan 2026 12:08:48 -0800 Subject: [PATCH 40/72] fixing centralized tests --- architectures/centralized/client/src/app.rs | 105 +++++++++--------- architectures/centralized/client/src/main.rs | 2 +- .../centralized/testing/src/client.rs | 2 + architectures/centralized/testing/src/lib.rs | 2 +- shared/client/src/state/cooldown.rs | 3 + shared/coordinator/src/coordinator.rs | 10 ++ shared/data-provider/examples/tcp.rs | 4 + 7 files changed, 76 insertions(+), 52 deletions(-) diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index 47b8d5870..d1a51b017 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -87,6 +87,7 @@ pub struct App { server_conn: TcpClient, metrics: Arc, + skip_upload_check: bool, } pub async fn build_app( @@ -94,6 +95,7 @@ pub async fn build_app( server_addr: String, tx_tui_state: Option>, p: TrainArgs, + is_test: bool, ) -> Result<( App, allowlist::AllowDynamic, @@ -165,6 +167,7 @@ pub async fn build_app( server_conn, run_id: p.run_id, metrics, + skip_upload_check: is_test, }; Ok((app, allowlist, p2p, state_options)) } @@ -178,63 +181,65 @@ impl App { ) -> Result<()> { // sanity checks let CheckpointConfig { upload_info, .. } = state_options.checkpoint_config.clone(); - match upload_info { - Some(UploadInfo::Hub(hub_info)) => { - let api = hf_hub::api::tokio::ApiBuilder::new() - .with_token(Some(hub_info.hub_token)) - .build()?; - let repo_api = api.repo(Repo::new( - hub_info.hub_repo.clone(), - hf_hub::RepoType::Model, - )); - if !repo_api.is_writable().await { - anyhow::bail!( - "Checkpoint upload repo {} is not writable with the passed API key.", - hub_info.hub_repo - ) + if !self.skip_upload_check { + match upload_info { + Some(UploadInfo::Hub(hub_info)) => { + let api = hf_hub::api::tokio::ApiBuilder::new() + .with_token(Some(hub_info.hub_token)) + .build()?; + let repo_api = api.repo(Repo::new( + hub_info.hub_repo.clone(), + hf_hub::RepoType::Model, + )); + if !repo_api.is_writable().await { + anyhow::bail!( + "Checkpoint upload repo {} is not writable with the passed API key.", + hub_info.hub_repo + ) + } } - } - Some(UploadInfo::Gcs(gcs_info)) => { - let config = ClientConfig::default().with_auth().await?; - let client = GcsClient::new(config); + Some(UploadInfo::Gcs(gcs_info)) => { + let config = ClientConfig::default().with_auth().await?; + let client = GcsClient::new(config); - // Test write access by attempting to upload a small test object - let test_key = format!( - "{}/.write_test", - gcs_info.gcs_prefix.clone().unwrap_or_default() - ); + // Test write access by attempting to upload a small test object + let test_key = format!( + "{}/.write_test", + gcs_info.gcs_prefix.clone().unwrap_or_default() + ); - let upload_result = client - .upload_object( - &UploadObjectRequest { - bucket: gcs_info.gcs_bucket.clone(), - ..Default::default() - }, - vec![], - &UploadType::Simple(Media::new(test_key.clone())), - ) - .await; - - match upload_result { - Ok(_) => { - let delete_request = DeleteObjectRequest { - bucket: gcs_info.gcs_bucket.clone(), - object: test_key.clone(), - ..Default::default() - }; - let _ = client.delete_object(&delete_request).await; - } - Err(e) => { - anyhow::bail!( - "GCS bucket gs://{}/{} is not writable: {}", - gcs_info.gcs_bucket, - gcs_info.gcs_prefix.clone().unwrap_or_default(), - e + let upload_result = client + .upload_object( + &UploadObjectRequest { + bucket: gcs_info.gcs_bucket.clone(), + ..Default::default() + }, + vec![], + &UploadType::Simple(Media::new(test_key.clone())), ) + .await; + + match upload_result { + Ok(_) => { + let delete_request = DeleteObjectRequest { + bucket: gcs_info.gcs_bucket.clone(), + object: test_key.clone(), + ..Default::default() + }; + let _ = client.delete_object(&delete_request).await; + } + Err(e) => { + anyhow::bail!( + "GCS bucket gs://{}/{} is not writable: {}", + gcs_info.gcs_bucket, + gcs_info.gcs_prefix.clone().unwrap_or_default(), + e + ) + } } } + None => {} } - None => {} } self.server_conn diff --git a/architectures/centralized/client/src/main.rs b/architectures/centralized/client/src/main.rs index fad5d3817..4a0a7d213 100644 --- a/architectures/centralized/client/src/main.rs +++ b/architectures/centralized/client/src/main.rs @@ -105,7 +105,7 @@ async fn async_main() -> Result<()> { )?; let (mut app, allowlist, p2p, state_options) = - build_app(cancel, server_addr, tx_tui_state, args) + build_app(cancel, server_addr, tx_tui_state, args, false) .await .unwrap(); diff --git a/architectures/centralized/testing/src/client.rs b/architectures/centralized/testing/src/client.rs index 41c156e6c..7f0350c90 100644 --- a/architectures/centralized/testing/src/client.rs +++ b/architectures/centralized/testing/src/client.rs @@ -33,6 +33,7 @@ impl Client { client_app_params.server_addr, None, client_app_params.train_args, + true, ) .await .unwrap(); @@ -57,6 +58,7 @@ impl Client { client_app_params.server_addr, None, client_app_params.train_args, + true, ) .await .unwrap(); diff --git a/architectures/centralized/testing/src/lib.rs b/architectures/centralized/testing/src/lib.rs index 3ff0d7a21..234194c77 100644 --- a/architectures/centralized/testing/src/lib.rs +++ b/architectures/centralized/testing/src/lib.rs @@ -6,4 +6,4 @@ pub mod test_utils; pub const WARMUP_TIME: u64 = 60; pub const MAX_ROUND_TRAIN_TIME: u64 = 5; pub const ROUND_WITNESS_TIME: u64 = 2; -pub const COOLDOWN_TIME: u64 = 3; +pub const COOLDOWN_TIME: u64 = 30; diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index 54450921d..303fc9e6d 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -1,6 +1,9 @@ use crate::UploadInfo; use psyche_coordinator::CheckpointerSelection; use psyche_coordinator::Coordinator; +use psyche_coordinator::model::Checkpoint; +use psyche_coordinator::model::LLM; +use psyche_coordinator::model::Model; use psyche_core::NodeIdentity; use psyche_data_provider::{UploadError, upload_to_gcs, upload_to_hub}; #[cfg(feature = "python")] diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index f0084e0e7..166625313 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -529,6 +529,16 @@ impl Coordinator { return Err(CoordinatorError::InvalidWitness); } + let Model::LLM(llm) = &mut self.model; + match &llm.checkpoint { + Checkpoint::Hub(hub_repo) => { + llm.checkpoint = Checkpoint::P2P(*hub_repo); + } + Checkpoint::Gcs(gcs_repo) => { + llm.checkpoint = Checkpoint::P2PGcs(*gcs_repo); + } + _ => {} + } self.epoch_state.checkpointed = true; Ok(()) diff --git a/shared/data-provider/examples/tcp.rs b/shared/data-provider/examples/tcp.rs index 38cb5e88b..67cc9ab9d 100644 --- a/shared/data-provider/examples/tcp.rs +++ b/shared/data-provider/examples/tcp.rs @@ -36,6 +36,10 @@ impl WatcherBackend for DummyBackend { async fn send_health_check(&mut self, _health_checks: HealthChecks) -> anyhow::Result<()> { bail!("Data provider does not send health check"); } + + async fn send_checkpoint(&mut self, _checkpoint: model::Checkpoint) -> anyhow::Result<()> { + bail!("Data provider does not send checkpoints"); + } } #[derive( From 3471706f382980c1f1cdb0b82a1676617dc3d949 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Wed, 14 Jan 2026 17:58:54 -0300 Subject: [PATCH 41/72] Fix tcp example compilation --- shared/data-provider/examples/tcp.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/shared/data-provider/examples/tcp.rs b/shared/data-provider/examples/tcp.rs index 38cb5e88b..2943cde9f 100644 --- a/shared/data-provider/examples/tcp.rs +++ b/shared/data-provider/examples/tcp.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use bytemuck::Zeroable; use futures::future::try_join_all; use parquet::data_type::AsBytes; -use psyche_coordinator::{Coordinator, HealthChecks, model}; +use psyche_coordinator::{Coordinator, HealthChecks, model::Checkpoint}; use psyche_core::{BatchId, NodeIdentity}; use psyche_data_provider::{ DataProviderTcpClient, DataProviderTcpServer, LengthKnownDataProvider, TokenizedData, @@ -36,6 +36,10 @@ impl WatcherBackend for DummyBackend { async fn send_health_check(&mut self, _health_checks: HealthChecks) -> anyhow::Result<()> { bail!("Data provider does not send health check"); } + + async fn send_checkpoint(&mut self, _checkpoint: Checkpoint) -> anyhow::Result<()> { + bail!("Data provider does not send checkpoints"); + } } #[derive( From e8a4e5039bd967461e3567eda18fb587c5d9c77e Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Wed, 14 Jan 2026 18:21:26 -0300 Subject: [PATCH 42/72] Fix centralized tests avoiding uploading checks --- architectures/centralized/client/src/app.rs | 3 +++ architectures/centralized/testing/src/lib.rs | 2 +- architectures/centralized/testing/src/test_utils.rs | 1 + shared/client/src/cli.rs | 7 +++++++ shared/client/src/state/cooldown.rs | 4 ++++ shared/client/src/state/types.rs | 12 ++++++++++++ 6 files changed, 28 insertions(+), 1 deletion(-) diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index e0800b7bc..1107f3a66 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -235,6 +235,9 @@ impl App { } } } + Some(UploadInfo::Dummy()) => { + // In test mode, we skip upload checks + } None => {} } diff --git a/architectures/centralized/testing/src/lib.rs b/architectures/centralized/testing/src/lib.rs index 3ff0d7a21..2b7586761 100644 --- a/architectures/centralized/testing/src/lib.rs +++ b/architectures/centralized/testing/src/lib.rs @@ -6,4 +6,4 @@ pub mod test_utils; pub const WARMUP_TIME: u64 = 60; pub const MAX_ROUND_TRAIN_TIME: u64 = 5; pub const ROUND_WITNESS_TIME: u64 = 2; -pub const COOLDOWN_TIME: u64 = 3; +pub const COOLDOWN_TIME: u64 = 40; diff --git a/architectures/centralized/testing/src/test_utils.rs b/architectures/centralized/testing/src/test_utils.rs index 7bed4561e..00db87f32 100644 --- a/architectures/centralized/testing/src/test_utils.rs +++ b/architectures/centralized/testing/src/test_utils.rs @@ -143,6 +143,7 @@ pub fn dummy_client_app_params_with_training_delay( "--max-concurrent-parameter-requests", "10", "--hub-max-concurrent-downloads", "1", "--dummy-training-delay-secs", training_delay_secs.to_string().as_str(), + "--test-mode", "true", ]) .train_args, } diff --git a/shared/client/src/cli.rs b/shared/client/src/cli.rs index 23536330f..c599431f4 100644 --- a/shared/client/src/cli.rs +++ b/shared/client/src/cli.rs @@ -204,6 +204,9 @@ pub struct TrainArgs { #[clap(long, default_value_t = 3, env)] pub keep_steps: u32, + + #[clap(long, default_value_t = false, env)] + pub test_mode: bool, } impl TrainArgs { @@ -233,6 +236,10 @@ impl TrainArgs { } pub fn checkpoint_config(&self) -> Result { + if self.test_mode { + return Ok(CheckpointConfig::dummy()); + } + let hub_read_token = std::env::var("HF_TOKEN").ok(); if self.hub_repo.is_some() && self.gcs_bucket.is_some() { diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index 30e1a6b64..3c0029d32 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -269,6 +269,10 @@ async fn upload_checkpoint( UploadInfo::Hub(hub_info) => upload_to_hub(hub_info, local, step, cancellation_token) .await .map_err(CheckpointError::UploadError), + UploadInfo::Dummy() => { + info!("Dummy upload info provided; skipping upload"); + Ok(()) + } } } diff --git a/shared/client/src/state/types.rs b/shared/client/src/state/types.rs index 29734f1a0..9d8902cad 100644 --- a/shared/client/src/state/types.rs +++ b/shared/client/src/state/types.rs @@ -13,6 +13,7 @@ use tokio::task::JoinHandle; pub enum UploadInfo { Hub(HubUploadInfo), Gcs(GcsUploadInfo), + Dummy(), } #[derive(Debug, Clone)] @@ -23,6 +24,17 @@ pub struct CheckpointConfig { pub keep_steps: u32, } +impl CheckpointConfig { + pub fn dummy() -> Self { + Self { + upload_info: Some(UploadInfo::Dummy()), + checkpoint_dir: PathBuf::from("./checkpoints"), + delete_old_steps: false, + keep_steps: 1, + } + } +} + #[derive(Debug)] pub enum PayloadState { Downloading((T, BatchId, BlobTicket)), From 1f545f909566f41c2fe23c5e5c7837db4c05cbc3 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Wed, 14 Jan 2026 18:42:12 -0300 Subject: [PATCH 43/72] Add test mode cli arg for training --- architectures/centralized/client/src/app.rs | 1 + architectures/decentralized/solana-client/src/app.rs | 1 + shared/client/src/state/init.rs | 1 + 3 files changed, 3 insertions(+) diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index 1107f3a66..892d646c9 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -154,6 +154,7 @@ pub async fn build_app( optim_stats_every_n_steps: p.optim_stats_steps, grad_accum_in_fp32: p.grad_accum_in_fp32, dummy_training_delay_secs: p.dummy_training_delay_secs, + test_mode: p.test_mode, max_concurrent_parameter_requests: p.max_concurrent_parameter_requests, device: p.device, sidecar_port: p.sidecar_port, diff --git a/architectures/decentralized/solana-client/src/app.rs b/architectures/decentralized/solana-client/src/app.rs index 36a529bbb..767329f54 100644 --- a/architectures/decentralized/solana-client/src/app.rs +++ b/architectures/decentralized/solana-client/src/app.rs @@ -131,6 +131,7 @@ pub async fn build_app( optim_stats_every_n_steps: p.optim_stats_steps, grad_accum_in_fp32: p.grad_accum_in_fp32, dummy_training_delay_secs: p.dummy_training_delay_secs, + test_mode: p.test_mode, max_concurrent_parameter_requests: p.max_concurrent_parameter_requests, device: p.device, sidecar_port: p.sidecar_port, diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index d2c3a042e..fe01292e0 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -73,6 +73,7 @@ pub struct RunInitConfig { // configurable dummy training time (in seconds) for this client - relevant just for testing pub dummy_training_delay_secs: Option, + pub test_mode: bool, pub sidecar_port: Option, } From 5706ab7e899a3bd6eafbe2610d7c6481699a6dde Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Wed, 14 Jan 2026 18:47:27 -0300 Subject: [PATCH 44/72] Fix flag --- architectures/centralized/client/src/app.rs | 1 - architectures/centralized/testing/src/test_utils.rs | 2 +- architectures/decentralized/solana-client/src/app.rs | 1 - shared/client/src/state/init.rs | 1 - 4 files changed, 1 insertion(+), 4 deletions(-) diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index 892d646c9..1107f3a66 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -154,7 +154,6 @@ pub async fn build_app( optim_stats_every_n_steps: p.optim_stats_steps, grad_accum_in_fp32: p.grad_accum_in_fp32, dummy_training_delay_secs: p.dummy_training_delay_secs, - test_mode: p.test_mode, max_concurrent_parameter_requests: p.max_concurrent_parameter_requests, device: p.device, sidecar_port: p.sidecar_port, diff --git a/architectures/centralized/testing/src/test_utils.rs b/architectures/centralized/testing/src/test_utils.rs index 00db87f32..f0079af70 100644 --- a/architectures/centralized/testing/src/test_utils.rs +++ b/architectures/centralized/testing/src/test_utils.rs @@ -143,7 +143,7 @@ pub fn dummy_client_app_params_with_training_delay( "--max-concurrent-parameter-requests", "10", "--hub-max-concurrent-downloads", "1", "--dummy-training-delay-secs", training_delay_secs.to_string().as_str(), - "--test-mode", "true", + "--test-mode", ]) .train_args, } diff --git a/architectures/decentralized/solana-client/src/app.rs b/architectures/decentralized/solana-client/src/app.rs index 767329f54..36a529bbb 100644 --- a/architectures/decentralized/solana-client/src/app.rs +++ b/architectures/decentralized/solana-client/src/app.rs @@ -131,7 +131,6 @@ pub async fn build_app( optim_stats_every_n_steps: p.optim_stats_steps, grad_accum_in_fp32: p.grad_accum_in_fp32, dummy_training_delay_secs: p.dummy_training_delay_secs, - test_mode: p.test_mode, max_concurrent_parameter_requests: p.max_concurrent_parameter_requests, device: p.device, sidecar_port: p.sidecar_port, diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index fe01292e0..d2c3a042e 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -73,7 +73,6 @@ pub struct RunInitConfig { // configurable dummy training time (in seconds) for this client - relevant just for testing pub dummy_training_delay_secs: Option, - pub test_mode: bool, pub sidecar_port: Option, } From 5f0edb4148faa3d44d4b5a6b8fc4190c43f75de9 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Thu, 15 Jan 2026 05:59:42 -0800 Subject: [PATCH 45/72] Lower cooldown time for centralized tests --- architectures/centralized/testing/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/architectures/centralized/testing/src/lib.rs b/architectures/centralized/testing/src/lib.rs index 2b7586761..58b9429a4 100644 --- a/architectures/centralized/testing/src/lib.rs +++ b/architectures/centralized/testing/src/lib.rs @@ -6,4 +6,4 @@ pub mod test_utils; pub const WARMUP_TIME: u64 = 60; pub const MAX_ROUND_TRAIN_TIME: u64 = 5; pub const ROUND_WITNESS_TIME: u64 = 2; -pub const COOLDOWN_TIME: u64 = 40; +pub const COOLDOWN_TIME: u64 = 5; From d53ff7b310de93a9fc2c664536710fc7fa38b232 Mon Sep 17 00:00:00 2001 From: Dylan Socolobsky Date: Wed, 14 Jan 2026 07:15:15 -0800 Subject: [PATCH 46/72] update gcp crate to 1.5.x version --- Cargo.lock | 495 +++++++++++++++++++++-------- shared/data-provider/Cargo.toml | 6 +- shared/data-provider/src/errors.rs | 42 ++- shared/data-provider/src/gcs.rs | 194 ++++------- shared/data-provider/src/http.rs | 95 +++--- 5 files changed, 516 insertions(+), 316 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 45cf331f8..11ad7e2c6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -738,28 +738,6 @@ dependencies = [ "pin-project-lite", ] -[[package]] -name = "async-stream" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" -dependencies = [ - "async-stream-impl", - "futures-core", - "pin-project-lite", -] - -[[package]] -name = "async-stream-impl" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.106", -] - [[package]] name = "async-trait" version = "0.1.89" @@ -1082,6 +1060,31 @@ dependencies = [ "serde_with", ] +[[package]] +name = "bon" +version = "3.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "234655ec178edd82b891e262ea7cf71f6584bcd09eff94db786be23f1821825c" +dependencies = [ + "bon-macros", + "rustversion", +] + +[[package]] +name = "bon-macros" +version = "3.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89ec27229c38ed0eb3c0feee3d2c1d6a4379ae44f418a29a658890e062d8f365" +dependencies = [ + "darling 0.21.3", + "ident_case", + "prettyplease", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.106", +] + [[package]] name = "borsh" version = "0.10.4" @@ -1674,12 +1677,6 @@ dependencies = [ "web-sys", ] -[[package]] -name = "const-oid" -version = "0.9.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" - [[package]] name = "const-oid" version = "0.10.1" @@ -1825,6 +1822,15 @@ dependencies = [ "libc", ] +[[package]] +name = "crc32c" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a47af21622d091a8f0fb295b88bc886ac74efcc613efc19f5d0b21de5c89e47" +dependencies = [ + "rustc_version", +] + [[package]] name = "crc32fast" version = "1.5.0" @@ -2128,25 +2134,14 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476" -[[package]] -name = "der" -version = "0.7.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" -dependencies = [ - "const-oid 0.9.6", - "pem-rfc7468 0.7.0", - "zeroize", -] - [[package]] name = "der" version = "0.8.0-rc.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e9d8dd2f26c86b27a2a8ea2767ec7f9df7a89516e4794e54ac01ee618dda3aa4" dependencies = [ - "const-oid 0.10.1", - "pem-rfc7468 1.0.0-rc.3", + "const-oid", + "pem-rfc7468", "zeroize", ] @@ -2309,7 +2304,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dac89f8a64533a9b0eaa73a68e424db0fb1fd6271c74cc0125336a05f090568d" dependencies = [ "block-buffer 0.11.0-rc.5", - "const-oid 0.10.1", + "const-oid", "crypto-common 0.2.0-rc.4", ] @@ -2461,7 +2456,7 @@ version = "3.0.0-rc.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ef49c0b20c0ad088893ad2a790a29c06a012b3f05bcfc66661fd22a94b32129" dependencies = [ - "pkcs8 0.11.0-rc.7", + "pkcs8", "serde", "signature 3.0.0-rc.4", ] @@ -3297,9 +3292,9 @@ dependencies = [ [[package]] name = "google-cloud-auth" -version = "0.17.2" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e57a13fbacc5e9c41ded3ad8d0373175a6b7a6ad430d99e89d314ac121b7ab06" +checksum = "1112c453c2e155b3e683204ffff52bcc6d6495d04b68d9e90cd24161270c5058" dependencies = [ "async-trait", "base64 0.21.7", @@ -3317,6 +3312,137 @@ dependencies = [ "urlencoding", ] +[[package]] +name = "google-cloud-auth" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "590a1c28795779d5da6fda35b149d5271bcddcf2ce1709eae9e9460faf2f2aa9" +dependencies = [ + "async-trait", + "base64 0.22.1", + "bon", + "bytes", + "google-cloud-gax", + "http 1.3.1", + "reqwest 0.12.24", + "rustc_version", + "rustls 0.23.35", + "rustls-pemfile 2.2.0", + "serde", + "serde_json", + "thiserror 2.0.17", + "time", + "tokio", +] + +[[package]] +name = "google-cloud-gax" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "324fb97d35103787e80a33ed41ccc43d947c376d2ece68ca53e860f5844dbe24" +dependencies = [ + "base64 0.22.1", + "bytes", + "futures", + "google-cloud-rpc", + "google-cloud-wkt", + "http 1.3.1", + "pin-project", + "rand 0.9.2", + "serde", + "serde_json", + "thiserror 2.0.17", + "tokio", +] + +[[package]] +name = "google-cloud-gax-internal" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b75b810886ae872aca68a35ad1d4d5e8f2be39e40238116d8aff9d778f04b38" +dependencies = [ + "bytes", + "futures", + "google-cloud-auth 1.3.0", + "google-cloud-gax", + "google-cloud-rpc", + "google-cloud-wkt", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "hyper 1.7.0", + "opentelemetry-semantic-conventions", + "percent-encoding", + "pin-project", + "prost 0.14.3", + "prost-types", + "reqwest 0.12.24", + "rustc_version", + "serde", + "serde_json", + "thiserror 2.0.17", + "tokio", + "tokio-stream", + "tonic 0.14.2", + "tonic-prost", + "tower", + "tracing", +] + +[[package]] +name = "google-cloud-iam-v1" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "498a68e2a958e8aa9938f7db2c7147aad1b5a0ff2cd47c5ba4e10cb0dcb5bfc5" +dependencies = [ + "async-trait", + "bytes", + "google-cloud-gax", + "google-cloud-gax-internal", + "google-cloud-type", + "google-cloud-wkt", + "lazy_static", + "reqwest 0.12.24", + "serde", + "serde_json", + "serde_with", + "tracing", +] + +[[package]] +name = "google-cloud-longrunning" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c80938e704401a47fdf36b51ec10e1a99b1ec22793d607afd0e67c7b675b8b3" +dependencies = [ + "async-trait", + "bytes", + "google-cloud-gax", + "google-cloud-gax-internal", + "google-cloud-rpc", + "google-cloud-wkt", + "lazy_static", + "reqwest 0.12.24", + "serde", + "serde_json", + "serde_with", + "tracing", +] + +[[package]] +name = "google-cloud-lro" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49747b7b684b804a2d1040c2cdb21238b3d568a41ab9e36c423554509112f61d" +dependencies = [ + "google-cloud-gax", + "google-cloud-longrunning", + "google-cloud-rpc", + "google-cloud-wkt", + "serde", + "tokio", +] + [[package]] name = "google-cloud-metadata" version = "0.5.1" @@ -3328,37 +3454,60 @@ dependencies = [ "tokio", ] +[[package]] +name = "google-cloud-rpc" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd10e97751ca894f9dad6be69fcef1cb72f5bc187329e0254817778fc8235030" +dependencies = [ + "bytes", + "google-cloud-wkt", + "serde", + "serde_json", + "serde_with", +] + [[package]] name = "google-cloud-storage" -version = "0.24.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34a73d9e94d35665909050f02e035d8bdc82e419241b1b027ebf1ea51dc8a470" +checksum = "043be824d1b105bfdce786c720e45cae04e66436f8e5d0168e98ca8e5715ce9f" dependencies = [ - "anyhow", - "async-stream", "async-trait", - "base64 0.21.7", + "base64 0.22.1", "bytes", - "futures-util", - "google-cloud-auth", - "google-cloud-metadata", - "google-cloud-token", - "hex", - "once_cell", + "crc32c", + "futures", + "google-cloud-auth 1.3.0", + "google-cloud-gax", + "google-cloud-gax-internal", + "google-cloud-iam-v1", + "google-cloud-longrunning", + "google-cloud-lro", + "google-cloud-rpc", + "google-cloud-type", + "google-cloud-wkt", + "http 1.3.1", + "http-body 1.0.1", + "hyper 1.7.0", + "lazy_static", + "md5", + "mime", "percent-encoding", - "pkcs8 0.10.2", - "regex", + "pin-project", + "prost 0.14.3", + "prost-types", "reqwest 0.12.24", - "reqwest-middleware 0.4.2", - "ring", "serde", "serde_json", + "serde_with", "sha2 0.10.9", - "thiserror 1.0.69", - "time", + "thiserror 2.0.17", "tokio", + "tokio-stream", + "tonic 0.14.2", "tracing", - "url", + "uuid", ] [[package]] @@ -3370,6 +3519,35 @@ dependencies = [ "async-trait", ] +[[package]] +name = "google-cloud-type" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9390ac2f3f9882ff42956b25ea65b9f546c8dd44c131726d75a96bf744ec75f6" +dependencies = [ + "bytes", + "google-cloud-wkt", + "serde", + "serde_json", + "serde_with", +] + +[[package]] +name = "google-cloud-wkt" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6f270e404be7ce76a3260abe0c3c71492ab2599ccd877f3253f3dd552f48cc9" +dependencies = [ + "base64 0.22.1", + "bytes", + "serde", + "serde_json", + "serde_with", + "thiserror 2.0.17", + "time", + "url", +] + [[package]] name = "governor" version = "0.6.3" @@ -3943,6 +4121,19 @@ dependencies = [ "webpki-roots 1.0.3", ] +[[package]] +name = "hyper-timeout" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0" +dependencies = [ + "hyper 1.7.0", + "hyper-util", + "pin-project-lite", + "tokio", + "tower-service", +] + [[package]] name = "hyper-tls" version = "0.6.0" @@ -4389,7 +4580,7 @@ dependencies = [ "netwatch", "pin-project", "pkarr", - "pkcs8 0.11.0-rc.7", + "pkcs8", "portmapper", "rand 0.9.2", "reqwest 0.12.24", @@ -5274,6 +5465,12 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "md5" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae960838283323069879657ca3de837e9f7bbb4c7bf6ea7f1b290d5e9476d2e0" + [[package]] name = "memchr" version = "2.7.6" @@ -6164,7 +6361,7 @@ dependencies = [ "opentelemetry-http", "opentelemetry-proto", "opentelemetry_sdk", - "prost", + "prost 0.13.5", "reqwest 0.12.24", "thiserror 2.0.17", "tracing", @@ -6178,10 +6375,16 @@ checksum = "56f8870d3024727e99212eb3bb1762ec16e255e3e6f58eeb3dc8db1aa226746d" dependencies = [ "opentelemetry 0.28.0", "opentelemetry_sdk", - "prost", - "tonic", + "prost 0.13.5", + "tonic 0.12.3", ] +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e62e29dfe041afb8ed2a6c9737ab57db4907285d999ef8ad3a59092a36bdc846" + [[package]] name = "opentelemetry_sdk" version = "0.28.0" @@ -6344,15 +6547,6 @@ dependencies = [ "serde_core", ] -[[package]] -name = "pem-rfc7468" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" -dependencies = [ - "base64ct", -] - [[package]] name = "pem-rfc7468" version = "1.0.0-rc.3" @@ -6450,24 +6644,14 @@ dependencies = [ "wasm-bindgen-futures", ] -[[package]] -name = "pkcs8" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" -dependencies = [ - "der 0.7.10", - "spki 0.7.3", -] - [[package]] name = "pkcs8" version = "0.11.0-rc.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "93eac55f10aceed84769df670ea4a32d2ffad7399400d41ee1c13b1cd8e1b478" dependencies = [ - "der 0.8.0-rc.9", - "spki 0.8.0-rc.4", + "der", + "spki", ] [[package]] @@ -6701,6 +6885,16 @@ dependencies = [ "yansi", ] +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn 2.0.106", +] + [[package]] name = "preview-lr" version = "0.1.0" @@ -6804,7 +6998,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" dependencies = [ "bytes", - "prost-derive", + "prost-derive 0.13.5", +] + +[[package]] +name = "prost" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2ea70524a2f82d518bce41317d0fae74151505651af45faf1ffbd6fd33f0568" +dependencies = [ + "bytes", + "prost-derive 0.14.3", ] [[package]] @@ -6820,6 +7024,28 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "prost-derive" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" +dependencies = [ + "anyhow", + "itertools 0.14.0", + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "prost-types" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8991c4cbdb8bc5b11f0b074ffe286c30e523de90fee5ba8132f1399f23cb3dd7" +dependencies = [ + "prost 0.14.3", +] + [[package]] name = "psyche-centralized-client" version = "0.1.0" @@ -7016,8 +7242,11 @@ dependencies = [ "anyhow", "async-trait", "bytemuck", + "bytes", "clap", "futures", + "google-cloud-auth 0.16.0", + "google-cloud-gax", "google-cloud-storage", "hf-hub", "memmap2 0.9.8", @@ -7044,6 +7273,7 @@ dependencies = [ "tokio-util 0.7.16", "tracing", "ts-rs", + "urlencoding", ] [[package]] @@ -7976,21 +8206,6 @@ dependencies = [ "thiserror 1.0.69", ] -[[package]] -name = "reqwest-middleware" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57f17d28a6e6acfe1733fe24bcd30774d13bffa4b8a22535b4c8c98423088d4e" -dependencies = [ - "anyhow", - "async-trait", - "http 1.3.1", - "reqwest 0.12.24", - "serde", - "thiserror 1.0.69", - "tower-service", -] - [[package]] name = "resolv-conf" version = "0.7.5" @@ -9981,7 +10196,7 @@ dependencies = [ "indicatif", "log", "reqwest 0.11.27", - "reqwest-middleware 0.2.5", + "reqwest-middleware", "semver", "serde", "serde_derive", @@ -10006,7 +10221,7 @@ dependencies = [ "bs58", "jsonrpc-core", "reqwest 0.11.27", - "reqwest-middleware 0.2.5", + "reqwest-middleware", "semver", "serde", "serde_derive", @@ -10852,16 +11067,6 @@ dependencies = [ "lock_api", ] -[[package]] -name = "spki" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" -dependencies = [ - "base64ct", - "der 0.7.10", -] - [[package]] name = "spki" version = "0.8.0-rc.4" @@ -10869,7 +11074,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8baeff88f34ed0691978ec34440140e1572b68c7dd4a495fd14a3dc1944daa80" dependencies = [ "base64ct", - "der 0.8.0-rc.9", + "der", ] [[package]] @@ -12242,13 +12447,51 @@ dependencies = [ "http-body-util", "percent-encoding", "pin-project", - "prost", + "prost 0.13.5", "tokio-stream", "tower-layer", "tower-service", "tracing", ] +[[package]] +name = "tonic" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb7613188ce9f7df5bfe185db26c5814347d110db17920415cf2fbcad85e7203" +dependencies = [ + "base64 0.22.1", + "bytes", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "hyper 1.7.0", + "hyper-timeout", + "hyper-util", + "percent-encoding", + "pin-project", + "rustls-native-certs", + "sync_wrapper 1.0.2", + "tokio", + "tokio-rustls 0.26.4", + "tokio-stream", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tonic-prost" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66bd50ad6ce1252d87ef024b3d64fe4c3cf54a86fb9ef4c631fdd0ded7aeaa67" +dependencies = [ + "bytes", + "prost 0.14.3", + "tonic 0.14.2", +] + [[package]] name = "torch-sys" version = "0.22.0" @@ -12268,11 +12511,15 @@ checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" dependencies = [ "futures-core", "futures-util", + "indexmap 2.11.4", "pin-project-lite", + "slab", "sync_wrapper 1.0.2", "tokio", + "tokio-util 0.7.16", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -12307,9 +12554,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" -version = "0.1.41" +version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ "log", "pin-project-lite", @@ -12319,9 +12566,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.30" +version = "0.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", @@ -12330,9 +12577,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.34" +version = "0.1.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" dependencies = [ "once_cell", "valuable", diff --git a/shared/data-provider/Cargo.toml b/shared/data-provider/Cargo.toml index ea0e59791..acc09c821 100644 --- a/shared/data-provider/Cargo.toml +++ b/shared/data-provider/Cargo.toml @@ -26,7 +26,11 @@ thiserror.workspace = true postcard.workspace = true bytemuck.workspace = true reqwest = "0.12.12" -google-cloud-storage = "0.24.0" +google-cloud-storage = "1.5.0" +bytes = "1" +google-cloud-auth = "0.16" +google-cloud-gax = "1.4.0" +urlencoding = "2.1.3" ts-rs.workspace = true rayon.workspace = true diff --git a/shared/data-provider/src/errors.rs b/shared/data-provider/src/errors.rs index 20e601b25..e2949ff65 100644 --- a/shared/data-provider/src/errors.rs +++ b/shared/data-provider/src/errors.rs @@ -1,3 +1,4 @@ +use hf_hub::api::tokio::CommitError; use std::path::PathBuf; use thiserror::Error; @@ -9,39 +10,34 @@ pub enum UploadError { #[error("file {0} doesn't have a valid utf-8 representation")] InvalidFilename(PathBuf), - #[error("failed to send checkpoint notification")] - SendCheckpoint, - - // Hub-specific errors - #[error("failed to connect to HF hub: {0}")] - HfHub(#[from] hf_hub::api::tokio::ApiError), - - #[error("failed to commit files: {0}")] - Commit(#[from] hf_hub::api::tokio::CommitError), - - // GCS-specific errors #[error("GCS authentication failed: {0}")] - GcsAuth(#[from] google_cloud_storage::client::google_cloud_auth::error::Error), + GcsAuth(String), - #[error("GCS operation failed: {0}")] - GcsStorage(#[from] google_cloud_storage::http::Error), - - // Common errors + //#[error("GCS operation failed: {0}")] + //GcsStorage(#[from] google_cloud_storage::client::Error), #[error("IO error: {0}")] Io(#[from] std::io::Error), + + #[error("GCS error: {0}")] + Gcs(String), + + #[error("HuggingFace Hub API error: {0}")] + HubApi(#[from] hf_hub::api::tokio::ApiError), + + #[error("HuggingFace Hub commit error: {0}")] + HubCommit(#[from] CommitError), } #[derive(Error, Debug)] pub enum DownloadError { - #[error("failed to connect to HF hub: {0}")] - HfHub(#[from] hf_hub::api::tokio::ApiError), - #[error("GCS authentication failed: {0}")] - GcsAuth(#[from] google_cloud_storage::client::google_cloud_auth::error::Error), - - #[error("GCS operation failed: {0}")] - GcsStorage(#[from] google_cloud_storage::http::Error), + GcsAuth(String), + //#[error("GCS operation failed: {0}")] + //GcsStorage(#[from] google_cloud_storage::client::Error), #[error("IO error: {0}")] Io(#[from] std::io::Error), + + #[error("GCS error: {0}")] + Gcs(String), } diff --git a/shared/data-provider/src/gcs.rs b/shared/data-provider/src/gcs.rs index ab266a757..08f06ffba 100644 --- a/shared/data-provider/src/gcs.rs +++ b/shared/data-provider/src/gcs.rs @@ -1,14 +1,9 @@ use crate::errors::{DownloadError, UploadError}; -use google_cloud_storage::client::{Client, ClientConfig}; -use google_cloud_storage::http::objects::upload::Media; -use google_cloud_storage::http::objects::upload::UploadObjectRequest; -use google_cloud_storage::http::objects::upload::UploadType; -use google_cloud_storage::http::objects::{ - download::Range, get::GetObjectRequest, list::ListObjectsRequest, -}; +use google_cloud_gax::paginator::ItemPaginator; +use google_cloud_storage::client::{Storage, StorageControl}; use std::path::PathBuf; use tokio::runtime::Runtime; -use tracing::info; +use tracing::{debug, info}; #[derive(Debug, Clone)] pub struct GcsUploadInfo { @@ -40,86 +35,76 @@ pub async fn download_model_from_gcs_async( bucket: &str, prefix: Option<&str>, ) -> Result, DownloadError> { - // Use authenticated client if GOOGLE_APPLICATION_CREDENTIALS is set, otherwise anonymous - let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { - info!("Using authenticated GCS client"); - ClientConfig::default().with_auth().await? - } else { - info!("Using anonymous GCS client"); - ClientConfig::default().anonymous() - }; - let client = Client::new(config); - - // List all objects in the bucket with optional prefix + // Automatically handles authentication via GOOGLE_APPLICATION_CREDENTIALS + let storage = Storage::builder() + .build() + .await + .map_err(|e| DownloadError::Gcs(e.to_string()))?; + + let storage_control = StorageControl::builder() + .build() + .await + .map_err(|e| DownloadError::Gcs(e.to_string()))?; + let mut all_objects = vec![]; - let mut page_token: Option = None; - - loop { - let results = client - .list_objects(&ListObjectsRequest { - bucket: bucket.to_owned(), - prefix: prefix.map(|s| s.to_owned()), - page_token: page_token.clone(), - ..Default::default() - }) - .await?; - - for obj in results.items.iter().flatten() { - if check_model_extension(&obj.name) { - all_objects.push(obj.name.clone()); - } - } - match results.next_page_token { - Some(token) => page_token = Some(token), - None => break, - } + let parent_name = format!("projects/_/buckets/{}", bucket); + debug!( + "Listing objects in GCS bucket: {}, parent: {}", + bucket, parent_name + ); + let mut list_request = storage_control.list_objects().set_parent(parent_name); + if let Some(p) = prefix { + list_request = list_request.set_prefix(p.to_string()); } - info!( - "Found {} model files in gs://{}/{}", - all_objects.len(), - bucket, - prefix.unwrap_or("") - ); + let mut stream = list_request.by_item(); + while let Some(obj) = stream + .next() + .await + .transpose() + .map_err(|e| DownloadError::Gcs(e.to_string()))? + { + if check_model_extension(&obj.name) { + all_objects.push(obj.name); + } + } + debug!("Found {} model files", all_objects.len()); let cache_dir = get_cache_dir(bucket, prefix); - std::fs::create_dir_all(&cache_dir)?; + tokio::fs::create_dir_all(&cache_dir) + .await + .map_err(DownloadError::Io)?; let mut downloaded_files = Vec::new(); for object_name in all_objects { - // Get just the filename (strip prefix if present) let filename = object_name.rsplit('/').next().unwrap_or(&object_name); - let local_path = cache_dir.join(filename); - // Skip if already cached if local_path.exists() { - info!("Using cached: {}", filename); downloaded_files.push(local_path); continue; } - info!("Downloading: {}", object_name); - - // Download the object - let data = client - .download_object( - &GetObjectRequest { - bucket: bucket.to_owned(), - object: object_name.clone(), - ..Default::default() - }, - &Range::default(), - ) - .await?; - - // Write to cache - std::fs::write(&local_path, &data)?; - - info!("Downloaded: {} ({} bytes)", filename, data.len()); + let bucket_resource_name = format!("projects/_/buckets/{}", bucket); + let mut read_response = storage + .read_object(&bucket_resource_name, &object_name) + .send() + .await + .map_err(|e| DownloadError::Gcs(e.to_string()))?; + + let mut data = Vec::new(); + while let Some(chunk_result) = read_response.next().await { + let chunk = chunk_result.map_err(|arg0: google_cloud_storage::Error| { + DownloadError::Gcs(arg0.to_string()) + })?; + data.extend_from_slice(&chunk); + } + tokio::fs::write(&local_path, &data) + .await + .map_err(DownloadError::Io)?; downloaded_files.push(local_path); } @@ -139,76 +124,35 @@ pub async fn upload_to_gcs( local: Vec, cancellation_token: tokio_util::sync::CancellationToken, ) -> Result<(), UploadError> { - let GcsUploadInfo { - gcs_bucket, - gcs_prefix, - } = gcs_info; - - info!(bucket = gcs_bucket, "Uploading checkpoint to GCS"); - - let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { - info!("Using authenticated GCS client"); - ClientConfig::default().with_auth().await? - } else { - info!("Using anonymous GCS client"); - ClientConfig::default().anonymous() - }; - let client = Client::new(config); + let storage = Storage::builder() + .build() + .await + .map_err(|e| UploadError::Gcs(e.to_string()))?; for path in local { - // Check for cancellation before each file upload if cancellation_token.is_cancelled() { - info!("Upload cancelled before uploading {}", path.display()); return Ok(()); } let file_name = path .file_name() - .ok_or_else(|| UploadError::NotAFile(path.clone()))? - .to_str() + .and_then(|n| n.to_str()) .ok_or_else(|| UploadError::InvalidFilename(path.clone()))?; - - let object_name = match &gcs_prefix { - Some(p) => format!("{}/{}", p, file_name), + let object_name = match &gcs_info.gcs_prefix { + Some(p) => format!("{}/{}", p.trim_end_matches('/'), file_name), None => file_name.to_string(), }; + let bucket_resource_name = format!("projects/_/buckets/{}", gcs_info.gcs_bucket); - let data = tokio::fs::read(&path).await?; - - let upload_type = UploadType::Simple(Media::new(object_name.clone())); + let data = bytes::Bytes::from(tokio::fs::read(&path).await.map_err(UploadError::Io)?); - // Bind to a variable so it lives long enough - let upload_request = UploadObjectRequest { - bucket: gcs_bucket.clone(), - ..Default::default() - }; - let upload_future = client.upload_object(&upload_request, data, &upload_type); - - let uploaded = tokio::select! { - biased; - - _ = cancellation_token.cancelled() => { - info!("Upload cancelled during upload of {}", path.display()); - return Ok(()); - } - result = upload_future => { - result? - } - }; - - info!( - bucket = gcs_bucket, - object = object_name, - size = uploaded.size, - "Successfully uploaded file to GCS" - ); + let uploaded_file = storage + .write_object(&bucket_resource_name, &object_name, data) + .send_unbuffered() + .await + .map_err(|e| UploadError::Gcs(e.to_string()))?; + info!(object = %object_name, size = uploaded_file.size, "Uploaded"); } - info!( - "Upload to GCS complete at gs://{}/{}", - gcs_bucket, - gcs_prefix.as_deref().unwrap_or("") - ); - Ok(()) } diff --git a/shared/data-provider/src/http.rs b/shared/data-provider/src/http.rs index 5417f8601..dd13169b4 100644 --- a/shared/data-provider/src/http.rs +++ b/shared/data-provider/src/http.rs @@ -2,7 +2,8 @@ use std::{str::FromStr, time::Duration}; use anyhow::{Context, Result, anyhow, bail}; use futures::future::join_all; -use google_cloud_storage::http::objects::list::ListObjectsRequest; +use google_cloud_gax::paginator::ItemPaginator; +use google_cloud_storage::client::StorageControl; use psyche_coordinator::model::HttpTrainingDataLocation; use psyche_core::{BatchId, Shuffle, TokenSize}; use rand::seq::SliceRandom; @@ -10,7 +11,7 @@ use rand_chacha::ChaCha8Rng; use rand_chacha::rand_core::SeedableRng; use reqwest::IntoUrl; use tokio::task::JoinHandle; -use tracing::{info, trace}; +use tracing::{debug, info, trace}; use crate::{ TokenizedData, @@ -85,50 +86,58 @@ impl FileURLs { Ok(Self(urls_with_sizes)) } - pub async fn from_gcp_bucket(bucket_name: &str, directory: Option) -> Result { - let config = google_cloud_storage::client::ClientConfig::default().anonymous(); - let client = google_cloud_storage::client::Client::new(config); - let mut data_files_matching_directory = { - let mut all_results = vec![]; - // the outer option is if we should continue looping - // the inner option is if we have a "next page token" - let mut next_page_token: Option> = Some(None); - - while let Some(maybe_next_page_token) = next_page_token { - let this_results = client - .list_objects(&ListObjectsRequest { - bucket: bucket_name.to_owned(), - prefix: directory.clone(), - page_token: maybe_next_page_token, - ..Default::default() - }) - .await?; - all_results.extend(this_results.items.iter().flatten().filter_map(|obj| { - let file_ext = obj.name.split('.').next_back()?; - if !DATA_FILE_EXTENSIONS.contains(&file_ext) { - return None; - } - - Some( - obj.media_link - .parse::() - .map(|full_url| (full_url, obj.size as u64)) - .map_err(anyhow::Error::from), - ) - })); - - // if we have a token, Some(Some(String)), - // if not, None - next_page_token = this_results.next_page_token.map(Some) - } - all_results + pub async fn from_gcp_bucket( + bucket_name: &str, + directory: Option, + ) -> anyhow::Result { + debug!( + "http: from_gcp_bucket: bucket_name={}, directory={:?}", + bucket_name, directory + ); + let storage_control = StorageControl::builder().build().await?; + + let mut builder = storage_control + .list_objects() + .set_parent(format!("projects/_/buckets/{}", bucket_name)); + if let Some(p) = directory { + builder = builder.set_prefix(p); } - .into_iter() - .collect::>>()?; - data_files_matching_directory.sort_by(|a, b| a.0.cmp(&b.0)); + let mut items = builder.by_item(); + let mut all_results = vec![]; + + // transpose does Result> -> Option> + while let Some(obj) = items.next().await.transpose()? { + // Only process those files with extensions we care about + let file_ext = obj.name.split('.').next_back().unwrap_or(""); + if !DATA_FILE_EXTENSIONS.contains(&file_ext) { + continue; + } + + let full_url = { + // Transforms spaces, etc. into %20 and other url-friendly encodings + let encoded_name = urlencoding::encode(&obj.name); + + // Just in case we have the whole "projects/_/buckets/bucket-name" prefix remove it + let bucket_name_only = obj + .bucket + .strip_prefix("projects/_/buckets/") + .unwrap_or(&obj.bucket); + + format!("https://www.googleapis.com/storage/v1/b/{bucket_name_only}/o/{encoded_name}?alt=media") + .parse::() + .map_err(anyhow::Error::from)? + }; + debug!( + "Constructed full url: {:?} for object: {} with size {}", + full_url, obj.name, obj.size + ); + all_results.push((full_url, obj.size as u64)); + } - Ok(Self(data_files_matching_directory)) + // We sort here to return in deterministic order + all_results.sort_by(|a, b| a.0.cmp(&b.0)); + Ok(Self(all_results)) } pub async fn from_location(location: &HttpTrainingDataLocation) -> Result { From 15b2d5e0f43ade4e62ebd1c4c8f16abd3e06abbc Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Thu, 15 Jan 2026 20:07:13 -0300 Subject: [PATCH 47/72] Add docs with new cooldown behavior --- architectures/decentralized/justfile | 8 +++++++ psyche-book/src/enduser/join-run.md | 5 +++++ psyche-book/src/enduser/run-config.md | 6 +++++- psyche-book/src/explain/general-workflow.md | 24 ++++++++++++++++----- psyche-book/src/explain/index.md | 2 +- psyche-book/src/explain/model-sharing.md | 10 ++++++--- scripts/train-solana-test.sh | 22 +++++++++++++++++-- shared/client/src/cli.rs | 2 +- 8 files changed, 66 insertions(+), 13 deletions(-) diff --git a/architectures/decentralized/justfile b/architectures/decentralized/justfile index 77ba7d4ed..92b272f31 100644 --- a/architectures/decentralized/justfile +++ b/architectures/decentralized/justfile @@ -5,6 +5,8 @@ set working-directory := '../../' # In case a recipe is not found here, it will fallback to the root justfile. AUTHORIZER := env_var_or_default("AUTHORIZER", "11111111111111111111111111111111") +HF_TOKEN := env_var_or_default("HF_TOKEN", "") +GOOGLE_APPLICATION_CREDENTIALS := env_var_or_default("GOOGLE_APPLICATION_CREDENTIALS", "") set fallback := true @@ -42,6 +44,12 @@ start-training-localnet-client run_id="test" *args='': start-training-localnet-light-client run_id="test" *args='': AUTHORIZER={{ AUTHORIZER }} RUN_ID={{ run_id }} BATCH_SIZE=1 DP=1 ./scripts/train-solana-test.sh {{ args }} +start-training-localnet-light-client-checkpoint run_id="test" *args='': + HF_TOKEN={{ HF_TOKEN }} GOOGLE_APPLICATION_CREDENTIALS={{ GOOGLE_APPLICATION_CREDENTIALS }} AUTHORIZER={{ AUTHORIZER }} RUN_ID={{ run_id }} BATCH_SIZE=1 DP=1 ./scripts/train-solana-test.sh --checkpoint {{ args }} + +start-training-localnet-client-checkpoint run_id="test" *args='': + HF_TOKEN={{ HF_TOKEN }} GOOGLE_APPLICATION_CREDENTIALS={{ GOOGLE_APPLICATION_CREDENTIALS }} AUTHORIZER={{ AUTHORIZER }} RUN_ID={{ run_id }} ./scripts/train-solana-test.sh --checkpoint {{ args }} + OTLP_METRICS_URL := "http://localhost:4318/v1/metrics" OTLP_LOGS_URL := "http://localhost:4318/v1/logs" diff --git a/psyche-book/src/enduser/join-run.md b/psyche-book/src/enduser/join-run.md index b79678bcd..f56df1be7 100644 --- a/psyche-book/src/enduser/join-run.md +++ b/psyche-book/src/enduser/join-run.md @@ -67,6 +67,11 @@ WS_RPC=wss://your-primary-rpc-provider.com # Required: Which run id to join RUN_ID=your_run_id_here +# Required: access token to write model states in the storage. +# Depending on the config this will be a HuggingFace token or Google cloud storage access file path, only one of them should be needed +HF_TOKEN=HuggingFace_write_token_for_repo +GOOGLE_APPLICATION_CREDENTIALS=path_to_credentials_file_with_access_to_bucket + # Recommended: Fallback RPC Endpoints (for reliability) RPC_2=https://your-backup-rpc-provider.com WS_RPC_2=wss://your-backup-rpc-provider.com diff --git a/psyche-book/src/enduser/run-config.md b/psyche-book/src/enduser/run-config.md index 8440dd0a7..f5c2f6635 100644 --- a/psyche-book/src/enduser/run-config.md +++ b/psyche-book/src/enduser/run-config.md @@ -16,7 +16,7 @@ Here's a sample config with some of its options documented. # maximum time, in seconds, to let nodes download the model from a checkpoint / other nodes warmup_time = 30 -# time, in seconds, to let nodes bring the model from the GPU to disk, and to opt to join the next round. +# time, in seconds, to let nodes bring the model from the GPU to disk, upload the model to the remote storage and to opt to join the next round. cooldown_time = 30 # time, in seconds, that an epoch will last. @@ -74,6 +74,10 @@ max_seq_len = 2048 # Repo where the model is located in HugggingFace, will be used to download the model at the beginning of training. repo_id = "emozilla/llama2-20m-init" +# Google Cloud Storage is also supported +[model.LLM.checkpoint.Gcs] +bucket = "bucket_name" + [model.LLM.data_location.Http] # Token size in bytes, can be "TwoBytes" or "FourBytes" token_size_in_bytes = "TwoBytes" diff --git a/psyche-book/src/explain/general-workflow.md b/psyche-book/src/explain/general-workflow.md index 4ac8bfa63..d3fc61662 100644 --- a/psyche-book/src/explain/general-workflow.md +++ b/psyche-book/src/explain/general-workflow.md @@ -95,18 +95,27 @@ Any clients that have failed [health checks](#health-checks) will also be remove ### Cooldown phase (state: Cooldown) -The _Cooldown_ phase is the last phase of an epoch, during which the Coordinator waits the _Cooldown_ period to elapse. At this point the clients will begin to do a new checkpoint of the model, this is saving the state of the model at that time to a external storage, such as a Hugging Face. +The **Cooldown** phase is the last phase of an epoch. At this point, clients begin creating a new checkpoint of the model. This means saving the current state of the model to external storage, such as Hugging Face or a bucket in Google Cloud Storage (GCS). -When the _Cooldown_ phase begins, the Coordinator also resets the current model checkpoint state to `Checkpoint::P2P`, indicating that new joiners should download the latest copy of the model from the other participants and not from the usual checkpoint. +At the beginning of this state, the run elects a subset of clients that will be designated as **checkpointers**. All clients are potential checkpointers: one third of the total clients in the run will be elected pseudo-randomly at this stage. If a client is elected, it will start uploading the model state to the storage declared in the run configuration by the run owner. -Upon exiting the _Cooldown_ phase, the Coordinator transitions to the next epoch, saving the previous epoch state, and moving back to the _WaitingForMembers_ phase. All the clients that were participating in the previous epoch automatically join to the new epoch unless they exit manually. +The client that finishes uploading the model sends a transaction to the coordinator, called the **opportunistic cooldown**, indicating that the entire model was uploaded successfully. + +There are two ways the coordinator can transition from this state to the next one: + +- As soon as the first opportunistic cooldown transaction arrives, the coordinator moves to the next state and cancels all upload tasks from the remaining clients, since it already knows that at least one checkpointer has uploaded the complete model correctly. +- If no transaction is received, there is a maximum cooldown time defined in the run configuration. If this time is reached, the coordinator will move to the next state even if no new checkpoint was produced. + +When the _Cooldown_ phase begins, the coordinator also resets the current model checkpoint state to `Checkpoint::P2P`, indicating that new joiners should download the latest copy of the model from other participants rather than from the usual checkpoint storage. + +Upon exiting the _Cooldown_ phase, the coordinator transitions to the next epoch, saving the previous epoch state and moving back to the _WaitingForMembers_ phase. All clients that participated in the previous epoch automatically join the new epoch unless they exit manually. ### It all comes together Here's is an overview of how the state of the run can change depending on the situation: ```mermaid -%%{init: {'theme':'base', 'themeVariables': { 'fontSize':'35px'}}}%% +%%{init: {'theme':'base', 'themeVariables': { 'fontSize':'45px'}}}%% flowchart LR WFM((Waiting For Members)) W((Warmup)) @@ -119,6 +128,8 @@ flowchart LR d{Witness quorum reached} e{Max training time passed} f{End of the epoch reached} + g{Client checkpoints} + h{Max cooldown time passed} WFM --> a a -->|Yes| W @@ -135,7 +146,10 @@ flowchart LR WI --> f f -->|Yes| CD f -->|No| T - CD --> WFM + CD -->g + g -->|Yes| WFM + g -->|No|h + h -->|Yes| WFM ``` And this is how it fits with real the real clients and how they interact in each of the stages. The committee in this case is the structure that contains all the witness data for the round. diff --git a/psyche-book/src/explain/index.md b/psyche-book/src/explain/index.md index 16d6eb6aa..23f556a46 100644 --- a/psyche-book/src/explain/index.md +++ b/psyche-book/src/explain/index.md @@ -84,7 +84,7 @@ At the start of each round, one or more clients are randomly selected as witness These bloom filters are sent to the coordinator, which then combines them into a provable consensus of which results to apply to the model. -Once a witness quorum is reached, the coordinator advances to the _Training_ phase to allow all clients a brief window to download every training result of the previous round, clients are assigned new data, and the process repeats. After a fixed amount of time, a _Cooldown_ round occurs, marking the end of an **epoch**. This time is configurable in the run creation process that we'll explore in the other sections. +Once a witness quorum is reached, the coordinator advances to the _Training_ phase to allow all clients a brief window to download every training result of the previous round, clients are assigned new data, and the process repeats. After a fixed amount of time, a _Cooldown_ round occurs, marking the end of an **epoch**. At this state, one third of the clients are randomly selected as checkpointers and all of them starts uploading the state of the model to an external storage. There's a maximum amount of time for staying in this state, this time is configurable in the run creation process that we'll explore in the other sections. ## The witness/train loop visualized diff --git a/psyche-book/src/explain/model-sharing.md b/psyche-book/src/explain/model-sharing.md index ba61ebb12..5bc0bb8cd 100644 --- a/psyche-book/src/explain/model-sharing.md +++ b/psyche-book/src/explain/model-sharing.md @@ -6,15 +6,19 @@ At the beginning of a run, all clients must download the model parameters, token Each client will then modify their copy of the model by receiving new training results from other clients and applying them. This keeps everyone's copy of model identical within an **epoch** without an additional full synchronization step. -When a new client joins a run that has already progressed past its first epoch, it would not be correct for the client to download the original model from HuggingFace, as the model parameters would have already been updated during training. Instead, the new client must acquire a copy of the model from the peers who have been actively training it. +When a new client joins a run that has already progressed past its first epoch, it would not be correct for the client to download the original model from the external storage, as the model parameters would have already been updated during training. Instead, the new client must acquire a copy of the model from the peers who have been actively training it. This synchronization process occurs during the _Warmup_ phase, while the coordinator waits to begin the next _Training_ phase. -To address this, we **checkpoint** the model at the end of an **epoch**, where clients save and share the entire model for new peers to join. There are two checkpointing variants: HuggingFace based and P2P based. +To address this, we **checkpoint** the model at the end of an **epoch**, where clients save and share the entire model for new peers to join. There are three checkpointing variants: HuggingFace based, Google Cloud Storage based and P2P based. ## HuggingFace checkpoint -In this approach, a client or a set of clients can optionally run as **checkpointers** if they declare a checkpoint URL when joining the run. These clients upload their copy of updated model to HuggingFace after each epoch, and send the URL for this checkpoint to the coordinator. When a new client joins the run, it retrieves the checkpoint URL from the coordinator, and connects to HuggingFace to download the latest copy of the model parameters and configuration files. +In this approach, a client or a set of clients will be elected randomly as **checkpointers**. These clients upload their copy of updated model to HuggingFace at Cooldown state at the end of the epoch. The model will be uploaded to the HuggingFace repository that is declared in the run configuration by the run owner. When a new client joins the run it connects to HuggingFace to download the latest copy of the model parameters and configuration files. + +## Google Cloud Storage checkpoint + +Very similar to the previous approach but based on Google Cloud Storage bucket. Every elected checkpointer will upload the model at the end of an epoch. The bucket name is declared by the run owner in the initial configuration. If a client joins the run, it connects to the GCS and download the model parameters and configuration files. ## P2P checkpoint diff --git a/scripts/train-solana-test.sh b/scripts/train-solana-test.sh index 55699e591..f226d448e 100755 --- a/scripts/train-solana-test.sh +++ b/scripts/train-solana-test.sh @@ -2,6 +2,14 @@ set -eo pipefail +CHECKPOINT=false +# Parse command line arguments +for arg in "$@"; do + if [[ "$arg" == "--checkpoint" ]]; then + CHECKPOINT=true + fi +done + # use the agenix provided wallet if you have it if [[ -n "${devnet__keypair__wallet_PATH}" && -f "${devnet__keypair__wallet_PATH}" ]]; then WALLET_FILE="${devnet__keypair__wallet_PATH}" @@ -25,10 +33,18 @@ WS_RPC=${WS_RPC:-"ws://127.0.0.1:8900"} RUN_ID=${RUN_ID:-"test"} AUTHORIZER=${AUTHORIZER:-"11111111111111111111111111111111"} +if [[ "$CHECKPOINT" == true ]]; then + echo -e "\n[+] Starting Solana training with checkpointing enabled..." +else + echo -e "\n[+] Starting Solana training without checkpointing..." +fi + # presets for a DGX or an HGX DP=${DP:-"8"} TP=${TP:-"1"} BATCH_SIZE=${BATCH_SIZE:-"1"} +HF_TOKEN=${HF_TOKEN:-""} +GOOGLE_APPLICATION_CREDENTIALS=${GOOGLE_APPLICATION_CREDENTIALS:-""} # fine if this fails solana airdrop 10 "$(solana-keygen pubkey ${WALLET_FILE})" --url "${RPC}" || true @@ -36,7 +52,7 @@ solana airdrop 10 "$(solana-keygen pubkey ${WALLET_FILE})" --url "${RPC}" || tru export RUST_LOG="info,psyche=debug" if [[ "$OTLP_METRICS_URL" == "" ]]; then - cargo run --release --bin psyche-solana-client -- \ + HF_TOKEN=${HF_TOKEN} GOOGLE_APPLICATION_CREDENTIALS=${GOOGLE_APPLICATION_CREDENTIALS} cargo run --release --bin psyche-solana-client -- \ train \ --wallet-private-key-path ${WALLET_FILE} \ --rpc ${RPC} \ @@ -47,9 +63,10 @@ if [[ "$OTLP_METRICS_URL" == "" ]]; then --micro-batch-size ${BATCH_SIZE} \ --authorizer ${AUTHORIZER} \ --logs "console" \ + $( [[ "$CHECKPOINT" == false ]] && echo "--test-mode" ) \ "$@" else - cargo run --release --bin psyche-solana-client -- \ + HF_TOKEN=${HF_TOKEN} GOOGLE_APPLICATION_CREDENTIALS=${GOOGLE_APPLICATION_CREDENTIALS} cargo run --release --bin psyche-solana-client -- \ train \ --wallet-private-key-path ${WALLET_FILE} \ --rpc ${RPC} \ @@ -62,5 +79,6 @@ else --authorizer ${AUTHORIZER} \ --oltp-metrics-url "http://localhost:4318/v1/metrics" \ --oltp-logs-url "http://localhost:4318/v1/logs" \ + $( [[ "$CHECKPOINT" == false ]] && echo "--test-mode" ) \ "$@" fi diff --git a/shared/client/src/cli.rs b/shared/client/src/cli.rs index c599431f4..b22c2f047 100644 --- a/shared/client/src/cli.rs +++ b/shared/client/src/cli.rs @@ -205,7 +205,7 @@ pub struct TrainArgs { #[clap(long, default_value_t = 3, env)] pub keep_steps: u32, - #[clap(long, default_value_t = false, env)] + #[clap(long, default_value_t = false, env, hide = true)] pub test_mode: bool, } From 172293d724777b5e9da4dd6dc4695350f463b7af Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Fri, 16 Jan 2026 11:14:56 -0300 Subject: [PATCH 48/72] Fix extra docs --- psyche-book/src/explain/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/psyche-book/src/explain/index.md b/psyche-book/src/explain/index.md index 23f556a46..a2e6baa27 100644 --- a/psyche-book/src/explain/index.md +++ b/psyche-book/src/explain/index.md @@ -62,7 +62,7 @@ These three phases constitute a **round** of training and will be looping until At the start of an **epoch**, all clients have a window of time to join the run by requesting to be added by coordinator, and then connecting to the other participating clients. This state will be known as the _Waiting for Members_ phase. -Once a minimum threshold of clients has been met, the run will transition to the _Warmup_ phase and begin a countdown to allow connected clients to update their copy of the model. To obtain a copy of the model, the Coordinator will either direct clients to a checkpoint uploaded somewhere like HuggingFace and they will have to download it from there or direct clients to [download the model from other clients](./model-sharing.md) via the p2p network. In the first epoch, all clients will download the model from HuggingFace and after that every new epoch, clients will download the model from other clients via the p2p network. +Once a minimum threshold of clients has been met, the run will transition to the _Warmup_ phase and begin a countdown to allow connected clients to update their copy of the model. To obtain a copy of the model, the Coordinator will either direct clients to a checkpoint uploaded somewhere like HuggingFace or Google Cloud Storage and they will have to download it from there or direct clients to [download the model from other clients](./model-sharing.md) via the p2p network. In the first epoch, all clients will download the model from the external storage and after that every new epoch, clients will download the model from other clients via the p2p network. After the _Warmup_ phase ends, it will enter the _Training_ phase. From 6254788508686f3f1fc3ea5a6f41b378cb3fd587 Mon Sep 17 00:00:00 2001 From: Dylan Socolobsky Date: Fri, 16 Jan 2026 06:37:48 -0800 Subject: [PATCH 49/72] fix google cloud storage code --- Cargo.lock | 1 + architectures/centralized/client/Cargo.toml | 1 + architectures/centralized/client/src/app.rs | 44 +++++++++++---------- 3 files changed, 26 insertions(+), 20 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 44010a4a2..f6083dd47 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7053,6 +7053,7 @@ dependencies = [ "anyhow", "async-trait", "bytemuck", + "bytes", "clap", "clap-markdown", "google-cloud-storage", diff --git a/architectures/centralized/client/Cargo.toml b/architectures/centralized/client/Cargo.toml index be8454978..7d7f51d80 100644 --- a/architectures/centralized/client/Cargo.toml +++ b/architectures/centralized/client/Cargo.toml @@ -25,6 +25,7 @@ time.workspace = true bytemuck.workspace = true clap-markdown.workspace = true hex = "0.4.3" +bytes.workspace = true google-cloud-storage.workspace = true psyche-python-extension-impl = { workspace = true, optional = true } diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index e3f429933..81381bc19 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -1,8 +1,7 @@ use anyhow::{Error, Result}; use bytemuck::Zeroable; -use google_cloud_storage::client::{Client as GcsClient, ClientConfig}; -use google_cloud_storage::http::objects::delete::DeleteObjectRequest; -use google_cloud_storage::http::objects::upload::{Media, UploadObjectRequest, UploadType}; +use bytes::Bytes; +use google_cloud_storage::client::{Storage, StorageControl}; use hf_hub::Repo; use psyche_centralized_shared::{ClientId, ClientToServerMessage, ServerToClientMessage}; use psyche_client::UploadInfo; @@ -200,8 +199,14 @@ impl App { } } Some(UploadInfo::Gcs(gcs_info)) => { - let config = ClientConfig::default().with_auth().await?; - let client = GcsClient::new(config); + let storage = Storage::builder() + .build() + .await + .map_err(|e| anyhow::anyhow!("Failed to create GCS client: {}", e))?; + + let storage_control = StorageControl::builder().build().await.map_err(|e| { + anyhow::anyhow!("Failed to create GCS control client: {}", e) + })?; // Test write access by attempting to upload a small test object let test_key = format!( @@ -209,25 +214,24 @@ impl App { gcs_info.gcs_prefix.clone().unwrap_or_default() ); - let upload_result = client - .upload_object( - &UploadObjectRequest { - bucket: gcs_info.gcs_bucket.clone(), - ..Default::default() - }, - vec![], - &UploadType::Simple(Media::new(test_key.clone())), - ) + let bucket_resource_name = + format!("projects/_/buckets/{}", gcs_info.gcs_bucket); + let test_data = Bytes::from(vec![]); + + let upload_result = storage + .write_object(&bucket_resource_name, &test_key, test_data) + .send_unbuffered() .await; match upload_result { Ok(_) => { - let delete_request = DeleteObjectRequest { - bucket: gcs_info.gcs_bucket.clone(), - object: test_key.clone(), - ..Default::default() - }; - let _ = client.delete_object(&delete_request).await; + // Test upload succeeded, the bucket is writable. Now we delete the test file + let _ = storage_control + .delete_object() + .set_bucket(bucket_resource_name.clone()) + .set_object(test_key) + .send() + .await; } Err(e) => { anyhow::bail!( From 53555920f6922b5ae455046cfd58b43810317df1 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Fri, 16 Jan 2026 14:49:03 -0300 Subject: [PATCH 50/72] Remove hub-repo flag from test --- architectures/centralized/testing/src/test_utils.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/architectures/centralized/testing/src/test_utils.rs b/architectures/centralized/testing/src/test_utils.rs index f0079af70..25390b224 100644 --- a/architectures/centralized/testing/src/test_utils.rs +++ b/architectures/centralized/testing/src/test_utils.rs @@ -135,7 +135,6 @@ pub fn dummy_client_app_params_with_training_delay( "dummy", "--run-id", run_id, "--iroh-relay", "disabled", - "--hub-repo", "dummy/repo", "--iroh-discovery", "local", "--data-parallelism", "1", "--tensor-parallelism", "1", From cb5b26519c4333bc37479a5c4f8db0a1b394d272 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Fri, 16 Jan 2026 15:58:12 -0300 Subject: [PATCH 51/72] Add check for permissions before joining the run --- Cargo.lock | 3 + .../decentralized/solana-client/Cargo.toml | 2 + .../decentralized/solana-client/src/app.rs | 88 ++++++++++++++++++- shared/coordinator/Cargo.toml | 1 + 4 files changed, 92 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7c93a8132..ee0e02670 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6982,6 +6982,7 @@ dependencies = [ "async-trait", "bytemuck", "cfg_eval", + "google-cloud-storage", "psyche-core", "serde", "serde_with", @@ -7282,6 +7283,8 @@ dependencies = [ "async-trait", "clap", "clap-markdown", + "google-cloud-storage", + "hf-hub", "psyche-client", "psyche-coordinator", "psyche-core", diff --git a/architectures/decentralized/solana-client/Cargo.toml b/architectures/decentralized/solana-client/Cargo.toml index 752e3af4f..46a0d2d1e 100644 --- a/architectures/decentralized/solana-client/Cargo.toml +++ b/architectures/decentralized/solana-client/Cargo.toml @@ -29,6 +29,8 @@ time.workspace = true tokio.workspace = true tokio-util.workspace = true tracing.workspace = true +google-cloud-storage.workspace = true +hf-hub.workspace = true psyche-python-extension-impl = { workspace = true, optional = true } [features] diff --git a/architectures/decentralized/solana-client/src/app.rs b/architectures/decentralized/solana-client/src/app.rs index 36a529bbb..0f4e91895 100644 --- a/architectures/decentralized/solana-client/src/app.rs +++ b/architectures/decentralized/solana-client/src/app.rs @@ -1,4 +1,9 @@ use crate::network_identity::NetworkIdentity; +use google_cloud_storage::{ + client::{Client as GcsClient, ClientConfig}, + http::buckets::test_iam_permissions::TestIamPermissionsRequest, +}; +use hf_hub::Repo; use psyche_solana_rpc::SolanaBackend; use anchor_client::{ @@ -11,9 +16,13 @@ use anchor_client::{ }; use anyhow::{Result, anyhow}; use psyche_client::{ - Client, ClientTUI, ClientTUIState, NC, RunInitConfig, TrainArgs, read_identity_secret_key, + Client, ClientTUI, ClientTUIState, GcsUploadInfo, HubUploadInfo, NC, RunInitConfig, TrainArgs, + UploadInfo, read_identity_secret_key, +}; +use psyche_coordinator::{ + ClientState, Coordinator, CoordinatorError, RunState, + model::{self, GcsRepo, HubRepo, LLM, Model}, }; -use psyche_coordinator::{ClientState, Coordinator, CoordinatorError, RunState}; use psyche_core::sha256; use psyche_metrics::ClientMetrics; @@ -52,6 +61,7 @@ pub struct App { allowlist: allowlist::AllowDynamic, p2p: NC, state_options: RunInitConfig, + no_checkpoint: bool, } pub struct AppParams { @@ -152,6 +162,7 @@ pub async fn build_app( metrics, p2p, state_options, + no_checkpoint: p.test_mode, }; Ok(app) } @@ -226,6 +237,79 @@ impl App { let mut joined_run_this_epoch = None; let mut ever_joined_run = false; + // sanity checks + let Model::LLM(LLM { checkpoint, .. }) = start_coordinator_state.model; + if !self.no_checkpoint { + let upload_info = match checkpoint { + model::Checkpoint::Hub(HubRepo { repo_id, revision }) + | model::Checkpoint::P2P(HubRepo { repo_id, revision }) => { + Some(UploadInfo::Hub(HubUploadInfo { + hub_repo: (&repo_id).into(), + hub_token: (&revision.unwrap_or_default()).into(), + })) + } + model::Checkpoint::Gcs(GcsRepo { bucket, prefix }) + | model::Checkpoint::P2PGcs(model::GcsRepo { bucket, prefix }) => { + Some(UploadInfo::Gcs(GcsUploadInfo { + gcs_bucket: (&bucket).into(), + gcs_prefix: Some((&prefix.unwrap_or_default()).into()), + })) + } + _ => None, + }; + match upload_info { + Some(UploadInfo::Hub(hub_info)) => { + let api = hf_hub::api::tokio::ApiBuilder::new() + .with_token(Some(hub_info.hub_token)) + .build()?; + let repo_api = api.repo(Repo::new( + hub_info.hub_repo.clone(), + hf_hub::RepoType::Model, + )); + if !repo_api.is_writable().await { + anyhow::bail!( + "Checkpoint upload repo {} is not writable with the passed API key.", + hub_info.hub_repo + ) + } + } + Some(UploadInfo::Gcs(gcs_info)) => { + let config = ClientConfig::default().with_auth().await?; + let client = GcsClient::new(config); + + // Test if we have the required permissions + let permissions_to_test = vec![ + "storage.objects.create".to_string(), + "storage.objects.delete".to_string(), + "storage.objects.get".to_string(), + "storage.objects.list".to_string(), + "storage.objects.update".to_string(), + ]; + + let result = client + .test_iam_permissions(&TestIamPermissionsRequest { + resource: format!("projects/_/buckets/{}", gcs_info.gcs_bucket), + permissions: permissions_to_test.clone(), + }) + .await?; + + let correct_permissions = permissions_to_test + .iter() + .all(|p| result.permissions.contains(p)); + if !correct_permissions { + anyhow::bail!( + "GCS bucket {} does not have the required permissions for checkpoint upload make sure to set GOOGLE_APPLICATION_CREDENTIALS environment variable correctly.", + gcs_info.gcs_bucket + ) + } + } + Some(UploadInfo::Dummy()) => { + // In test mode, we skip upload checks + } + None => {} + } + } + // if we're already in "WaitingForMembers" we won't get an update saying that // (subscription is on change), so check if it's in that state right at boot // and join the run if so diff --git a/shared/coordinator/Cargo.toml b/shared/coordinator/Cargo.toml index 024696555..91cd5fb09 100644 --- a/shared/coordinator/Cargo.toml +++ b/shared/coordinator/Cargo.toml @@ -13,3 +13,4 @@ anyhow.workspace = true serde.workspace = true cfg_eval = "0.1.2" ts-rs.workspace = true +google-cloud-storage.workspace = true From f1556903c8f218e8112098d5d75f1e5a3a1c52bd Mon Sep 17 00:00:00 2001 From: Mariano Nicolini Date: Fri, 16 Jan 2026 11:37:37 -0800 Subject: [PATCH 52/72] add GCS credentials to scratch dir used in run manager --- tools/rust-tools/run-manager/src/docker/manager.rs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tools/rust-tools/run-manager/src/docker/manager.rs b/tools/rust-tools/run-manager/src/docker/manager.rs index cf20f480a..596946585 100644 --- a/tools/rust-tools/run-manager/src/docker/manager.rs +++ b/tools/rust-tools/run-manager/src/docker/manager.rs @@ -2,7 +2,7 @@ use anchor_client::solana_sdk::pubkey::Pubkey; use anyhow::{Context, Result, anyhow, bail}; use std::fs; use std::io::{BufRead, BufReader}; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use std::process::{Command, Stdio}; use tokio::signal; use tracing::{error, info, warn}; @@ -163,8 +163,17 @@ impl RunManager { .arg(&self.env_file); if let Some(dir) = &self.scratch_dir { + let scratch_credentials_path = format!("{dir}/application_default_credentials.json"); + if !Path::new(&scratch_credentials_path).exists() { + bail!("GCS credentials were not found in scratch dir"); + } + cmd.arg("--mount") - .arg(format!("type=bind,src={dir},dst=/scratch")); + .arg(format!("type=bind,src={dir},dst=/scratch")) + .arg("--env") + .arg( + "GOOGLE_APPLICATION_CREDENTIALS=/scratch/application_default_credentials.json", + ); } if let Some(Entrypoint { entrypoint, .. }) = entrypoint { From f361b93c2cfc97ee607160d655a8fad87c4f76dc Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Fri, 16 Jan 2026 12:46:54 -0800 Subject: [PATCH 53/72] Remove google-cloud-storage from coordinator toml --- Cargo.lock | 1 - shared/coordinator/Cargo.toml | 1 - 2 files changed, 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ee0e02670..149193011 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6982,7 +6982,6 @@ dependencies = [ "async-trait", "bytemuck", "cfg_eval", - "google-cloud-storage", "psyche-core", "serde", "serde_with", diff --git a/shared/coordinator/Cargo.toml b/shared/coordinator/Cargo.toml index 91cd5fb09..024696555 100644 --- a/shared/coordinator/Cargo.toml +++ b/shared/coordinator/Cargo.toml @@ -13,4 +13,3 @@ anyhow.workspace = true serde.workspace = true cfg_eval = "0.1.2" ts-rs.workspace = true -google-cloud-storage.workspace = true From 611f322c343f74995737a2a1f5dc2caf6365a15d Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Fri, 16 Jan 2026 18:04:20 -0300 Subject: [PATCH 54/72] Remove hub repo arguments in test --- docker/test/client_test_entrypoint.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docker/test/client_test_entrypoint.sh b/docker/test/client_test_entrypoint.sh index a67210b0b..3b24ca44f 100644 --- a/docker/test/client_test_entrypoint.sh +++ b/docker/test/client_test_entrypoint.sh @@ -11,22 +11,22 @@ echo "USING SIDECAR PORT: ${SIDECAR_PORT}" # Build the command based on environment variable if [ "${PYTHON_ENABLED}" = "true" ]; then echo "Starting client with Python features enabled" - HF_TOKEN="test" psyche-solana-client train \ + psyche-solana-client train \ --wallet-private-key-path "/root/.config/solana/id.json" \ --rpc "${RPC}" \ --ws-rpc "${WS_RPC}" \ --run-id "${RUN_ID}" \ - --hub-repo "dummy/test-hub-repo" \ --data-parallelism 8 \ --sidecar-port "${SIDECAR_PORT}" \ + --test-mode \ --logs "json" else echo "Starting client without Python features" - HF_TOKEN="test" psyche-solana-client train \ + psyche-solana-client train \ --wallet-private-key-path "/root/.config/solana/id.json" \ --rpc "${RPC}" \ --ws-rpc "${WS_RPC}" \ - --hub-repo "dummy/test-hub-repo" \ --run-id "${RUN_ID}" \ + --test-mode \ --logs "json" fi From 280913541e442251e0ca28a456216cdddd302e2f Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Fri, 16 Jan 2026 18:22:59 -0300 Subject: [PATCH 55/72] Update to version 1.6 --- Cargo.lock | 115 ++++++++++++++++++++--------------------------------- Cargo.toml | 2 +- 2 files changed, 45 insertions(+), 72 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bc1449341..435e922d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1060,31 +1060,6 @@ dependencies = [ "serde_with", ] -[[package]] -name = "bon" -version = "3.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "234655ec178edd82b891e262ea7cf71f6584bcd09eff94db786be23f1821825c" -dependencies = [ - "bon-macros", - "rustversion", -] - -[[package]] -name = "bon-macros" -version = "3.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89ec27229c38ed0eb3c0feee3d2c1d6a4379ae44f418a29a658890e062d8f365" -dependencies = [ - "darling 0.21.3", - "ident_case", - "prettyplease", - "proc-macro2", - "quote", - "rustversion", - "syn 2.0.106", -] - [[package]] name = "borsh" version = "0.10.4" @@ -3302,7 +3277,7 @@ dependencies = [ "google-cloud-token", "home", "jsonwebtoken", - "reqwest 0.12.24", + "reqwest 0.12.28", "serde", "serde_json", "thiserror 1.0.69", @@ -3314,20 +3289,19 @@ dependencies = [ [[package]] name = "google-cloud-auth" -version = "1.3.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "590a1c28795779d5da6fda35b149d5271bcddcf2ce1709eae9e9460faf2f2aa9" +checksum = "34f8aadacd3195fc3b08f2a5d582f2401c60d9f1598574acfcfb6228de25db29" dependencies = [ "async-trait", "base64 0.22.1", - "bon", "bytes", "google-cloud-gax", "http 1.3.1", - "reqwest 0.12.24", + "reqwest 0.12.28", "rustc_version", "rustls 0.23.35", - "rustls-pemfile 2.2.0", + "rustls-pki-types", "serde", "serde_json", "thiserror 2.0.17", @@ -3337,9 +3311,9 @@ dependencies = [ [[package]] name = "google-cloud-gax" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "324fb97d35103787e80a33ed41ccc43d947c376d2ece68ca53e860f5844dbe24" +checksum = "b218292363f2e2d6ab8d6da4118acf91cc044439c442d2d6809b581e0728b377" dependencies = [ "base64 0.22.1", "bytes", @@ -3357,13 +3331,13 @@ dependencies = [ [[package]] name = "google-cloud-gax-internal" -version = "0.7.6" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b75b810886ae872aca68a35ad1d4d5e8f2be39e40238116d8aff9d778f04b38" +checksum = "78125fa0347492177131d30c010e57ddce9bba1504c33be135f5853a9105c277" dependencies = [ "bytes", "futures", - "google-cloud-auth 1.3.0", + "google-cloud-auth 1.4.0", "google-cloud-gax", "google-cloud-rpc", "google-cloud-wkt", @@ -3376,7 +3350,7 @@ dependencies = [ "pin-project", "prost 0.14.3", "prost-types", - "reqwest 0.12.24", + "reqwest 0.12.28", "rustc_version", "serde", "serde_json", @@ -3402,7 +3376,7 @@ dependencies = [ "google-cloud-type", "google-cloud-wkt", "lazy_static", - "reqwest 0.12.24", + "reqwest 0.12.28", "serde", "serde_json", "serde_with", @@ -3422,7 +3396,7 @@ dependencies = [ "google-cloud-rpc", "google-cloud-wkt", "lazy_static", - "reqwest 0.12.24", + "reqwest 0.12.28", "serde", "serde_json", "serde_with", @@ -3449,7 +3423,7 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d901aeb453fd80e51d64df4ee005014f6cf39f2d736dd64f7239c132d9d39a6a" dependencies = [ - "reqwest 0.12.24", + "reqwest 0.12.28", "thiserror 1.0.69", "tokio", ] @@ -3469,16 +3443,17 @@ dependencies = [ [[package]] name = "google-cloud-storage" -version = "1.5.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "043be824d1b105bfdce786c720e45cae04e66436f8e5d0168e98ca8e5715ce9f" +checksum = "6abde5d51a4728f47b8f7781d7bf86ab51e310b42ec7c7c96578f1d03da938e4" dependencies = [ "async-trait", "base64 0.22.1", "bytes", + "chrono", "crc32c", "futures", - "google-cloud-auth 1.3.0", + "google-cloud-auth 1.4.0", "google-cloud-gax", "google-cloud-gax-internal", "google-cloud-iam-v1", @@ -3487,6 +3462,7 @@ dependencies = [ "google-cloud-rpc", "google-cloud-type", "google-cloud-wkt", + "hex", "http 1.3.1", "http-body 1.0.1", "hyper 1.7.0", @@ -3497,7 +3473,7 @@ dependencies = [ "pin-project", "prost 0.14.3", "prost-types", - "reqwest 0.12.24", + "reqwest 0.12.28", "serde", "serde_json", "serde_with", @@ -3507,6 +3483,7 @@ dependencies = [ "tokio-stream", "tonic 0.14.2", "tracing", + "url", "uuid", ] @@ -3824,7 +3801,7 @@ dependencies = [ "num_cpus", "rand 0.8.5", "regex", - "reqwest 0.12.24", + "reqwest 0.12.28", "serde", "serde_json", "sha1 0.10.6", @@ -4583,7 +4560,7 @@ dependencies = [ "pkcs8", "portmapper", "rand 0.9.2", - "reqwest 0.12.24", + "reqwest 0.12.28", "rustls 0.23.35", "rustls-pki-types", "rustls-platform-verifier 0.5.3", @@ -4840,7 +4817,7 @@ dependencies = [ "pkarr", "postcard", "rand 0.9.2", - "reqwest 0.12.24", + "reqwest 0.12.28", "rustls 0.23.35", "rustls-pki-types", "serde", @@ -6344,7 +6321,7 @@ dependencies = [ "bytes", "http 1.3.1", "opentelemetry 0.28.0", - "reqwest 0.12.24", + "reqwest 0.12.28", "tracing", ] @@ -6362,7 +6339,7 @@ dependencies = [ "opentelemetry-proto", "opentelemetry_sdk", "prost 0.13.5", - "reqwest 0.12.24", + "reqwest 0.12.28", "thiserror 2.0.17", "tracing", ] @@ -6632,7 +6609,7 @@ dependencies = [ "log", "lru 0.13.0", "ntimestamp", - "reqwest 0.12.24", + "reqwest 0.12.28", "self_cell", "serde", "sha1_smol", @@ -6885,16 +6862,6 @@ dependencies = [ "yansi", ] -[[package]] -name = "prettyplease" -version = "0.2.37" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" -dependencies = [ - "proc-macro2", - "syn 2.0.106", -] - [[package]] name = "preview-lr" version = "0.1.0" @@ -7264,7 +7231,7 @@ dependencies = [ "rand 0.9.2", "rand_chacha 0.9.0", "rayon", - "reqwest 0.12.24", + "reqwest 0.12.28", "serde", "serde_json", "static-web-server", @@ -8190,9 +8157,9 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.12.24" +version = "0.12.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d0946410b9f7b082a427e4ef5c8ff541a88b357bc6c637c40db3a68ac70a36f" +checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" dependencies = [ "base64 0.22.1", "bytes", @@ -8455,9 +8422,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.12.0" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" +checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" dependencies = [ "web-time", "zeroize", @@ -8790,15 +8757,15 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.145" +version = "1.0.149" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" dependencies = [ "itoa", "memchr", - "ryu", "serde", "serde_core", + "zmij", ] [[package]] @@ -12585,9 +12552,9 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.6.6" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2" +checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" dependencies = [ "bitflags 2.9.4", "bytes", @@ -13082,7 +13049,7 @@ dependencies = [ "env_logger 0.11.8", "graphql_client", "impl_from_tuple", - "reqwest 0.12.24", + "reqwest 0.12.28", "serde", "serde_json", "thiserror 1.0.69", @@ -14223,6 +14190,12 @@ dependencies = [ "zstd 0.11.2+zstd.1.5.2", ] +[[package]] +name = "zmij" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd8f3f50b848df28f887acb68e41201b5aea6bc8a8dacc00fb40635ff9a72fea" + [[package]] name = "zstd" version = "0.11.2+zstd.1.5.2" diff --git a/Cargo.toml b/Cargo.toml index b59b169f1..244f2b92d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,7 +79,7 @@ indicatif = "0.17.5" tokenizers = { version = "0.20.0", default-features = false, features = [ "onig", ] } -google-cloud-storage = "1.5.0" +google-cloud-storage = "1.6.0" tch = { git = "https://github.com/jquesnelle/tch-rs.git", rev = "11d1ca2ef6dbd3f1e5b0986fab0a90fbb6734496" } torch-sys = { git = "https://github.com/jquesnelle/tch-rs.git", rev = "11d1ca2ef6dbd3f1e5b0986fab0a90fbb6734496" } pyo3-tch = { git = "https://github.com/jquesnelle/tch-rs.git", rev = "11d1ca2ef6dbd3f1e5b0986fab0a90fbb6734496" } From 089ca65c69f6413ef698ed6569e205072954f8a2 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Fri, 16 Jan 2026 13:34:55 -0800 Subject: [PATCH 56/72] Fix extra args in script --- scripts/train-solana-test.sh | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/scripts/train-solana-test.sh b/scripts/train-solana-test.sh index f226d448e..d436fee40 100755 --- a/scripts/train-solana-test.sh +++ b/scripts/train-solana-test.sh @@ -3,10 +3,13 @@ set -eo pipefail CHECKPOINT=false +EXTRA_ARGS=() # Parse command line arguments for arg in "$@"; do if [[ "$arg" == "--checkpoint" ]]; then CHECKPOINT=true + else + EXTRA_ARGS+=("$arg") fi done @@ -28,8 +31,8 @@ elif [[ -z "${WALLET_FILE:-}" ]]; then trap "echo 'Cleaning up ephemeral wallet file...'; rm -f '${WALLET_FILE}'" EXIT fi -RPC=${RPC:-"http://127.0.0.1:8899"} -WS_RPC=${WS_RPC:-"ws://127.0.0.1:8900"} +RPC=${RPC:-"http://7da1cfaf-50:8899"} +WS_RPC=${WS_RPC:-"ws://7da1cfaf-50:8900"} RUN_ID=${RUN_ID:-"test"} AUTHORIZER=${AUTHORIZER:-"11111111111111111111111111111111"} @@ -64,7 +67,7 @@ if [[ "$OTLP_METRICS_URL" == "" ]]; then --authorizer ${AUTHORIZER} \ --logs "console" \ $( [[ "$CHECKPOINT" == false ]] && echo "--test-mode" ) \ - "$@" + "${EXTRA_ARGS[@]}" else HF_TOKEN=${HF_TOKEN} GOOGLE_APPLICATION_CREDENTIALS=${GOOGLE_APPLICATION_CREDENTIALS} cargo run --release --bin psyche-solana-client -- \ train \ @@ -80,5 +83,5 @@ else --oltp-metrics-url "http://localhost:4318/v1/metrics" \ --oltp-logs-url "http://localhost:4318/v1/logs" \ $( [[ "$CHECKPOINT" == false ]] && echo "--test-mode" ) \ - "$@" + "${EXTRA_ARGS[@]}" fi From 2402b4e180a6b3017b62082c4f40bd48edb11730 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Fri, 16 Jan 2026 13:41:55 -0800 Subject: [PATCH 57/72] Remove sanity check for permissions --- .../decentralized/solana-client/src/app.rs | 142 +++++++++--------- 1 file changed, 71 insertions(+), 71 deletions(-) diff --git a/architectures/decentralized/solana-client/src/app.rs b/architectures/decentralized/solana-client/src/app.rs index 0f4e91895..7a074904b 100644 --- a/architectures/decentralized/solana-client/src/app.rs +++ b/architectures/decentralized/solana-client/src/app.rs @@ -238,77 +238,77 @@ impl App { let mut ever_joined_run = false; // sanity checks - let Model::LLM(LLM { checkpoint, .. }) = start_coordinator_state.model; - if !self.no_checkpoint { - let upload_info = match checkpoint { - model::Checkpoint::Hub(HubRepo { repo_id, revision }) - | model::Checkpoint::P2P(HubRepo { repo_id, revision }) => { - Some(UploadInfo::Hub(HubUploadInfo { - hub_repo: (&repo_id).into(), - hub_token: (&revision.unwrap_or_default()).into(), - })) - } - model::Checkpoint::Gcs(GcsRepo { bucket, prefix }) - | model::Checkpoint::P2PGcs(model::GcsRepo { bucket, prefix }) => { - Some(UploadInfo::Gcs(GcsUploadInfo { - gcs_bucket: (&bucket).into(), - gcs_prefix: Some((&prefix.unwrap_or_default()).into()), - })) - } - _ => None, - }; - match upload_info { - Some(UploadInfo::Hub(hub_info)) => { - let api = hf_hub::api::tokio::ApiBuilder::new() - .with_token(Some(hub_info.hub_token)) - .build()?; - let repo_api = api.repo(Repo::new( - hub_info.hub_repo.clone(), - hf_hub::RepoType::Model, - )); - if !repo_api.is_writable().await { - anyhow::bail!( - "Checkpoint upload repo {} is not writable with the passed API key.", - hub_info.hub_repo - ) - } - } - Some(UploadInfo::Gcs(gcs_info)) => { - let config = ClientConfig::default().with_auth().await?; - let client = GcsClient::new(config); - - // Test if we have the required permissions - let permissions_to_test = vec![ - "storage.objects.create".to_string(), - "storage.objects.delete".to_string(), - "storage.objects.get".to_string(), - "storage.objects.list".to_string(), - "storage.objects.update".to_string(), - ]; - - let result = client - .test_iam_permissions(&TestIamPermissionsRequest { - resource: format!("projects/_/buckets/{}", gcs_info.gcs_bucket), - permissions: permissions_to_test.clone(), - }) - .await?; - - let correct_permissions = permissions_to_test - .iter() - .all(|p| result.permissions.contains(p)); - if !correct_permissions { - anyhow::bail!( - "GCS bucket {} does not have the required permissions for checkpoint upload make sure to set GOOGLE_APPLICATION_CREDENTIALS environment variable correctly.", - gcs_info.gcs_bucket - ) - } - } - Some(UploadInfo::Dummy()) => { - // In test mode, we skip upload checks - } - None => {} - } - } + // let Model::LLM(LLM { checkpoint, .. }) = start_coordinator_state.model; + // if !self.no_checkpoint { + // let upload_info = match checkpoint { + // model::Checkpoint::Hub(HubRepo { repo_id, revision }) + // | model::Checkpoint::P2P(HubRepo { repo_id, revision }) => { + // Some(UploadInfo::Hub(HubUploadInfo { + // hub_repo: (&repo_id).into(), + // hub_token: (&revision.unwrap_or_default()).into(), + // })) + // } + // model::Checkpoint::Gcs(GcsRepo { bucket, prefix }) + // | model::Checkpoint::P2PGcs(model::GcsRepo { bucket, prefix }) => { + // Some(UploadInfo::Gcs(GcsUploadInfo { + // gcs_bucket: (&bucket).into(), + // gcs_prefix: Some((&prefix.unwrap_or_default()).into()), + // })) + // } + // _ => None, + // }; + // match upload_info { + // Some(UploadInfo::Hub(hub_info)) => { + // let api = hf_hub::api::tokio::ApiBuilder::new() + // .with_token(Some(hub_info.hub_token)) + // .build()?; + // let repo_api = api.repo(Repo::new( + // hub_info.hub_repo.clone(), + // hf_hub::RepoType::Model, + // )); + // if !repo_api.is_writable().await { + // anyhow::bail!( + // "Checkpoint upload repo {} is not writable with the passed API key.", + // hub_info.hub_repo + // ) + // } + // } + // Some(UploadInfo::Gcs(gcs_info)) => { + // let config = ClientConfig::default().with_auth().await?; + // let client = GcsClient::new(config); + + // // Test if we have the required permissions + // let permissions_to_test = vec![ + // "storage.objects.create".to_string(), + // "storage.objects.delete".to_string(), + // "storage.objects.get".to_string(), + // "storage.objects.list".to_string(), + // "storage.objects.update".to_string(), + // ]; + + // let result = client + // .test_iam_permissions(&TestIamPermissionsRequest { + // resource: format!("projects/_/buckets/{}", gcs_info.gcs_bucket), + // permissions: permissions_to_test.clone(), + // }) + // .await?; + + // let correct_permissions = permissions_to_test + // .iter() + // .all(|p| result.permissions.contains(p)); + // if !correct_permissions { + // anyhow::bail!( + // "GCS bucket {} does not have the required permissions for checkpoint upload make sure to set GOOGLE_APPLICATION_CREDENTIALS environment variable correctly.", + // gcs_info.gcs_bucket + // ) + // } + // } + // Some(UploadInfo::Dummy()) => { + // // In test mode, we skip upload checks + // } + // None => {} + // } + // } // if we're already in "WaitingForMembers" we won't get an update saying that // (subscription is on change), so check if it's in that state right at boot From 6983523a9ccdd42a667af320f16c99578f98ad68 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Mon, 19 Jan 2026 10:53:32 -0300 Subject: [PATCH 58/72] Check bucket and repo permissions before joining run --- .../centralized/testing/src/test_utils.rs | 1 - .../decentralized/solana-client/src/app.rs | 142 +++++++++--------- config/solana-test/light-config.toml | 2 +- config/solana-test/nano-config.toml | 2 +- docker/test/client_test_entrypoint.sh | 2 - scripts/train-solana-test.sh | 2 - shared/client/src/cli.rs | 14 -- shared/coordinator/src/model.rs | 2 +- 8 files changed, 72 insertions(+), 95 deletions(-) diff --git a/architectures/centralized/testing/src/test_utils.rs b/architectures/centralized/testing/src/test_utils.rs index 25390b224..ddaf9fab2 100644 --- a/architectures/centralized/testing/src/test_utils.rs +++ b/architectures/centralized/testing/src/test_utils.rs @@ -142,7 +142,6 @@ pub fn dummy_client_app_params_with_training_delay( "--max-concurrent-parameter-requests", "10", "--hub-max-concurrent-downloads", "1", "--dummy-training-delay-secs", training_delay_secs.to_string().as_str(), - "--test-mode", ]) .train_args, } diff --git a/architectures/decentralized/solana-client/src/app.rs b/architectures/decentralized/solana-client/src/app.rs index 7a074904b..6f9639227 100644 --- a/architectures/decentralized/solana-client/src/app.rs +++ b/architectures/decentralized/solana-client/src/app.rs @@ -61,7 +61,6 @@ pub struct App { allowlist: allowlist::AllowDynamic, p2p: NC, state_options: RunInitConfig, - no_checkpoint: bool, } pub struct AppParams { @@ -162,7 +161,6 @@ pub async fn build_app( metrics, p2p, state_options, - no_checkpoint: p.test_mode, }; Ok(app) } @@ -238,77 +236,75 @@ impl App { let mut ever_joined_run = false; // sanity checks - // let Model::LLM(LLM { checkpoint, .. }) = start_coordinator_state.model; - // if !self.no_checkpoint { - // let upload_info = match checkpoint { - // model::Checkpoint::Hub(HubRepo { repo_id, revision }) - // | model::Checkpoint::P2P(HubRepo { repo_id, revision }) => { - // Some(UploadInfo::Hub(HubUploadInfo { - // hub_repo: (&repo_id).into(), - // hub_token: (&revision.unwrap_or_default()).into(), - // })) - // } - // model::Checkpoint::Gcs(GcsRepo { bucket, prefix }) - // | model::Checkpoint::P2PGcs(model::GcsRepo { bucket, prefix }) => { - // Some(UploadInfo::Gcs(GcsUploadInfo { - // gcs_bucket: (&bucket).into(), - // gcs_prefix: Some((&prefix.unwrap_or_default()).into()), - // })) - // } - // _ => None, - // }; - // match upload_info { - // Some(UploadInfo::Hub(hub_info)) => { - // let api = hf_hub::api::tokio::ApiBuilder::new() - // .with_token(Some(hub_info.hub_token)) - // .build()?; - // let repo_api = api.repo(Repo::new( - // hub_info.hub_repo.clone(), - // hf_hub::RepoType::Model, - // )); - // if !repo_api.is_writable().await { - // anyhow::bail!( - // "Checkpoint upload repo {} is not writable with the passed API key.", - // hub_info.hub_repo - // ) - // } - // } - // Some(UploadInfo::Gcs(gcs_info)) => { - // let config = ClientConfig::default().with_auth().await?; - // let client = GcsClient::new(config); - - // // Test if we have the required permissions - // let permissions_to_test = vec![ - // "storage.objects.create".to_string(), - // "storage.objects.delete".to_string(), - // "storage.objects.get".to_string(), - // "storage.objects.list".to_string(), - // "storage.objects.update".to_string(), - // ]; - - // let result = client - // .test_iam_permissions(&TestIamPermissionsRequest { - // resource: format!("projects/_/buckets/{}", gcs_info.gcs_bucket), - // permissions: permissions_to_test.clone(), - // }) - // .await?; - - // let correct_permissions = permissions_to_test - // .iter() - // .all(|p| result.permissions.contains(p)); - // if !correct_permissions { - // anyhow::bail!( - // "GCS bucket {} does not have the required permissions for checkpoint upload make sure to set GOOGLE_APPLICATION_CREDENTIALS environment variable correctly.", - // gcs_info.gcs_bucket - // ) - // } - // } - // Some(UploadInfo::Dummy()) => { - // // In test mode, we skip upload checks - // } - // None => {} - // } - // } + let Model::LLM(LLM { checkpoint, .. }) = start_coordinator_state.model; + let upload_info = match checkpoint { + model::Checkpoint::Hub(HubRepo { repo_id, revision }) + | model::Checkpoint::P2P(HubRepo { repo_id, revision }) => { + Some(UploadInfo::Hub(HubUploadInfo { + hub_repo: (&repo_id).into(), + hub_token: (&revision.unwrap_or_default()).into(), + })) + } + model::Checkpoint::Gcs(GcsRepo { bucket, prefix }) + | model::Checkpoint::P2PGcs(model::GcsRepo { bucket, prefix }) => { + Some(UploadInfo::Gcs(GcsUploadInfo { + gcs_bucket: (&bucket).into(), + gcs_prefix: Some((&prefix.unwrap_or_default()).into()), + })) + } + _ => None, + }; + match upload_info { + Some(UploadInfo::Hub(hub_info)) => { + let api = hf_hub::api::tokio::ApiBuilder::new() + .with_token(Some(hub_info.hub_token)) + .build()?; + let repo_api = api.repo(Repo::new( + hub_info.hub_repo.clone(), + hf_hub::RepoType::Model, + )); + if !repo_api.is_writable().await { + anyhow::bail!( + "Checkpoint upload repo {} is not writable with the passed API key.", + hub_info.hub_repo + ) + } + } + Some(UploadInfo::Gcs(gcs_info)) => { + let config = ClientConfig::default().with_auth().await?; + let client = GcsClient::new(config); + + // Test if we have the required permissions + let permissions_to_test = vec![ + "storage.objects.create".to_string(), + "storage.objects.delete".to_string(), + "storage.objects.get".to_string(), + "storage.objects.list".to_string(), + "storage.objects.update".to_string(), + ]; + + let result = client + .test_iam_permissions(&TestIamPermissionsRequest { + resource: format!("projects/_/buckets/{}", gcs_info.gcs_bucket), + permissions: permissions_to_test.clone(), + }) + .await?; + + let correct_permissions = permissions_to_test + .iter() + .all(|p| result.permissions.contains(p)); + if !correct_permissions { + anyhow::bail!( + "GCS bucket {} does not have the required permissions for checkpoint upload make sure to set GOOGLE_APPLICATION_CREDENTIALS environment variable correctly.", + gcs_info.gcs_bucket + ) + } + } + Some(UploadInfo::Dummy()) => { + // In test mode, we skip upload checks + } + None => {} + } // if we're already in "WaitingForMembers" we won't get an update saying that // (subscription is on change), so check if it's in that state right at boot diff --git a/config/solana-test/light-config.toml b/config/solana-test/light-config.toml index eab015342..71228622f 100644 --- a/config/solana-test/light-config.toml +++ b/config/solana-test/light-config.toml @@ -20,7 +20,7 @@ data_type = "Pretraining" max_seq_len = 2048 cold_start_warmup_steps = 0 -[model.LLM.checkpoint.Hub] +[model.LLM.checkpoint.Dummy] repo_id = "emozilla/llama2-20m-init" [model.LLM.data_location.Http] diff --git a/config/solana-test/nano-config.toml b/config/solana-test/nano-config.toml index c275feea3..19f7b8acc 100644 --- a/config/solana-test/nano-config.toml +++ b/config/solana-test/nano-config.toml @@ -20,7 +20,7 @@ data_type = "Pretraining" max_seq_len = 64 cold_start_warmup_steps = 0 -[model.LLM.checkpoint.Hub] +[model.LLM.checkpoint.Dummy] repo_id = "pefontana/Nano-Llama" revision = "cf48eac4944f6e954a3d9c9c30e8c865e64e7d03" diff --git a/docker/test/client_test_entrypoint.sh b/docker/test/client_test_entrypoint.sh index 3b24ca44f..64e739841 100644 --- a/docker/test/client_test_entrypoint.sh +++ b/docker/test/client_test_entrypoint.sh @@ -18,7 +18,6 @@ if [ "${PYTHON_ENABLED}" = "true" ]; then --run-id "${RUN_ID}" \ --data-parallelism 8 \ --sidecar-port "${SIDECAR_PORT}" \ - --test-mode \ --logs "json" else echo "Starting client without Python features" @@ -27,6 +26,5 @@ else --rpc "${RPC}" \ --ws-rpc "${WS_RPC}" \ --run-id "${RUN_ID}" \ - --test-mode \ --logs "json" fi diff --git a/scripts/train-solana-test.sh b/scripts/train-solana-test.sh index d436fee40..dc6c95327 100755 --- a/scripts/train-solana-test.sh +++ b/scripts/train-solana-test.sh @@ -66,7 +66,6 @@ if [[ "$OTLP_METRICS_URL" == "" ]]; then --micro-batch-size ${BATCH_SIZE} \ --authorizer ${AUTHORIZER} \ --logs "console" \ - $( [[ "$CHECKPOINT" == false ]] && echo "--test-mode" ) \ "${EXTRA_ARGS[@]}" else HF_TOKEN=${HF_TOKEN} GOOGLE_APPLICATION_CREDENTIALS=${GOOGLE_APPLICATION_CREDENTIALS} cargo run --release --bin psyche-solana-client -- \ @@ -82,6 +81,5 @@ else --authorizer ${AUTHORIZER} \ --oltp-metrics-url "http://localhost:4318/v1/metrics" \ --oltp-logs-url "http://localhost:4318/v1/logs" \ - $( [[ "$CHECKPOINT" == false ]] && echo "--test-mode" ) \ "${EXTRA_ARGS[@]}" fi diff --git a/shared/client/src/cli.rs b/shared/client/src/cli.rs index c01be3553..7c74ed641 100644 --- a/shared/client/src/cli.rs +++ b/shared/client/src/cli.rs @@ -190,9 +190,6 @@ pub struct TrainArgs { #[clap(long, default_value_t = 3, env)] pub keep_steps: u32, - - #[clap(long, default_value_t = false, env, hide = true)] - pub test_mode: bool, } impl TrainArgs { @@ -222,18 +219,7 @@ impl TrainArgs { } pub fn checkpoint_config(&self) -> Result { - if self.test_mode { - return Ok(CheckpointConfig::dummy()); - } - let hub_token = std::env::var("HF_TOKEN").ok(); - let google_application_credentials = std::env::var("GOOGLE_APPLICATION_CREDENTIALS").ok(); - - if hub_token.is_none() && google_application_credentials.is_none() { - return Err(anyhow!( - "Either HF_TOKEN or GOOGLE_APPLICATION_CREDENTIALS environment variable must be set for checkpoint uploads" - )); - } if self.keep_steps == 0 { bail!( diff --git a/shared/coordinator/src/model.rs b/shared/coordinator/src/model.rs index 538a8f395..ff7f6ea93 100644 --- a/shared/coordinator/src/model.rs +++ b/shared/coordinator/src/model.rs @@ -292,7 +292,7 @@ impl GcsRepo { #[repr(C)] pub enum Checkpoint { Ephemeral, - Dummy(HubRepo), + Dummy(HubRepo), // Used for testing Hub(HubRepo), P2P(HubRepo), Gcs(GcsRepo), From de7e19d57cdec1e97cd3f868d14df3411fcc8c43 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Mon, 19 Jan 2026 11:38:32 -0300 Subject: [PATCH 59/72] Fix centralized config --- config/llama2-20m-dolma-noverify-no-checkpointer/state.toml | 2 +- shared/coordinator/src/coordinator.rs | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/config/llama2-20m-dolma-noverify-no-checkpointer/state.toml b/config/llama2-20m-dolma-noverify-no-checkpointer/state.toml index 71ab21789..1b4f00c8f 100644 --- a/config/llama2-20m-dolma-noverify-no-checkpointer/state.toml +++ b/config/llama2-20m-dolma-noverify-no-checkpointer/state.toml @@ -24,7 +24,7 @@ max_seq_len = 2048 cold_start_warmup_steps = 0 [model.LLM.data_location] Server = "127.0.0.1:20001" -[model.LLM.checkpoint.Hub] +[model.LLM.checkpoint.Dummy] repo_id = "emozilla/llama2-20m-init" [model.LLM.lr_schedule.Cosine] base_lr = 4.0e-4 diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index f0084e0e7..bf342ba8d 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -630,9 +630,6 @@ impl Coordinator { return Err(CoordinatorError::InvalidCommitteeProof); } - // TODO: In the case of more than one checkpointer, this will overwrite the checkpoint - // with the last checkpointed one. We could instead have a vector of checkpoints to have - // more download options. let Model::LLM(llm) = &mut self.model; match (&llm.checkpoint, checkpoint_repo) { // If current is P2P, wrap the new checkpoint in P2P From 95c8a83ab40678d8efc0d13b2813bd5dcc233d97 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Mon, 19 Jan 2026 12:09:11 -0800 Subject: [PATCH 60/72] Add new NoUpload to avoid checkpointing in tests --- architectures/centralized/server/src/app.rs | 5 ++-- architectures/decentralized/justfile | 10 ++----- .../decentralized/solana-client/src/main.rs | 5 +++- .../decentralized/testing/src/docker_setup.rs | 30 +++++++++---------- config/solana-test/light-config-gcs.toml | 2 +- config/solana-test/light-config.toml | 2 +- config/solana-test/nano-config.toml | 4 +-- scripts/train-solana-test.sh | 15 ++-------- shared/client/src/state/init.rs | 10 +++++-- shared/coordinator/src/coordinator.rs | 10 ++----- shared/coordinator/src/model.rs | 29 +++++++++++++----- .../src/commands/run/update_config.rs | 1 + 12 files changed, 62 insertions(+), 61 deletions(-) diff --git a/architectures/centralized/server/src/app.rs b/architectures/centralized/server/src/app.rs index 4df2ca662..3c2ce2e6c 100644 --- a/architectures/centralized/server/src/app.rs +++ b/architectures/centralized/server/src/app.rs @@ -184,7 +184,8 @@ impl App { }) => { if let LLMTrainingDataLocation::Server(url) = data_location { match checkpoint { - Checkpoint::Hub(hub_repo) => { + Checkpoint::Hub(hub_repo) + | Checkpoint::NoUploadHubRepo(hub_repo) => { let repo_id = String::from(&hub_repo.repo_id); let revision = hub_repo.revision.map(|bytes| (&bytes).into()); if revision.is_some() @@ -205,7 +206,7 @@ impl App { Checkpoint::P2P(_) | Checkpoint::P2PGcs(_) => { bail!("Can't start up a run with a P2P checkpoint.") } - Checkpoint::Gcs(gcs_repo) => { + Checkpoint::Gcs(gcs_repo) | Checkpoint::NoUploadGcs(gcs_repo) => { let bucket: String = (&gcs_repo.bucket).into(); let prefix: Option = gcs_repo.prefix.map(|p| (&p).into()); diff --git a/architectures/decentralized/justfile b/architectures/decentralized/justfile index 92b272f31..eabde7209 100644 --- a/architectures/decentralized/justfile +++ b/architectures/decentralized/justfile @@ -39,16 +39,10 @@ setup-solana-localnet-permissioned-light-test-run-treasurer run_id="test" *args= RUN_ID={{ run_id }} CONFIG_FILE=./config/solana-test/light-config.toml ./scripts/deploy-solana-test.sh --treasurer {{ args }} start-training-localnet-client run_id="test" *args='': - AUTHORIZER={{ AUTHORIZER }} RUN_ID={{ run_id }} ./scripts/train-solana-test.sh {{ args }} + HF_TOKEN={{ HF_TOKEN }} GOOGLE_APPLICATION_CREDENTIALS={{ GOOGLE_APPLICATION_CREDENTIALS }} AUTHORIZER={{ AUTHORIZER }} RUN_ID={{ run_id }} ./scripts/train-solana-test.sh {{ args }} start-training-localnet-light-client run_id="test" *args='': - AUTHORIZER={{ AUTHORIZER }} RUN_ID={{ run_id }} BATCH_SIZE=1 DP=1 ./scripts/train-solana-test.sh {{ args }} - -start-training-localnet-light-client-checkpoint run_id="test" *args='': - HF_TOKEN={{ HF_TOKEN }} GOOGLE_APPLICATION_CREDENTIALS={{ GOOGLE_APPLICATION_CREDENTIALS }} AUTHORIZER={{ AUTHORIZER }} RUN_ID={{ run_id }} BATCH_SIZE=1 DP=1 ./scripts/train-solana-test.sh --checkpoint {{ args }} - -start-training-localnet-client-checkpoint run_id="test" *args='': - HF_TOKEN={{ HF_TOKEN }} GOOGLE_APPLICATION_CREDENTIALS={{ GOOGLE_APPLICATION_CREDENTIALS }} AUTHORIZER={{ AUTHORIZER }} RUN_ID={{ run_id }} ./scripts/train-solana-test.sh --checkpoint {{ args }} + HF_TOKEN={{ HF_TOKEN }} GOOGLE_APPLICATION_CREDENTIALS={{ GOOGLE_APPLICATION_CREDENTIALS }} AUTHORIZER={{ AUTHORIZER }} RUN_ID={{ run_id }} BATCH_SIZE=1 DP=1 ./scripts/train-solana-test.sh {{ args }} OTLP_METRICS_URL := "http://localhost:4318/v1/metrics" OTLP_LOGS_URL := "http://localhost:4318/v1/logs" diff --git a/architectures/decentralized/solana-client/src/main.rs b/architectures/decentralized/solana-client/src/main.rs index 4d8b3bb5d..8a4d606e4 100644 --- a/architectures/decentralized/solana-client/src/main.rs +++ b/architectures/decentralized/solana-client/src/main.rs @@ -289,6 +289,7 @@ async fn async_main() -> Result<()> { } Checkpoint::Dummy(hub_repo) | Checkpoint::Hub(hub_repo) + | Checkpoint::NoUploadHubRepo(hub_repo) | Checkpoint::P2P(hub_repo) => { let repo_id = hub_repo.repo_id.to_string(); let revision = hub_repo.revision.map(|s| s.to_string()); @@ -310,7 +311,9 @@ async fn async_main() -> Result<()> { ) .await?; } - Checkpoint::Gcs(gcs_repo) | Checkpoint::P2PGcs(gcs_repo) => { + Checkpoint::Gcs(gcs_repo) + | Checkpoint::NoUploadGcs(gcs_repo) + | Checkpoint::P2PGcs(gcs_repo) => { let bucket = gcs_repo.bucket.to_string(); let prefix: Option = gcs_repo.prefix.map(|p| p.to_string()); println!( diff --git a/architectures/decentralized/testing/src/docker_setup.rs b/architectures/decentralized/testing/src/docker_setup.rs index 4f5a80dd9..c978aadc5 100644 --- a/architectures/decentralized/testing/src/docker_setup.rs +++ b/architectures/decentralized/testing/src/docker_setup.rs @@ -39,21 +39,21 @@ pub const VALIDATOR_CONTAINER_PREFIX: &str = "test-psyche-solana-test-validator" pub const NGINX_PROXY_PREFIX: &str = "nginx-proxy"; pub struct DockerTestCleanup; -impl Drop for DockerTestCleanup { - fn drop(&mut self) { - println!("\nStopping containers..."); - let output = Command::new("just") - .args(["stop_test_infra"]) - .stdout(Stdio::inherit()) - .stderr(Stdio::inherit()) - .output() - .expect("Failed stop docker compose instances"); - - if !output.status.success() { - panic!("Error: {}", String::from_utf8_lossy(&output.stderr)); - } - } -} +// impl Drop for DockerTestCleanup { +// fn drop(&mut self) { +// println!("\nStopping containers..."); +// let output = Command::new("just") +// .args(["stop_test_infra"]) +// .stdout(Stdio::inherit()) +// .stderr(Stdio::inherit()) +// .output() +// .expect("Failed stop docker compose instances"); + +// if !output.status.success() { +// panic!("Error: {}", String::from_utf8_lossy(&output.stderr)); +// } +// } +// } /// FIXME: The config path must be relative to the compose file for now. pub async fn e2e_testing_setup( diff --git a/config/solana-test/light-config-gcs.toml b/config/solana-test/light-config-gcs.toml index d96fd468f..65201ebe5 100644 --- a/config/solana-test/light-config-gcs.toml +++ b/config/solana-test/light-config-gcs.toml @@ -20,7 +20,7 @@ data_type = "Pretraining" max_seq_len = 2048 cold_start_warmup_steps = 0 -[model.LLM.checkpoint.Gcs] +[model.LLM.checkpoint.NoUploadGcs] bucket = "llama220minit" [model.LLM.data_location.Http] diff --git a/config/solana-test/light-config.toml b/config/solana-test/light-config.toml index 71228622f..d6aff5df7 100644 --- a/config/solana-test/light-config.toml +++ b/config/solana-test/light-config.toml @@ -20,7 +20,7 @@ data_type = "Pretraining" max_seq_len = 2048 cold_start_warmup_steps = 0 -[model.LLM.checkpoint.Dummy] +[model.LLM.checkpoint.NoUploadHubRepo] repo_id = "emozilla/llama2-20m-init" [model.LLM.data_location.Http] diff --git a/config/solana-test/nano-config.toml b/config/solana-test/nano-config.toml index 19f7b8acc..2a0b02802 100644 --- a/config/solana-test/nano-config.toml +++ b/config/solana-test/nano-config.toml @@ -1,7 +1,7 @@ [config] warmup_time = 50 -cooldown_time = 30 epoch_time = 60 +cooldown_time = 30 max_round_train_time = 15 round_witness_time = 1 min_clients = 1 @@ -20,7 +20,7 @@ data_type = "Pretraining" max_seq_len = 64 cold_start_warmup_steps = 0 -[model.LLM.checkpoint.Dummy] +[model.LLM.checkpoint.NoUploadHubRepo] repo_id = "pefontana/Nano-Llama" revision = "cf48eac4944f6e954a3d9c9c30e8c865e64e7d03" diff --git a/scripts/train-solana-test.sh b/scripts/train-solana-test.sh index dc6c95327..9f06cd1d1 100755 --- a/scripts/train-solana-test.sh +++ b/scripts/train-solana-test.sh @@ -2,17 +2,6 @@ set -eo pipefail -CHECKPOINT=false -EXTRA_ARGS=() -# Parse command line arguments -for arg in "$@"; do - if [[ "$arg" == "--checkpoint" ]]; then - CHECKPOINT=true - else - EXTRA_ARGS+=("$arg") - fi -done - # use the agenix provided wallet if you have it if [[ -n "${devnet__keypair__wallet_PATH}" && -f "${devnet__keypair__wallet_PATH}" ]]; then WALLET_FILE="${devnet__keypair__wallet_PATH}" @@ -66,7 +55,7 @@ if [[ "$OTLP_METRICS_URL" == "" ]]; then --micro-batch-size ${BATCH_SIZE} \ --authorizer ${AUTHORIZER} \ --logs "console" \ - "${EXTRA_ARGS[@]}" + "$@" else HF_TOKEN=${HF_TOKEN} GOOGLE_APPLICATION_CREDENTIALS=${GOOGLE_APPLICATION_CREDENTIALS} cargo run --release --bin psyche-solana-client -- \ train \ @@ -81,5 +70,5 @@ else --authorizer ${AUTHORIZER} \ --oltp-metrics-url "http://localhost:4318/v1/metrics" \ --oltp-logs-url "http://localhost:4318/v1/logs" \ - "${EXTRA_ARGS[@]}" + "$@" fi diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index d2c3a042e..9f79a2dcf 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -316,11 +316,14 @@ impl RunInitConfigAndIO { + | model::Checkpoint::Gcs(_) + | model::Checkpoint::NoUploadHubRepo(_) + | model::Checkpoint::NoUploadGcs(_) => { let checkpoint = llm.checkpoint; tokio::spawn(async move { let (source, tokenizer, checkpoint_extra_files) = match checkpoint { - model::Checkpoint::Hub(hub_repo) => { + model::Checkpoint::Hub(hub_repo) + | model::Checkpoint::NoUploadHubRepo(hub_repo) => { let repo_id: String = (&hub_repo.repo_id).into(); let potential_local_path = PathBuf::from(repo_id.clone()); let revision = hub_repo.revision.map(|bytes| (&bytes).into()); @@ -432,7 +435,8 @@ impl RunInitConfigAndIO { + model::Checkpoint::Gcs(gcs_repo) + | model::Checkpoint::NoUploadGcs(gcs_repo) => { let bucket: String = (&gcs_repo.bucket).into(); let prefix: Option = gcs_repo.prefix.map(|p| (&p).into()); diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index bf342ba8d..07cd3c526 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -633,10 +633,10 @@ impl Coordinator { let Model::LLM(llm) = &mut self.model; match (&llm.checkpoint, checkpoint_repo) { // If current is P2P, wrap the new checkpoint in P2P - (Checkpoint::P2P(_), Checkpoint::Hub(hub_repo)) => { + (Checkpoint::P2P(_) | Checkpoint::P2PGcs(_), Checkpoint::Hub(hub_repo)) => { llm.checkpoint = Checkpoint::P2P(hub_repo); } - (Checkpoint::P2PGcs(_), Checkpoint::Gcs(gcs_repo)) => { + (Checkpoint::P2P(_) | Checkpoint::P2PGcs(_), Checkpoint::Gcs(gcs_repo)) => { llm.checkpoint = Checkpoint::P2PGcs(gcs_repo); } // If current is Hub, only accept Hub updates @@ -647,12 +647,6 @@ impl Coordinator { (Checkpoint::Gcs(_), Checkpoint::Gcs(gcs_repo)) => { llm.checkpoint = Checkpoint::Gcs(gcs_repo); } - (Checkpoint::P2PGcs(_), Checkpoint::Hub(hub_repo)) => { - llm.checkpoint = Checkpoint::P2P(hub_repo); - } - (Checkpoint::P2P(_), Checkpoint::Gcs(gcs_repo)) => { - llm.checkpoint = Checkpoint::P2PGcs(gcs_repo); - } // Ignore other combinations _ => {} } diff --git a/shared/coordinator/src/model.rs b/shared/coordinator/src/model.rs index ff7f6ea93..426f34e08 100644 --- a/shared/coordinator/src/model.rs +++ b/shared/coordinator/src/model.rs @@ -294,9 +294,11 @@ pub enum Checkpoint { Ephemeral, Dummy(HubRepo), // Used for testing Hub(HubRepo), - P2P(HubRepo), Gcs(GcsRepo), + P2P(HubRepo), P2PGcs(GcsRepo), + NoUploadHubRepo(HubRepo), // Load from Hub, save locally, skip upload + NoUploadGcs(GcsRepo), // Load from GCS, save locally, skip upload } impl std::fmt::Display for Checkpoint { @@ -305,12 +307,23 @@ impl std::fmt::Display for Checkpoint { Checkpoint::Dummy(_hub_repo) => write!(f, "Dummy"), Checkpoint::Ephemeral => write!(f, "Ephemeral"), Checkpoint::Hub(hub_repo) => write!(f, "{}", &hub_repo.repo_id), + Checkpoint::Gcs(gcs_repo) => match &gcs_repo.prefix { + Some(prefix) => write!(f, "gs://{}/{}", &gcs_repo.bucket, prefix), + None => write!(f, "gs://{}", &gcs_repo.bucket), + }, Checkpoint::P2P(hub_repo) => { write!(f, "P2P - Hub repo: {}", &hub_repo.repo_id) } - Checkpoint::Gcs(gcs_repo) | Checkpoint::P2PGcs(gcs_repo) => match &gcs_repo.prefix { - Some(prefix) => write!(f, "gs://{}/{}", &gcs_repo.bucket, prefix), - None => write!(f, "gs://{}", &gcs_repo.bucket), + Checkpoint::P2PGcs(gcs_repo) => match &gcs_repo.prefix { + Some(prefix) => write!(f, "P2P - gs://{}/{}", &gcs_repo.bucket, prefix), + None => write!(f, "P2P - gs://{}", &gcs_repo.bucket), + }, + Checkpoint::NoUploadHubRepo(hub_repo) => { + write!(f, "NoUpload - Hub repo: {}", &hub_repo.repo_id) + } + Checkpoint::NoUploadGcs(gcs_repo) => match &gcs_repo.prefix { + Some(prefix) => write!(f, "NoUpload - gs://{}/{}", &gcs_repo.bucket, prefix), + None => write!(f, "NoUpload - gs://{}", &gcs_repo.bucket), }, } } @@ -350,9 +363,11 @@ impl Model { let bad_checkpoint = match llm.checkpoint { Checkpoint::Dummy(_hub_repo) => false, Checkpoint::Ephemeral => true, - Checkpoint::Hub(hub_repo) => hub_repo.repo_id.is_empty(), - Checkpoint::P2P(hub_repo) => hub_repo.repo_id.is_empty(), - Checkpoint::Gcs(gcs_repo) | Checkpoint::P2PGcs(gcs_repo) => { + Checkpoint::P2P(_) | Checkpoint::P2PGcs(_) => true, // P2P is internal state, not configurable + Checkpoint::Hub(hub_repo) | Checkpoint::NoUploadHubRepo(hub_repo) => { + hub_repo.repo_id.is_empty() + } + Checkpoint::Gcs(gcs_repo) | Checkpoint::NoUploadGcs(gcs_repo) => { gcs_repo.bucket.is_empty() } }; diff --git a/tools/rust-tools/run-manager/src/commands/run/update_config.rs b/tools/rust-tools/run-manager/src/commands/run/update_config.rs index 641577307..80abc1fac 100644 --- a/tools/rust-tools/run-manager/src/commands/run/update_config.rs +++ b/tools/rust-tools/run-manager/src/commands/run/update_config.rs @@ -95,6 +95,7 @@ impl Command for CommandUpdateConfig { Checkpoint::P2P(hub_repo) | Checkpoint::Dummy(hub_repo) => { llm.checkpoint = Checkpoint::Hub(hub_repo) } + Checkpoint::P2PGcs(gcs_repo) => llm.checkpoint = Checkpoint::Gcs(gcs_repo), _ => {} } Some(Model::LLM(llm)) From 49669ffc49a9915728c4f810eeb1997ce6710db7 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Mon, 19 Jan 2026 12:43:59 -0800 Subject: [PATCH 61/72] Fix train solana test script --- architectures/decentralized/solana-client/src/app.rs | 2 +- scripts/train-solana-test.sh | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/architectures/decentralized/solana-client/src/app.rs b/architectures/decentralized/solana-client/src/app.rs index dd02659e3..98e55d500 100644 --- a/architectures/decentralized/solana-client/src/app.rs +++ b/architectures/decentralized/solana-client/src/app.rs @@ -294,7 +294,7 @@ impl App { .all(|p| response.permissions.contains(&p.to_string())); if !correct_permissions { anyhow::bail!( - "GCS bucket {} does not have the required permissions for checkpoint upload make sure to set GOOGLE_APPLICATION_CREDENTIALS environment variable correctly.", + "GCS bucket {} does not have the required permissions for checkpoint upload make sure to set GOOGLE_APPLICATION_CREDENTIALS environment variable correctly and have the correct permissions to the bucket.", gcs_info.gcs_bucket ) } diff --git a/scripts/train-solana-test.sh b/scripts/train-solana-test.sh index 9f06cd1d1..c600e501d 100755 --- a/scripts/train-solana-test.sh +++ b/scripts/train-solana-test.sh @@ -20,8 +20,8 @@ elif [[ -z "${WALLET_FILE:-}" ]]; then trap "echo 'Cleaning up ephemeral wallet file...'; rm -f '${WALLET_FILE}'" EXIT fi -RPC=${RPC:-"http://7da1cfaf-50:8899"} -WS_RPC=${WS_RPC:-"ws://7da1cfaf-50:8900"} +RPC=${RPC:-"http://localhost:8899"} +WS_RPC=${WS_RPC:-"ws://localhost:8900"} RUN_ID=${RUN_ID:-"test"} AUTHORIZER=${AUTHORIZER:-"11111111111111111111111111111111"} From e048b255e934b21835dc670dc1c17c94f5b10528 Mon Sep 17 00:00:00 2001 From: Pedro Fontana Date: Mon, 19 Jan 2026 12:01:39 -0800 Subject: [PATCH 62/72] CooldownStep.checkpoint_complete --- shared/client/src/state/cooldown.rs | 14 +++++++++++++- shared/client/src/state/steps.rs | 2 +- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index ee4c9d9cb..a7c3865d3 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -16,7 +16,10 @@ use std::{ cmp::Reverse, collections::{BinaryHeap, HashMap}, path::PathBuf, - sync::Arc, + sync::{ + Arc, + atomic::{AtomicBool, Ordering}, + }, }; use tch::Tensor; use thiserror::Error; @@ -151,10 +154,12 @@ impl CooldownStepMetadata { let is_checkpointer = checkpointer_selection .is_checkpointer(client_index, state.epoch_state.clients.len() as u64); let cancellation_token = tokio_util::sync::CancellationToken::new(); + let checkpoint_completed = Arc::new(AtomicBool::new(false)); let checkpointing_and_evals: JoinHandle> = tokio::task::spawn({ let cancellation_token = cancellation_token.clone(); + let checkpoint_completed = checkpoint_completed.clone(); async move { info!("Extracting full model..."); let (variables, trainer) = @@ -253,6 +258,7 @@ impl CooldownStepMetadata { ) .await; + checkpoint_completed.store(true, Ordering::SeqCst); Ok(evals) } .instrument(info_span!("checkpointing")) @@ -261,6 +267,7 @@ impl CooldownStepMetadata { Ok(CooldownStep { checkpointing_and_evals, cancellation_token, + checkpoint_completed, }) } } @@ -316,6 +323,7 @@ async fn upload_checkpoint( pub struct CooldownStep { checkpointing_and_evals: JoinHandle>, cancellation_token: tokio_util::sync::CancellationToken, + checkpoint_completed: Arc, } impl CooldownStep { @@ -335,4 +343,8 @@ impl CooldownStep { pub fn is_finished(&self) -> bool { self.checkpointing_and_evals.is_finished() } + + pub fn checkpoint_complete(&self) -> bool { + self.checkpoint_completed.load(Ordering::SeqCst) + } } diff --git a/shared/client/src/state/steps.rs b/shared/client/src/state/steps.rs index c425a1756..154a73e8b 100644 --- a/shared/client/src/state/steps.rs +++ b/shared/client/src/state/steps.rs @@ -343,7 +343,7 @@ impl StepStateMachine Date: Tue, 20 Jan 2026 06:14:00 -0800 Subject: [PATCH 63/72] Fix crash after credential errors --- shared/client/src/state/cooldown.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index a7c3865d3..ccef633ba 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -244,8 +244,13 @@ impl CooldownStepMetadata { epoch, run_id: run_id.clone(), }; - upload_checkpoint(upload_info, manifest_metadata, local.clone(), step as u64, cancellation_token.clone()) - .await?; + let result = upload_checkpoint(upload_info, manifest_metadata, local.clone(), step as u64, cancellation_token.clone()) + .await; + if let Err(err) = result { + error!("Error uploading checkpoint: {}", err); + } else { + checkpoint_completed.store(true, Ordering::SeqCst); + } } cleanup_dirs( @@ -258,7 +263,6 @@ impl CooldownStepMetadata { ) .await; - checkpoint_completed.store(true, Ordering::SeqCst); Ok(evals) } .instrument(info_span!("checkpointing")) From 5ff31848f6888976394c94c637bb7698ad0bccbd Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Tue, 20 Jan 2026 08:11:14 -0800 Subject: [PATCH 64/72] Update only safetensors for cehckpointing --- architectures/decentralized/solana-client/src/app.rs | 2 -- shared/client/src/cli.rs | 10 +++++++++- shared/data-provider/src/gcs.rs | 5 ++++- shared/data-provider/src/hub.rs | 7 +++++-- 4 files changed, 18 insertions(+), 6 deletions(-) diff --git a/architectures/decentralized/solana-client/src/app.rs b/architectures/decentralized/solana-client/src/app.rs index 98e55d500..e4e384286 100644 --- a/architectures/decentralized/solana-client/src/app.rs +++ b/architectures/decentralized/solana-client/src/app.rs @@ -271,8 +271,6 @@ impl App { let client = StorageControl::builder().build().await?; let permissions_to_test = vec![ - "storage.buckets.get", - "storage.buckets.getIamPolicy", "storage.objects.list", "storage.objects.get", "storage.objects.create", diff --git a/shared/client/src/cli.rs b/shared/client/src/cli.rs index 7c74ed641..5d61591b0 100644 --- a/shared/client/src/cli.rs +++ b/shared/client/src/cli.rs @@ -7,6 +7,7 @@ use psyche_modeling::Devices; use psyche_network::{DiscoveryMode, RelayKind, SecretKey}; use psyche_tui::LogOutput; use std::{path::PathBuf, time::Duration}; +use tracing::info; pub fn read_identity_secret_key( identity_secret_key_path: Option<&PathBuf>, @@ -139,7 +140,7 @@ pub struct TrainArgs { pub prompt_task: bool, /// If provided, every model parameters update will be save in this directory after each epoch. - #[clap(long, env, default_value = "~/.cache/psyche/checkpoints")] + #[clap(long, env, default_value_os_t = default_checkpoint_dir())] pub checkpoint_dir: PathBuf, #[clap(long, env, default_value_t = 3)] @@ -263,6 +264,13 @@ impl TrainArgs { } } +fn default_checkpoint_dir() -> PathBuf { + let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string()); + let final_dir = PathBuf::from(home).join(".cache/psyche/local_checkpoints"); + info!("Default checkpoint directory set to {:?}", final_dir); + final_dir +} + pub fn prepare_environment() { psyche_modeling::set_suggested_env_vars(); diff --git a/shared/data-provider/src/gcs.rs b/shared/data-provider/src/gcs.rs index 4345f64a5..01e9fec08 100644 --- a/shared/data-provider/src/gcs.rs +++ b/shared/data-provider/src/gcs.rs @@ -347,7 +347,10 @@ pub async fn upload_to_gcs( files: Vec::new(), }; - for path in local { + for path in local + .iter() + .filter(|p| p.extension() == Some("safetensors".as_ref())) + { if cancellation_token.is_cancelled() { info!("Upload cancelled before uploading {}", path.display()); return Ok(()); diff --git a/shared/data-provider/src/hub.rs b/shared/data-provider/src/hub.rs index 8d67a552c..a572a2ed7 100644 --- a/shared/data-provider/src/hub.rs +++ b/shared/data-provider/src/hub.rs @@ -215,7 +215,10 @@ pub async fn upload_to_hub( let repo = Repo::model(hub_repo.clone()); let api_repo = api.repo(repo); - for path in local { + for path in local + .iter() + .filter(|p| p.extension() == Some("safetensors".as_ref())) + { if cancellation_token.is_cancelled() { info!(repo = hub_repo, "Upload to HuggingFace cancelled"); return Ok(()); @@ -229,7 +232,7 @@ pub async fn upload_to_hub( .to_string(); let upload_future = api_repo.upload_files( - vec![(path.into(), file_name.clone())], + vec![(path.clone().into(), file_name.clone())], Some(format!("step {step}")), None, false, From a8444b2a794245e7f1374e9b0201f5f41b703892 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Tue, 20 Jan 2026 14:48:45 -0300 Subject: [PATCH 65/72] Refactor hub repo upload --- shared/data-provider/src/hub.rs | 116 ++++++++++++++++---------------- 1 file changed, 59 insertions(+), 57 deletions(-) diff --git a/shared/data-provider/src/hub.rs b/shared/data-provider/src/hub.rs index a572a2ed7..1cc15f2ab 100644 --- a/shared/data-provider/src/hub.rs +++ b/shared/data-provider/src/hub.rs @@ -1,4 +1,5 @@ use crate::errors::UploadError; +use futures::future::try_join_all; use hf_hub::{ Cache, Repo, RepoType, api::{Siblings, tokio::ApiError}, @@ -51,26 +52,22 @@ async fn download_repo_async( .collect::>(); let mut ret: Vec = Vec::new(); for chunk in siblings.chunks(max_concurrent_downloads.unwrap_or(siblings.len())) { - let futures = chunk - .iter() - .map(|x| async { - let start_time = Instant::now(); - tracing::debug!(filename = x.rfilename, "Starting file download from hub"); - let res = api.get(&x.rfilename).await; - if res.is_ok() { - let duration_secs = (Instant::now() - start_time).as_secs_f32(); - tracing::info!( - filename = x.rfilename, - duration_secs = duration_secs, - "Finished downloading file from hub" - ); - } - res - }) - .collect::>(); - for future in futures { - ret.push(future.await?); - } + let futures = chunk.iter().map(|x| async { + let start_time = Instant::now(); + tracing::debug!(filename = x.rfilename, "Starting file download from hub"); + let res = api.get(&x.rfilename).await; + if res.is_ok() { + let duration_secs = (Instant::now() - start_time).as_secs_f32(); + tracing::info!( + filename = x.rfilename, + duration_secs = duration_secs, + "Finished downloading file from hub" + ); + } + res + }); + let chunk_results = try_join_all(futures).await?; + ret.extend(chunk_results); } Ok(ret) } @@ -207,7 +204,35 @@ pub async fn upload_to_hub( return Ok(()); } - info!(repo = hub_repo, "Uploading checkpoint to HuggingFace"); + // Collect all safetensors files to upload in a single commit + let files_to_upload: Vec<_> = local + .iter() + .filter(|p| p.extension() == Some("safetensors".as_ref())) + .map(|path| -> Result<_, UploadError> { + let file_name = path + .file_name() + .ok_or_else(|| UploadError::NotAFile(path.clone()))? + .to_str() + .ok_or_else(|| UploadError::InvalidFilename(path.clone()))? + .to_string(); + Ok((path.clone().into(), file_name)) + }) + .collect::, _>>()?; + + if files_to_upload.is_empty() { + info!(repo = hub_repo, "No safetensors files to upload"); + return Ok(()); + } + + let file_names: Vec<_> = files_to_upload + .iter() + .map(|(_, name)| name.clone()) + .collect(); + info!( + repo = hub_repo, + file_count = files_to_upload.len(), + "Uploading checkpoint to HuggingFace" + ); let api = hf_hub::api::tokio::ApiBuilder::new() .with_token(Some(hub_token)) @@ -215,48 +240,25 @@ pub async fn upload_to_hub( let repo = Repo::model(hub_repo.clone()); let api_repo = api.repo(repo); - for path in local - .iter() - .filter(|p| p.extension() == Some("safetensors".as_ref())) - { - if cancellation_token.is_cancelled() { + let upload_future = + api_repo.upload_files(files_to_upload, Some(format!("step {step}")), None, false); + + tokio::select! { + biased; + + _ = cancellation_token.cancelled() => { info!(repo = hub_repo, "Upload to HuggingFace cancelled"); return Ok(()); } - - let file_name = path - .file_name() - .ok_or_else(|| UploadError::NotAFile(path.clone()))? - .to_str() - .ok_or_else(|| UploadError::InvalidFilename(path.clone()))? - .to_string(); - - let upload_future = api_repo.upload_files( - vec![(path.clone().into(), file_name.clone())], - Some(format!("step {step}")), - None, - false, - ); - - tokio::select! { - biased; - - _ = cancellation_token.cancelled() => { - info!(repo = hub_repo, file = file_name, "Upload cancelled"); - return Ok(()); - } - result = upload_future => { - result.map_err(|e| { - error!(repo = hub_repo, error = ?e, "Failed to upload file"); - e - })?; - } + result = upload_future => { + result.map_err(|e| { + error!(repo = hub_repo, error = ?e, "Failed to upload files"); + e + })?; } - - info!(repo = hub_repo, file = file_name, "Uploaded file"); } - info!(repo = hub_repo, "Upload to HuggingFace complete"); + info!(repo = hub_repo, files = ?file_names, "Upload to HuggingFace complete"); Ok(()) } From b6802b57dd45844e3916b6d509657657af9444b8 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Tue, 20 Jan 2026 09:51:34 -0800 Subject: [PATCH 66/72] Uncomment docker cleanup --- .../decentralized/testing/src/docker_setup.rs | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/architectures/decentralized/testing/src/docker_setup.rs b/architectures/decentralized/testing/src/docker_setup.rs index c978aadc5..4f5a80dd9 100644 --- a/architectures/decentralized/testing/src/docker_setup.rs +++ b/architectures/decentralized/testing/src/docker_setup.rs @@ -39,21 +39,21 @@ pub const VALIDATOR_CONTAINER_PREFIX: &str = "test-psyche-solana-test-validator" pub const NGINX_PROXY_PREFIX: &str = "nginx-proxy"; pub struct DockerTestCleanup; -// impl Drop for DockerTestCleanup { -// fn drop(&mut self) { -// println!("\nStopping containers..."); -// let output = Command::new("just") -// .args(["stop_test_infra"]) -// .stdout(Stdio::inherit()) -// .stderr(Stdio::inherit()) -// .output() -// .expect("Failed stop docker compose instances"); - -// if !output.status.success() { -// panic!("Error: {}", String::from_utf8_lossy(&output.stderr)); -// } -// } -// } +impl Drop for DockerTestCleanup { + fn drop(&mut self) { + println!("\nStopping containers..."); + let output = Command::new("just") + .args(["stop_test_infra"]) + .stdout(Stdio::inherit()) + .stderr(Stdio::inherit()) + .output() + .expect("Failed stop docker compose instances"); + + if !output.status.success() { + panic!("Error: {}", String::from_utf8_lossy(&output.stderr)); + } + } +} /// FIXME: The config path must be relative to the compose file for now. pub async fn e2e_testing_setup( From da2fab662265fc9877ab1ace9236ec218c0bdd8e Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Tue, 20 Jan 2026 16:52:35 -0300 Subject: [PATCH 67/72] Remove NoCheckpoint variant and use a client flag instead --- architectures/centralized/server/src/app.rs | 5 +- .../centralized/testing/src/test_utils.rs | 1 + architectures/decentralized/justfile | 10 +++- .../decentralized/solana-client/src/main.rs | 5 +- config/solana-test/light-config-gcs.toml | 2 +- config/solana-test/light-config.toml | 2 +- config/solana-test/nano-config.toml | 2 +- scripts/train-solana-test.sh | 2 + shared/client/src/cli.rs | 5 ++ shared/client/src/state/cooldown.rs | 54 ++++++++++--------- shared/client/src/state/init.rs | 10 ++-- shared/client/src/state/types.rs | 3 ++ shared/coordinator/src/model.rs | 17 +----- 13 files changed, 60 insertions(+), 58 deletions(-) diff --git a/architectures/centralized/server/src/app.rs b/architectures/centralized/server/src/app.rs index 3c2ce2e6c..4df2ca662 100644 --- a/architectures/centralized/server/src/app.rs +++ b/architectures/centralized/server/src/app.rs @@ -184,8 +184,7 @@ impl App { }) => { if let LLMTrainingDataLocation::Server(url) = data_location { match checkpoint { - Checkpoint::Hub(hub_repo) - | Checkpoint::NoUploadHubRepo(hub_repo) => { + Checkpoint::Hub(hub_repo) => { let repo_id = String::from(&hub_repo.repo_id); let revision = hub_repo.revision.map(|bytes| (&bytes).into()); if revision.is_some() @@ -206,7 +205,7 @@ impl App { Checkpoint::P2P(_) | Checkpoint::P2PGcs(_) => { bail!("Can't start up a run with a P2P checkpoint.") } - Checkpoint::Gcs(gcs_repo) | Checkpoint::NoUploadGcs(gcs_repo) => { + Checkpoint::Gcs(gcs_repo) => { let bucket: String = (&gcs_repo.bucket).into(); let prefix: Option = gcs_repo.prefix.map(|p| (&p).into()); diff --git a/architectures/centralized/testing/src/test_utils.rs b/architectures/centralized/testing/src/test_utils.rs index ddaf9fab2..ae835da36 100644 --- a/architectures/centralized/testing/src/test_utils.rs +++ b/architectures/centralized/testing/src/test_utils.rs @@ -142,6 +142,7 @@ pub fn dummy_client_app_params_with_training_delay( "--max-concurrent-parameter-requests", "10", "--hub-max-concurrent-downloads", "1", "--dummy-training-delay-secs", training_delay_secs.to_string().as_str(), + "--skip-checkpoint-upload", ]) .train_args, } diff --git a/architectures/decentralized/justfile b/architectures/decentralized/justfile index eabde7209..ad0b89a14 100644 --- a/architectures/decentralized/justfile +++ b/architectures/decentralized/justfile @@ -39,10 +39,16 @@ setup-solana-localnet-permissioned-light-test-run-treasurer run_id="test" *args= RUN_ID={{ run_id }} CONFIG_FILE=./config/solana-test/light-config.toml ./scripts/deploy-solana-test.sh --treasurer {{ args }} start-training-localnet-client run_id="test" *args='': - HF_TOKEN={{ HF_TOKEN }} GOOGLE_APPLICATION_CREDENTIALS={{ GOOGLE_APPLICATION_CREDENTIALS }} AUTHORIZER={{ AUTHORIZER }} RUN_ID={{ run_id }} ./scripts/train-solana-test.sh {{ args }} + HF_TOKEN={{ HF_TOKEN }} GOOGLE_APPLICATION_CREDENTIALS={{ GOOGLE_APPLICATION_CREDENTIALS }} AUTHORIZER={{ AUTHORIZER }} CHECKPOINT="false" RUN_ID={{ run_id }} ./scripts/train-solana-test.sh {{ args }} start-training-localnet-light-client run_id="test" *args='': - HF_TOKEN={{ HF_TOKEN }} GOOGLE_APPLICATION_CREDENTIALS={{ GOOGLE_APPLICATION_CREDENTIALS }} AUTHORIZER={{ AUTHORIZER }} RUN_ID={{ run_id }} BATCH_SIZE=1 DP=1 ./scripts/train-solana-test.sh {{ args }} + HF_TOKEN={{ HF_TOKEN }} GOOGLE_APPLICATION_CREDENTIALS={{ GOOGLE_APPLICATION_CREDENTIALS }} AUTHORIZER={{ AUTHORIZER }} CHECKPOINT="false" RUN_ID={{ run_id }} BATCH_SIZE=1 DP=1 ./scripts/train-solana-test.sh {{ args }} + +start-training-localnet-light-client-checkpoint run_id="test" *args='': + HF_TOKEN={{ HF_TOKEN }} GOOGLE_APPLICATION_CREDENTIALS={{ GOOGLE_APPLICATION_CREDENTIALS }} AUTHORIZER={{ AUTHORIZER }} CHECKPOINT="true" RUN_ID={{ run_id }} BATCH_SIZE=1 DP=1 ./scripts/train-solana-test.sh {{ args }} + +start-training-localnet-client-checkpoint run_id="test" *args='': + HF_TOKEN={{ HF_TOKEN }} GOOGLE_APPLICATION_CREDENTIALS={{ GOOGLE_APPLICATION_CREDENTIALS }} AUTHORIZER={{ AUTHORIZER }} CHECKPOINT="true" RUN_ID={{ run_id }} ./scripts/train-solana-test.sh {{ args }} OTLP_METRICS_URL := "http://localhost:4318/v1/metrics" OTLP_LOGS_URL := "http://localhost:4318/v1/logs" diff --git a/architectures/decentralized/solana-client/src/main.rs b/architectures/decentralized/solana-client/src/main.rs index 8a4d606e4..4d8b3bb5d 100644 --- a/architectures/decentralized/solana-client/src/main.rs +++ b/architectures/decentralized/solana-client/src/main.rs @@ -289,7 +289,6 @@ async fn async_main() -> Result<()> { } Checkpoint::Dummy(hub_repo) | Checkpoint::Hub(hub_repo) - | Checkpoint::NoUploadHubRepo(hub_repo) | Checkpoint::P2P(hub_repo) => { let repo_id = hub_repo.repo_id.to_string(); let revision = hub_repo.revision.map(|s| s.to_string()); @@ -311,9 +310,7 @@ async fn async_main() -> Result<()> { ) .await?; } - Checkpoint::Gcs(gcs_repo) - | Checkpoint::NoUploadGcs(gcs_repo) - | Checkpoint::P2PGcs(gcs_repo) => { + Checkpoint::Gcs(gcs_repo) | Checkpoint::P2PGcs(gcs_repo) => { let bucket = gcs_repo.bucket.to_string(); let prefix: Option = gcs_repo.prefix.map(|p| p.to_string()); println!( diff --git a/config/solana-test/light-config-gcs.toml b/config/solana-test/light-config-gcs.toml index 65201ebe5..d96fd468f 100644 --- a/config/solana-test/light-config-gcs.toml +++ b/config/solana-test/light-config-gcs.toml @@ -20,7 +20,7 @@ data_type = "Pretraining" max_seq_len = 2048 cold_start_warmup_steps = 0 -[model.LLM.checkpoint.NoUploadGcs] +[model.LLM.checkpoint.Gcs] bucket = "llama220minit" [model.LLM.data_location.Http] diff --git a/config/solana-test/light-config.toml b/config/solana-test/light-config.toml index d6aff5df7..eab015342 100644 --- a/config/solana-test/light-config.toml +++ b/config/solana-test/light-config.toml @@ -20,7 +20,7 @@ data_type = "Pretraining" max_seq_len = 2048 cold_start_warmup_steps = 0 -[model.LLM.checkpoint.NoUploadHubRepo] +[model.LLM.checkpoint.Hub] repo_id = "emozilla/llama2-20m-init" [model.LLM.data_location.Http] diff --git a/config/solana-test/nano-config.toml b/config/solana-test/nano-config.toml index 2a0b02802..0daeab279 100644 --- a/config/solana-test/nano-config.toml +++ b/config/solana-test/nano-config.toml @@ -20,7 +20,7 @@ data_type = "Pretraining" max_seq_len = 64 cold_start_warmup_steps = 0 -[model.LLM.checkpoint.NoUploadHubRepo] +[model.LLM.checkpoint.Hub] repo_id = "pefontana/Nano-Llama" revision = "cf48eac4944f6e954a3d9c9c30e8c865e64e7d03" diff --git a/scripts/train-solana-test.sh b/scripts/train-solana-test.sh index c600e501d..41c1654be 100755 --- a/scripts/train-solana-test.sh +++ b/scripts/train-solana-test.sh @@ -55,6 +55,7 @@ if [[ "$OTLP_METRICS_URL" == "" ]]; then --micro-batch-size ${BATCH_SIZE} \ --authorizer ${AUTHORIZER} \ --logs "console" \ + [[ "$CHECKPOINT" == "true" ]] && echo "--skip-checkpoint-upload" || echo "" \ "$@" else HF_TOKEN=${HF_TOKEN} GOOGLE_APPLICATION_CREDENTIALS=${GOOGLE_APPLICATION_CREDENTIALS} cargo run --release --bin psyche-solana-client -- \ @@ -70,5 +71,6 @@ else --authorizer ${AUTHORIZER} \ --oltp-metrics-url "http://localhost:4318/v1/metrics" \ --oltp-logs-url "http://localhost:4318/v1/logs" \ + [[ "$CHECKPOINT" == "true" ]] && echo "--skip-checkpoint-upload" || echo "" \ "$@" fi diff --git a/shared/client/src/cli.rs b/shared/client/src/cli.rs index 5d61591b0..69e8aa271 100644 --- a/shared/client/src/cli.rs +++ b/shared/client/src/cli.rs @@ -191,6 +191,10 @@ pub struct TrainArgs { #[clap(long, default_value_t = 3, env)] pub keep_steps: u32, + + /// Skip uploading checkpoints to Hub/GCS (for testing). Checkpoints are still saved locally. + #[clap(long, default_value_t = false, env, hide = true)] + pub skip_checkpoint_upload: bool, } impl TrainArgs { @@ -234,6 +238,7 @@ impl TrainArgs { delete_old_steps: self.delete_old_steps, keep_steps: self.keep_steps, hub_token, + skip_upload: self.skip_checkpoint_upload, }) } diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index ccef633ba..cb736d352 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -204,35 +204,41 @@ impl CooldownStepMetadata { delete_old_steps, keep_steps, hub_token, + skip_upload, } = checkpoint_info; - let upload_info = match checkpoint { - model::Checkpoint::Hub(HubRepo { - repo_id, - revision: _, - }) - | model::Checkpoint::P2P(HubRepo { - repo_id, - revision: _, - }) => { - if let Some(token) = hub_token { - Some(UploadInfo::Hub(HubUploadInfo { - hub_repo: (&repo_id).into(), - hub_token: token, + let upload_info = if skip_upload { + info!("Skipping checkpoint upload (skip_upload flag is set)"); + None + } else { + match checkpoint { + model::Checkpoint::Hub(HubRepo { + repo_id, + revision: _, + }) + | model::Checkpoint::P2P(HubRepo { + repo_id, + revision: _, + }) => { + if let Some(token) = hub_token { + Some(UploadInfo::Hub(HubUploadInfo { + hub_repo: (&repo_id).into(), + hub_token: token, + })) + } else { + warn!("HF_TOKEN env not provided, skipping upload to HuggingFace Hub"); + None + } + } + model::Checkpoint::Gcs(model::GcsRepo { bucket, prefix }) + | model::Checkpoint::P2PGcs(model::GcsRepo { bucket, prefix }) => { + Some(UploadInfo::Gcs(GcsUploadInfo { + gcs_bucket: (&bucket).into(), + gcs_prefix: prefix.as_ref().map(|p| p.into()), })) - } else { - warn!("HF_TOKEN env not provided, skipping upload to HuggingFace Hub"); - None } + _ => None, } - model::Checkpoint::Gcs(model::GcsRepo { bucket, prefix }) - | model::Checkpoint::P2PGcs(model::GcsRepo { bucket, prefix }) => { - Some(UploadInfo::Gcs(GcsUploadInfo { - gcs_bucket: (&bucket).into(), - gcs_prefix: prefix.as_ref().map(|p| p.into()), - })) - } - _ => None, }; let path = checkpoint_dir.join(format!("{run_id}-step{step}")); diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index 9f79a2dcf..d2c3a042e 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -316,14 +316,11 @@ impl RunInitConfigAndIO { + | model::Checkpoint::Gcs(_) => { let checkpoint = llm.checkpoint; tokio::spawn(async move { let (source, tokenizer, checkpoint_extra_files) = match checkpoint { - model::Checkpoint::Hub(hub_repo) - | model::Checkpoint::NoUploadHubRepo(hub_repo) => { + model::Checkpoint::Hub(hub_repo) => { let repo_id: String = (&hub_repo.repo_id).into(); let potential_local_path = PathBuf::from(repo_id.clone()); let revision = hub_repo.revision.map(|bytes| (&bytes).into()); @@ -435,8 +432,7 @@ impl RunInitConfigAndIO { + model::Checkpoint::Gcs(gcs_repo) => { let bucket: String = (&gcs_repo.bucket).into(); let prefix: Option = gcs_repo.prefix.map(|p| (&p).into()); diff --git a/shared/client/src/state/types.rs b/shared/client/src/state/types.rs index 3cdd98c3e..48aee7495 100644 --- a/shared/client/src/state/types.rs +++ b/shared/client/src/state/types.rs @@ -22,6 +22,8 @@ pub struct CheckpointConfig { pub delete_old_steps: bool, pub keep_steps: u32, pub hub_token: Option, + /// Skip uploading checkpoints (for testing). Checkpoints are still saved locally. + pub skip_upload: bool, } impl CheckpointConfig { @@ -31,6 +33,7 @@ impl CheckpointConfig { delete_old_steps: false, keep_steps: 1, hub_token: None, + skip_upload: false, } } } diff --git a/shared/coordinator/src/model.rs b/shared/coordinator/src/model.rs index 426f34e08..7617f2508 100644 --- a/shared/coordinator/src/model.rs +++ b/shared/coordinator/src/model.rs @@ -297,8 +297,6 @@ pub enum Checkpoint { Gcs(GcsRepo), P2P(HubRepo), P2PGcs(GcsRepo), - NoUploadHubRepo(HubRepo), // Load from Hub, save locally, skip upload - NoUploadGcs(GcsRepo), // Load from GCS, save locally, skip upload } impl std::fmt::Display for Checkpoint { @@ -318,13 +316,6 @@ impl std::fmt::Display for Checkpoint { Some(prefix) => write!(f, "P2P - gs://{}/{}", &gcs_repo.bucket, prefix), None => write!(f, "P2P - gs://{}", &gcs_repo.bucket), }, - Checkpoint::NoUploadHubRepo(hub_repo) => { - write!(f, "NoUpload - Hub repo: {}", &hub_repo.repo_id) - } - Checkpoint::NoUploadGcs(gcs_repo) => match &gcs_repo.prefix { - Some(prefix) => write!(f, "NoUpload - gs://{}/{}", &gcs_repo.bucket, prefix), - None => write!(f, "NoUpload - gs://{}", &gcs_repo.bucket), - }, } } } @@ -364,12 +355,8 @@ impl Model { Checkpoint::Dummy(_hub_repo) => false, Checkpoint::Ephemeral => true, Checkpoint::P2P(_) | Checkpoint::P2PGcs(_) => true, // P2P is internal state, not configurable - Checkpoint::Hub(hub_repo) | Checkpoint::NoUploadHubRepo(hub_repo) => { - hub_repo.repo_id.is_empty() - } - Checkpoint::Gcs(gcs_repo) | Checkpoint::NoUploadGcs(gcs_repo) => { - gcs_repo.bucket.is_empty() - } + Checkpoint::Hub(hub_repo) => hub_repo.repo_id.is_empty(), + Checkpoint::Gcs(gcs_repo) => gcs_repo.bucket.is_empty(), }; if bad_checkpoint { From d6a1201480e02d22e649affff8164f6658f7f1d1 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Tue, 20 Jan 2026 17:14:15 -0300 Subject: [PATCH 68/72] Send cooldown witness on skip checkpoint --- shared/client/src/state/cooldown.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index cb736d352..db0728fe5 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -257,6 +257,10 @@ impl CooldownStepMetadata { } else { checkpoint_completed.store(true, Ordering::SeqCst); } + } else { + // No upload needed (skip_upload or unsupported checkpoint type) + // Mark checkpoint as complete so cooldown witness can be sent + checkpoint_completed.store(true, Ordering::SeqCst); } cleanup_dirs( From 816ba5b423275b65cc3d7b33e487b185616395f9 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Tue, 20 Jan 2026 17:38:19 -0300 Subject: [PATCH 69/72] Skip local save with skip upload flag --- shared/client/src/state/cooldown.rs | 66 ++++++++++++++--------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index db0728fe5..0bc2c125f 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -207,11 +207,12 @@ impl CooldownStepMetadata { skip_upload, } = checkpoint_info; - let upload_info = if skip_upload { - info!("Skipping checkpoint upload (skip_upload flag is set)"); - None + // When skip_upload is true (testing), skip all checkpoint saving + if skip_upload { + info!("Skipping checkpoint save and upload (skip_upload flag is set)"); + checkpoint_completed.store(true, Ordering::SeqCst); } else { - match checkpoint { + let upload_info = match checkpoint { model::Checkpoint::Hub(HubRepo { repo_id, revision: _, @@ -238,40 +239,39 @@ impl CooldownStepMetadata { })) } _ => None, - } - }; - - let path = checkpoint_dir.join(format!("{run_id}-step{step}")); - let local = - save_checkpoint_locally(path, variables, checkpoint_extra_files).await?; - - if let Some(upload_info) = upload_info { - let manifest_metadata = GcsManifestMetadata { - epoch, - run_id: run_id.clone(), }; - let result = upload_checkpoint(upload_info, manifest_metadata, local.clone(), step as u64, cancellation_token.clone()) - .await; - if let Err(err) = result { - error!("Error uploading checkpoint: {}", err); + + let path = checkpoint_dir.join(format!("{run_id}-step{step}")); + let local = + save_checkpoint_locally(path, variables, checkpoint_extra_files).await?; + + if let Some(upload_info) = upload_info { + let manifest_metadata = GcsManifestMetadata { + epoch, + run_id: run_id.clone(), + }; + let result = upload_checkpoint(upload_info, manifest_metadata, local.clone(), step as u64, cancellation_token.clone()) + .await; + if let Err(err) = result { + error!("Error uploading checkpoint: {}", err); + } else { + checkpoint_completed.store(true, Ordering::SeqCst); + } } else { + // No upload configured, but local save succeeded checkpoint_completed.store(true, Ordering::SeqCst); } - } else { - // No upload needed (skip_upload or unsupported checkpoint type) - // Mark checkpoint as complete so cooldown witness can be sent - checkpoint_completed.store(true, Ordering::SeqCst); - } - cleanup_dirs( - delete_queue, - keep_steps, - run_id, - delete_old_steps, - step, - checkpoint_dir, - ) - .await; + cleanup_dirs( + delete_queue, + keep_steps, + run_id, + delete_old_steps, + step, + checkpoint_dir, + ) + .await; + } Ok(evals) } From c6016b754aeff1b563dbbd12a31188901974c598 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Mon, 26 Jan 2026 11:15:49 -0300 Subject: [PATCH 70/72] Remove comment and add early return on skip upload check --- shared/client/src/state/cooldown.rs | 109 ++++++++++++++-------------- shared/data-provider/src/errors.rs | 4 - 2 files changed, 55 insertions(+), 58 deletions(-) diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index 0bc2c125f..cb04bf025 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -211,68 +211,69 @@ impl CooldownStepMetadata { if skip_upload { info!("Skipping checkpoint save and upload (skip_upload flag is set)"); checkpoint_completed.store(true, Ordering::SeqCst); - } else { - let upload_info = match checkpoint { - model::Checkpoint::Hub(HubRepo { - repo_id, - revision: _, - }) - | model::Checkpoint::P2P(HubRepo { - repo_id, - revision: _, - }) => { - if let Some(token) = hub_token { - Some(UploadInfo::Hub(HubUploadInfo { - hub_repo: (&repo_id).into(), - hub_token: token, - })) - } else { - warn!("HF_TOKEN env not provided, skipping upload to HuggingFace Hub"); - None - } - } - model::Checkpoint::Gcs(model::GcsRepo { bucket, prefix }) - | model::Checkpoint::P2PGcs(model::GcsRepo { bucket, prefix }) => { - Some(UploadInfo::Gcs(GcsUploadInfo { - gcs_bucket: (&bucket).into(), - gcs_prefix: prefix.as_ref().map(|p| p.into()), - })) - } - _ => None, - }; + return Ok(evals); + } - let path = checkpoint_dir.join(format!("{run_id}-step{step}")); - let local = - save_checkpoint_locally(path, variables, checkpoint_extra_files).await?; - - if let Some(upload_info) = upload_info { - let manifest_metadata = GcsManifestMetadata { - epoch, - run_id: run_id.clone(), - }; - let result = upload_checkpoint(upload_info, manifest_metadata, local.clone(), step as u64, cancellation_token.clone()) - .await; - if let Err(err) = result { - error!("Error uploading checkpoint: {}", err); + let upload_info = match checkpoint { + model::Checkpoint::Hub(HubRepo { + repo_id, + revision: _, + }) + | model::Checkpoint::P2P(HubRepo { + repo_id, + revision: _, + }) => { + if let Some(token) = hub_token { + Some(UploadInfo::Hub(HubUploadInfo { + hub_repo: (&repo_id).into(), + hub_token: token, + })) } else { - checkpoint_completed.store(true, Ordering::SeqCst); + warn!("HF_TOKEN env not provided, skipping upload to HuggingFace Hub"); + None } + } + model::Checkpoint::Gcs(model::GcsRepo { bucket, prefix }) + | model::Checkpoint::P2PGcs(model::GcsRepo { bucket, prefix }) => { + Some(UploadInfo::Gcs(GcsUploadInfo { + gcs_bucket: (&bucket).into(), + gcs_prefix: prefix.as_ref().map(|p| p.into()), + })) + } + _ => None, + }; + + let path = checkpoint_dir.join(format!("{run_id}-step{step}")); + let local = + save_checkpoint_locally(path, variables, checkpoint_extra_files).await?; + + if let Some(upload_info) = upload_info { + let manifest_metadata = GcsManifestMetadata { + epoch, + run_id: run_id.clone(), + }; + let result = upload_checkpoint(upload_info, manifest_metadata, local.clone(), step as u64, cancellation_token.clone()) + .await; + if let Err(err) = result { + error!("Error uploading checkpoint: {}", err); } else { - // No upload configured, but local save succeeded checkpoint_completed.store(true, Ordering::SeqCst); } - - cleanup_dirs( - delete_queue, - keep_steps, - run_id, - delete_old_steps, - step, - checkpoint_dir, - ) - .await; + } else { + // No upload configured, but local save succeeded + checkpoint_completed.store(true, Ordering::SeqCst); } + cleanup_dirs( + delete_queue, + keep_steps, + run_id, + delete_old_steps, + step, + checkpoint_dir, + ) + .await; + Ok(evals) } .instrument(info_span!("checkpointing")) diff --git a/shared/data-provider/src/errors.rs b/shared/data-provider/src/errors.rs index c99185aa0..9de7bb76b 100644 --- a/shared/data-provider/src/errors.rs +++ b/shared/data-provider/src/errors.rs @@ -13,8 +13,6 @@ pub enum UploadError { #[error("GCS authentication failed: {0}")] GcsAuth(String), - //#[error("GCS operation failed: {0}")] - //GcsStorage(#[from] google_cloud_storage::client::Error), #[error("IO error: {0}")] Io(#[from] std::io::Error), @@ -36,8 +34,6 @@ pub enum DownloadError { #[error("GCS authentication failed: {0}")] GcsAuth(String), - //#[error("GCS operation failed: {0}")] - //GcsStorage(#[from] google_cloud_storage::client::Error), #[error("IO error: {0}")] Io(#[from] std::io::Error), From 98c1abe244028f21eb7dfcd62622511c363ac46c Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Mon, 26 Jan 2026 12:10:43 -0300 Subject: [PATCH 71/72] Fix centralized permission check to upload --- architectures/centralized/client/src/app.rs | 105 +++++++------------- architectures/centralized/server/src/app.rs | 2 +- 2 files changed, 36 insertions(+), 71 deletions(-) diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index 3af8322f7..cf4f00d69 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -1,14 +1,12 @@ use anyhow::{Error, Result}; use bytemuck::Zeroable; -use bytes::Bytes; use google_cloud_storage::client::{Storage, StorageControl}; -use hf_hub::Repo; use psyche_centralized_shared::{ClientId, ClientToServerMessage, ServerToClientMessage}; use psyche_client::{ Client, ClientTUI, ClientTUIState, NC, RunInitConfig, TrainArgs, read_identity_secret_key, }; use psyche_client::{GcsUploadInfo, HubUploadInfo, UploadInfo}; -use psyche_coordinator::model::{self, Checkpoint, GcsRepo, HubRepo, LLM, Model}; +use psyche_coordinator::model::Checkpoint; use psyche_coordinator::{Coordinator, HealthChecks}; use psyche_metrics::ClientMetrics; use psyche_network::{ @@ -178,92 +176,59 @@ impl App { p2p: NC, state_options: RunInitConfig, ) -> Result<()> { - // sanity checks - let Model::LLM(LLM { checkpoint, .. }) = &self.coordinator_state.model; + // Sanity checks using the checkpoint config from state_options, not the zeroed coordinator state. + // The coordinator_state is only populated after receiving the first ServerToClientMessage::Coordinator. if !self.skip_upload_check { - let upload_info = match checkpoint { - model::Checkpoint::Hub(HubRepo { repo_id, revision }) - | model::Checkpoint::P2P(HubRepo { repo_id, revision }) => { - Some(UploadInfo::Hub(HubUploadInfo { - hub_repo: (repo_id).into(), - hub_token: (&revision.unwrap_or_default()).into(), - })) - } - model::Checkpoint::Gcs(GcsRepo { bucket, prefix }) - | model::Checkpoint::P2PGcs(model::GcsRepo { bucket, prefix }) => { - Some(UploadInfo::Gcs(GcsUploadInfo { - gcs_bucket: (bucket).into(), - gcs_prefix: Some((&prefix.unwrap_or_default()).into()), - })) + let upload_info = match &state_options.checkpoint_config { + config if config.skip_upload => Some(UploadInfo::Dummy()), + config => { + // Use HF_TOKEN from checkpoint_config for Hub uploads + if let Some(ref hub_token) = config.hub_token { + Some(UploadInfo::Hub(HubUploadInfo { + hub_repo: String::new(), // Will be validated when actual checkpoint is received + hub_token: hub_token.clone(), + })) + } else { + // Check if GCS credentials are available by attempting to create a client + match Storage::builder().build().await { + Ok(_) => Some(UploadInfo::Gcs(GcsUploadInfo { + gcs_bucket: String::new(), // Will be validated when actual checkpoint is received + gcs_prefix: None, + })), + Err(_) => None, + } + } } - _ => None, }; match upload_info { Some(UploadInfo::Hub(HubUploadInfo { - hub_repo, + hub_repo: _, hub_token, })) => { - let api = hf_hub::api::tokio::ApiBuilder::new() + let _api = hf_hub::api::tokio::ApiBuilder::new() .with_token(Some(hub_token.clone())) .build()?; - let repo_api = api.repo(Repo::new(hub_repo.clone(), hf_hub::RepoType::Model)); - if !repo_api.is_writable().await { - anyhow::bail!( - "Checkpoint upload repo {} is not writable with the passed API key.", - hub_repo - ) - } } - Some(UploadInfo::Gcs(gcs_info)) => { - let storage = Storage::builder() + Some(UploadInfo::Gcs(_gcs_info)) => { + let _storage = Storage::builder() .build() .await .map_err(|e| anyhow::anyhow!("Failed to create GCS client: {}", e))?; - let storage_control = StorageControl::builder().build().await.map_err(|e| { - anyhow::anyhow!("Failed to create GCS control client: {}", e) - })?; - - // Test write access by attempting to upload a small test object - let test_key = format!( - "{}/.write_test", - gcs_info.gcs_prefix.clone().unwrap_or_default() - ); - - let bucket_resource_name = - format!("projects/_/buckets/{}", gcs_info.gcs_bucket); - let test_data = Bytes::from(vec![]); - - let upload_result = storage - .write_object(&bucket_resource_name, &test_key, test_data) - .send_unbuffered() - .await; - match upload_result { - Ok(_) => { - // Test upload succeeded, the bucket is writable. Now we delete the test file - let _ = storage_control - .delete_object() - .set_bucket(bucket_resource_name.clone()) - .set_object(test_key) - .send() - .await; - } - Err(e) => { - anyhow::bail!( - "GCS bucket gs://{}/{} is not writable: {}", - gcs_info.gcs_bucket, - gcs_info.gcs_prefix.clone().unwrap_or_default(), - e - ) - } - } + let _storage_control = + StorageControl::builder().build().await.map_err(|e| { + anyhow::anyhow!("Failed to create GCS control client: {}", e) + })?; + // GCS credentials are valid - actual bucket writability will be checked during checkpoint } Some(UploadInfo::Dummy()) => { - // In test mode, we skip upload checks + // In test mode or skip_upload mode, we skip upload checks } None => { - anyhow::bail!("No upload info found for checkpointing"); + anyhow::bail!( + "No upload credentials found for checkpointing. Set HF_TOKEN for HuggingFace Hub or configure GCS credentials." + ); } } } diff --git a/architectures/centralized/server/src/app.rs b/architectures/centralized/server/src/app.rs index 4df2ca662..e203189b1 100644 --- a/architectures/centralized/server/src/app.rs +++ b/architectures/centralized/server/src/app.rs @@ -403,7 +403,7 @@ impl App { rand::rng().next_u64(), ), OpportunisticData::CooldownStep(witness) => { - self.coordinator.cooldown_witness(witness) + self.coordinator.cooldown_witness(&from, witness) } } { warn!("Error when processing witness: {error}"); From 7d8bf2cdca393ece440fa66904c7c9f2d0450086 Mon Sep 17 00:00:00 2001 From: IAvecilla Date: Mon, 26 Jan 2026 12:14:48 -0300 Subject: [PATCH 72/72] Fix cooldown checks and update comment on test flag --- .../solana-coordinator/src/instance_state.rs | 10 ++++-- .../programs/solana-coordinator/src/lib.rs | 15 ++++---- shared/client/src/cli.rs | 2 +- shared/client/src/state/types.rs | 2 +- shared/coordinator/src/coordinator.rs | 36 +++++++++++-------- 5 files changed, 41 insertions(+), 24 deletions(-) diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs index 6751fdf0c..79eadddef 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs @@ -233,9 +233,15 @@ impl CoordinatorInstanceState { self.tick() } - pub fn cooldown_witness(&mut self, witness: Witness) -> Result<()> { + pub fn cooldown_witness( + &mut self, + payer: &Pubkey, + witness: Witness, + ) -> Result<()> { + let id = self.clients_state.find_signer(payer)?; + self.coordinator - .cooldown_witness(witness) + .cooldown_witness(id, witness) .map_err(|err| anchor_lang::error!(ProgramError::from(err)))?; self.tick() diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs index 71ee7e949..e7fd2bee9 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs @@ -297,12 +297,15 @@ pub mod psyche_solana_coordinator { ) -> Result<()> { let mut account = ctx.accounts.coordinator_account.load_mut()?; account.increment_nonce(); - account.state.cooldown_witness(Witness { - proof, - participant_bloom, - broadcast_bloom, - broadcast_merkle, - }) + account.state.cooldown_witness( + ctx.accounts.user.key, + Witness { + proof, + participant_bloom, + broadcast_bloom, + broadcast_merkle, + }, + ) } pub fn health_check( diff --git a/shared/client/src/cli.rs b/shared/client/src/cli.rs index 69e8aa271..c79bdd5c6 100644 --- a/shared/client/src/cli.rs +++ b/shared/client/src/cli.rs @@ -192,7 +192,7 @@ pub struct TrainArgs { #[clap(long, default_value_t = 3, env)] pub keep_steps: u32, - /// Skip uploading checkpoints to Hub/GCS (for testing). Checkpoints are still saved locally. + /// Skip saving and uploading checkpoints (for testing). #[clap(long, default_value_t = false, env, hide = true)] pub skip_checkpoint_upload: bool, } diff --git a/shared/client/src/state/types.rs b/shared/client/src/state/types.rs index 48aee7495..085211cb7 100644 --- a/shared/client/src/state/types.rs +++ b/shared/client/src/state/types.rs @@ -22,7 +22,7 @@ pub struct CheckpointConfig { pub delete_old_steps: bool, pub keep_steps: u32, pub hub_token: Option, - /// Skip uploading checkpoints (for testing). Checkpoints are still saved locally. + /// Skip saving and uploading checkpoints (for testing). pub skip_upload: bool, } diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index 07cd3c526..744eafba6 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -512,6 +512,7 @@ impl Coordinator { pub fn cooldown_witness( &mut self, + from: &T, witness: Witness, ) -> std::result::Result<(), CoordinatorError> { if self.halted() { @@ -522,6 +523,12 @@ impl Coordinator { return Ok(()); } + // Verify the sender matches the witness index to prevent spoofing + let index = witness.proof.index as usize; + if index >= self.epoch_state.clients.len() || self.epoch_state.clients[index].id != *from { + return Err(CoordinatorError::InvalidWitness); + } + let checkpointer_selection = CheckpointerSelection::from_coordinator(self, 0)?; if !checkpointer_selection .is_checkpointer(witness.proof.index, self.epoch_state.clients.len() as u64) @@ -630,6 +637,21 @@ impl Coordinator { return Err(CoordinatorError::InvalidCommitteeProof); } + if self.halted() { + return Err(CoordinatorError::Halted); + } + + if !matches!(self.run_state, RunState::Cooldown) { + return Err(CoordinatorError::InvalidRunState); + } + + let checkpointer_selection = CheckpointerSelection::from_coordinator(self, 0)?; + if !checkpointer_selection + .is_checkpointer(index as u64, self.epoch_state.clients.len() as u64) + { + return Err(CoordinatorError::InvalidWitness); + } + let Model::LLM(llm) = &mut self.model; match (&llm.checkpoint, checkpoint_repo) { // If current is P2P, wrap the new checkpoint in P2P @@ -651,20 +673,6 @@ impl Coordinator { _ => {} } - if self.halted() { - return Err(CoordinatorError::Halted); - } - - if !matches!(self.run_state, RunState::Cooldown) { - return Err(CoordinatorError::InvalidRunState); - } - let checkpointer_selection = CheckpointerSelection::from_coordinator(self, 0)?; - if !checkpointer_selection - .is_checkpointer(index as u64, self.epoch_state.clients.len() as u64) - { - return Err(CoordinatorError::InvalidWitness); - } - self.epoch_state.checkpointed = true; Ok(())