diff --git a/architectures/centralized/server/src/app.rs b/architectures/centralized/server/src/app.rs index bc034d1db..75a3827eb 100644 --- a/architectures/centralized/server/src/app.rs +++ b/architectures/centralized/server/src/app.rs @@ -177,11 +177,22 @@ impl App { let training_data_server = match &coordinator.model { Model::LLM(LLM { - data_location, checkpoint, .. }) => { - if let LLMTrainingDataLocation::Server(url) = data_location { + let data_locations = &coordinator.data_locations; + let data_location_server_urls:Vec<_> = data_locations.iter().filter_map(|l| match l {LLMTrainingDataLocation::Server(url) => Some(url.to_string()), _=> None}).collect(); + + if data_location_server_urls.is_empty() { + None + } else { + if data_location_server_urls.len() > 1 { + bail!("More than one LLMTrainingDataLocation::Server configured, but we only support hosting a single one."); + } + + // we know there's a single url, and it's the one that includes the port we want to host on. + let url = data_location_server_urls.first().unwrap(); + match checkpoint { Checkpoint::Hub(hub_repo) => { let repo_id = String::from(&hub_repo.repo_id); @@ -206,7 +217,7 @@ impl App { } } - let server_addr: SocketAddr = String::from(url).parse().map_err(|e| { + let server_addr: SocketAddr = url.parse().map_err(|e| { anyhow!("Failed to parse training data server URL {:?}: {}", url, e) })?; let data_server_port = server_addr.port(); @@ -231,8 +242,6 @@ impl App { DataProviderTcpServer::start(local_data_provider, backend, data_server_port) .await?; Some((tx, data_server)) - } else { - None } } }; diff --git a/architectures/centralized/testing/src/server.rs b/architectures/centralized/testing/src/server.rs index 0c0ec60a8..59f9686ec 100644 --- a/architectures/centralized/testing/src/server.rs +++ b/architectures/centralized/testing/src/server.rs @@ -3,6 +3,7 @@ use crate::{MAX_ROUND_TRAIN_TIME, ROUND_WITNESS_TIME, WARMUP_TIME}; use bytemuck::Zeroable; use psyche_centralized_server::app::App as ServerApp; use psyche_centralized_shared::ClientId; +use psyche_coordinator::model::LLMDataLocations; use psyche_coordinator::{Client, Round}; use psyche_coordinator::{ Coordinator, CoordinatorConfig, CoordinatorEpochState, RunState, SOLANA_MAX_NUM_CLIENTS, @@ -94,6 +95,7 @@ impl CoordinatorServer { model: Model::LLM(LLM::dummy()), config: coordinator_config, epoch_state, + data_locations: LLMDataLocations::dummy(), ..Coordinator::::zeroed() }; diff --git a/architectures/decentralized/solana-client/src/command/json_dump_run.rs b/architectures/decentralized/solana-client/src/command/json_dump_run.rs index 8d5fa0fbd..36567637f 100644 --- a/architectures/decentralized/solana-client/src/command/json_dump_run.rs +++ b/architectures/decentralized/solana-client/src/command/json_dump_run.rs @@ -107,6 +107,7 @@ pub async fn command_json_dump_run_execute( "client_version": coordinator_account_state.state.client_version, "metadata": coordinator_account_state.state.metadata, "model": coordinator_account_state.state.coordinator.model, + "data_locations": coordinator_account_state.state.coordinator.data_locations, "config": coordinator_account_state.state.coordinator.config, }, "status": { diff --git a/architectures/decentralized/solana-client/src/command/set_future_epoch_rates.rs b/architectures/decentralized/solana-client/src/command/set_future_epoch_rates.rs index 75ae70c48..643344521 100644 --- a/architectures/decentralized/solana-client/src/command/set_future_epoch_rates.rs +++ b/architectures/decentralized/solana-client/src/command/set_future_epoch_rates.rs @@ -72,6 +72,7 @@ pub async fn command_set_future_epoch_rates_execute( .map(|amount| ui_amount_to_native_amount(amount, mint_decimals)), paused: None, client_version: None, + data_location: None, }, ); diff --git a/architectures/decentralized/solana-client/src/command/set_paused.rs b/architectures/decentralized/solana-client/src/command/set_paused.rs index 1479e2d27..1c97faf1d 100644 --- a/architectures/decentralized/solana-client/src/command/set_paused.rs +++ b/architectures/decentralized/solana-client/src/command/set_paused.rs @@ -52,6 +52,7 @@ pub async fn command_set_paused_execute( epoch_slashing_rate_per_client: None, paused: Some(paused), client_version: None, + data_location: None, }, ) } else { diff --git a/architectures/decentralized/solana-client/src/command/update_config.rs b/architectures/decentralized/solana-client/src/command/update_config.rs index 17f41be28..322cdc095 100644 --- a/architectures/decentralized/solana-client/src/command/update_config.rs +++ b/architectures/decentralized/solana-client/src/command/update_config.rs @@ -1,13 +1,13 @@ -use std::path::PathBuf; - use anyhow::{Context, Result, bail}; use clap::Args; use psyche_coordinator::{ CoordinatorConfig, CoordinatorProgress, get_data_index_for_step, - model::{Checkpoint, Model}, + model::{Checkpoint, LLMDataLocations, LLMTrainingDataLocation, Model}, }; +use psyche_core::FixedVec; use psyche_solana_treasurer::logic::RunUpdateParams; use serde::{Deserialize, Serialize}; +use std::path::PathBuf; use crate::{SolanaBackend, instructions}; @@ -68,22 +68,58 @@ pub async fn command_update_config_execute( .get_coordinator_account(&coordinator_account) .await?; - let (config, mut model) = match config_path { + let (config, mut model, data_locations) = match config_path { Some(config_path) => { + #[derive(Serialize, Deserialize)] + struct ModelWrapper { + #[serde(flatten)] + pub model: Model, + } + #[derive(Serialize, Deserialize)] struct State { pub config: CoordinatorConfig, - pub model: Model, + pub model: ModelWrapper, } + + // First, parse without data_locations to get the Model enum let state: State = toml::from_str(std::str::from_utf8( &std::fs::read(&config_path) .with_context(|| format!("failed to read config toml file {config_path:?}"))?, )?) .with_context(|| format!("failed to parse config toml file {config_path:?}"))?; - (Some(state.config), Some(state.model)) + // Then parse just the data_locations separately + #[derive(Serialize, Deserialize)] + struct DataLocationsWrapper { + pub data_locations: Vec, + } + + #[derive(Serialize, Deserialize)] + struct LLMSection { + #[serde(rename = "LLM")] + pub llm: DataLocationsWrapper, + } + + #[derive(Serialize, Deserialize)] + struct ModelSection { + pub model: LLMSection, + } + + let data_section: ModelSection = toml::from_str(std::str::from_utf8( + &std::fs::read(&config_path) + .with_context(|| format!("failed to read config toml file {config_path:?}"))?, + )?)?; + + let data_locs = LLMDataLocations { + data_locations: FixedVec::from_iter( + data_section.model.llm.data_locations.into_iter(), + ), + }; + + (Some(state.config), Some(state.model.model), Some(data_locs)) } - None => (None, None), + None => (None, None, None), }; model = if switch_to_hub { @@ -133,6 +169,10 @@ pub async fn command_update_config_execute( coordinator_account_state.state.coordinator.model = model; } + if let Some(data_locations) = data_locations { + coordinator_account_state.state.coordinator.data_locations = data_locations; + } + let progress = restart_from_step.map(|step| CoordinatorProgress { epoch: coordinator_account_state.state.coordinator.progress.epoch, step, @@ -148,11 +188,14 @@ pub async fn command_update_config_execute( bail!("this invocation would not update anything, bailing.") } - let instructions = if let Some(treasurer_index) = backend + let (instructions, data_location_instr) = if let Some(treasurer_index) = backend .resolve_treasurer_index(&run_id, treasurer_index) .await? { - vec![instructions::treasurer_run_update( + let mut instructions = Vec::new(); + let mut data_location_instr = Vec::new(); + + instructions.push(instructions::treasurer_run_update( &run_id, treasurer_index, &coordinator_account, @@ -166,10 +209,35 @@ pub async fn command_update_config_execute( epoch_slashing_rate_per_client: None, paused: None, client_version: client_version.clone(), + data_location: None, }, - )] + )); + if let Some(data_locations) = data_locations { + for dl in data_locations.data_locations.iter() { + data_location_instr.push(instructions::treasurer_run_update( + &run_id, + treasurer_index, + &coordinator_account, + &main_authority, + RunUpdateParams { + metadata: None, + config: None, + model: None, + progress: None, + epoch_earning_rate_total_shared: None, + epoch_slashing_rate_per_client: None, + paused: None, + client_version: None, + data_location: Some(*dl), + }, + )); + } + } + (instructions, data_location_instr) } else { let mut instructions = Vec::new(); + let mut data_location_instr = Vec::new(); + let data_locations_iter = data_locations.unwrap().iter().cloned().collect::>(); if coordinator_update { instructions.push(instructions::coordinator_update( @@ -181,6 +249,19 @@ pub async fn command_update_config_execute( model, progress, )); + data_location_instr.push(instructions::clear_data_locations( + &run_id, + &coordinator_account, + &main_authority, + )); + for dl in data_locations_iter.iter() { + data_location_instr.push(instructions::coordinator_update_data_locations( + &run_id, + &coordinator_account, + &main_authority, + Some(*dl), + )); + } } if let Some(client_version) = client_version.clone() { @@ -192,16 +273,21 @@ pub async fn command_update_config_execute( )); } - instructions + (instructions, data_location_instr) }; let signature = backend .send_and_retry("Update config", &instructions, &[]) .await?; println!("Updated config of {run_id} with transaction {signature}"); + let signature = backend + .send_and_retry("Update data locations", &data_location_instr, &[]) + .await?; + println!(" - Metadata: {metadata:#?}"); println!(" - Config: {config:#?}"); println!(" - Model: {model:#?}"); + println!(" - Data locations: {data_locations:#?}"); println!(" - Progress: {progress:#?}"); println!(" - Client version: {client_version:#?}"); diff --git a/architectures/decentralized/solana-client/src/instructions.rs b/architectures/decentralized/solana-client/src/instructions.rs index a8bec54a0..c42246693 100644 --- a/architectures/decentralized/solana-client/src/instructions.rs +++ b/architectures/decentralized/solana-client/src/instructions.rs @@ -53,6 +53,41 @@ pub fn coordinator_close_run( ) } +pub fn clear_data_locations( + run_id: &str, + coordinator_account: &Pubkey, + main_authority: &Pubkey, +) -> Instruction { + let coordinator_instance = psyche_solana_coordinator::find_coordinator_instance(run_id); + anchor_instruction( + psyche_solana_coordinator::ID, + psyche_solana_coordinator::accounts::OwnerCoordinatorAccounts { + authority: *main_authority, + coordinator_instance, + coordinator_account: *coordinator_account, + }, + psyche_solana_coordinator::instruction::ClearDataLocations {}, + ) +} + +pub fn coordinator_update_data_locations( + run_id: &str, + coordinator_account: &Pubkey, + main_authority: &Pubkey, + data_location: Option, +) -> Instruction { + let coordinator_instance = psyche_solana_coordinator::find_coordinator_instance(run_id); + anchor_instruction( + psyche_solana_coordinator::ID, + psyche_solana_coordinator::accounts::OwnerCoordinatorAccounts { + authority: *main_authority, + coordinator_instance, + coordinator_account: *coordinator_account, + }, + psyche_solana_coordinator::instruction::UpdateDataLocations { data_location }, + ) +} + pub fn coordinator_update( run_id: &str, coordinator_account: &Pubkey, diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs index 029b40fad..d5674106b 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs @@ -10,9 +10,13 @@ use psyche_coordinator::RunState; use psyche_coordinator::SOLANA_MAX_STRING_LEN; use psyche_coordinator::TickResult; use psyche_coordinator::Witness; +use psyche_coordinator::model::HttpLLMTrainingDataLocation; +use psyche_coordinator::model::HttpTrainingDataLocation; use psyche_coordinator::model::HubRepo; +use psyche_coordinator::model::LLMTrainingDataLocation; use psyche_coordinator::model::Model; use psyche_core::FixedString; +use psyche_core::FixedVec; use psyche_core::SmallBoolean; use psyche_core::sha256v; use serde::Deserialize; @@ -192,6 +196,42 @@ impl CoordinatorInstanceState { return err!(ProgramError::ModelSanityCheckFailed); } + for data_location in self.coordinator.data_locations.iter() { + let bad_data_location = match data_location { + LLMTrainingDataLocation::Dummy(_) => false, + LLMTrainingDataLocation::Server(url) => url.is_empty(), + LLMTrainingDataLocation::Local(_) => false, + LLMTrainingDataLocation::Http( + HttpLLMTrainingDataLocation { location, .. }, + ) => match location { + HttpTrainingDataLocation::SingleUrl(url) => { + url.is_empty() + }, + HttpTrainingDataLocation::NumberedFiles { + url_template, + num_files, + .. + } => url_template.is_empty() || *num_files == 0, + HttpTrainingDataLocation::Gcp { + bucket_name, + .. + } => bucket_name.is_empty(), + }, + LLMTrainingDataLocation::WeightedHttp(url) => { + url.is_empty() + }, + LLMTrainingDataLocation::Preprocessed(url) => { + url.is_empty() + }, + }; + if bad_data_location { + msg!( + "model check failed: bad LLM training data location." + ); + return err!(ProgramError::ModelSanityCheckFailed); + } + } + if self.coordinator.run_state == RunState::Uninitialized { // this is the only way to get out of RunState::Uninitialized // by doing this we force the sanity checks on the config and model @@ -275,6 +315,44 @@ impl CoordinatorInstanceState { Ok(()) } + pub fn update_data_locations( + &mut self, + data_location: Option< + psyche_coordinator::model::LLMTrainingDataLocation, + >, + ) -> Result<()> { + if self.coordinator.run_state == RunState::Finished { + return err!(ProgramError::UpdateConfigFinished); + } else if !self.coordinator.halted() && data_location.is_some() { + return err!(ProgramError::UpdateConfigNotHalted); + } + + let mut data_locations = self.coordinator.data_locations; + if let Some(dl) = data_location { + let _ = data_locations.data_locations.push(dl); + let _ = std::mem::replace( + &mut self.coordinator.data_locations, + data_locations, + ); + } + + Ok(()) + } + + pub fn clear_data_locations(&mut self) -> Result<()> { + if self.coordinator.run_state == RunState::Finished { + return err!(ProgramError::UpdateConfigFinished); + } else if !self.coordinator.halted() { + return err!(ProgramError::UpdateConfigNotHalted); + } + let _ = std::mem::replace( + &mut self.coordinator.data_locations.data_locations, + FixedVec::new(), + ); + + Ok(()) + } + pub fn update( &mut self, metadata: Option, diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs index 657424195..2f1feb19e 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs @@ -164,9 +164,10 @@ impl CoordinatorInstance { #[program] pub mod psyche_solana_coordinator { - use super::*; use psyche_core::FixedString; + use super::*; + pub fn init_coordinator( context: Context, params: InitCoordinatorParams, @@ -181,6 +182,25 @@ pub mod psyche_solana_coordinator { free_coordinator_processor(context, params) } + pub fn update_data_locations( + ctx: Context, + data_location: Option< + psyche_coordinator::model::LLMTrainingDataLocation, + >, + ) -> Result<()> { + let mut account = ctx.accounts.coordinator_account.load_mut()?; + account.increment_nonce(); + account.state.update_data_locations(data_location) + } + + pub fn clear_data_locations( + ctx: Context, + ) -> Result<()> { + let mut account = ctx.accounts.coordinator_account.load_mut()?; + account.increment_nonce(); + account.state.clear_data_locations() + } + pub fn update( ctx: Context, metadata: Option, diff --git a/architectures/decentralized/solana-coordinator/target/deploy/psyche_solana_coordinator-keypair.json b/architectures/decentralized/solana-coordinator/target/deploy/psyche_solana_coordinator-keypair.json index e7a29e9b4..49e92701b 100644 --- a/architectures/decentralized/solana-coordinator/target/deploy/psyche_solana_coordinator-keypair.json +++ b/architectures/decentralized/solana-coordinator/target/deploy/psyche_solana_coordinator-keypair.json @@ -1 +1 @@ -[64,238,5,158,112,133,38,180,4,62,68,219,46,236,189,68,44,131,70,134,229,152,44,218,72,233,162,120,147,52,99,51,51,13,179,3,249,169,215,84,254,219,157,144,170,99,145,211,144,51,17,103,241,3,92,148,244,17,156,198,157,197,61,26] \ No newline at end of file +[64,238,5,158,112,133,38,180,4,62,68,219,46,236,189,68,44,131,70,134,229,152,44,218,72,233,162,120,147,52,99,51,51,13,179,3,249,169,215,84,254,219,157,144,170,99,145,211,144,51,17,103,241,3,92,148,244,17,156,198,157,197,61,26] diff --git a/architectures/decentralized/solana-tooling/src/process_coordinator_instructions.rs b/architectures/decentralized/solana-tooling/src/process_coordinator_instructions.rs index 90489607e..d73b08730 100644 --- a/architectures/decentralized/solana-tooling/src/process_coordinator_instructions.rs +++ b/architectures/decentralized/solana-tooling/src/process_coordinator_instructions.rs @@ -3,6 +3,7 @@ use anchor_lang::ToAccountMetas; use anyhow::Result; use psyche_coordinator::CoordinatorConfig; use psyche_coordinator::CoordinatorProgress; +use psyche_coordinator::model::LLMTrainingDataLocation; use psyche_coordinator::model::Model; use psyche_solana_coordinator::ClientId; use psyche_solana_coordinator::RunMetadata; @@ -19,6 +20,7 @@ use psyche_solana_coordinator::instruction::SetFutureEpochRates; use psyche_solana_coordinator::instruction::SetPaused; use psyche_solana_coordinator::instruction::Tick; use psyche_solana_coordinator::instruction::Update; +use psyche_solana_coordinator::instruction::UpdateDataLocations; use psyche_solana_coordinator::instruction::Witness; use psyche_solana_coordinator::logic::FreeCoordinatorParams; use psyche_solana_coordinator::logic::InitCoordinatorParams; @@ -114,6 +116,30 @@ pub async fn process_update( Ok(()) } +pub async fn process_data_locations_update( + endpoint: &mut ToolboxEndpoint, + payer: &Keypair, + authority: &Keypair, + coordinator_instance: &Pubkey, + coordinator_account: &Pubkey, + data_location: Option, +) -> Result<()> { + let accounts = OwnerCoordinatorAccounts { + authority: authority.pubkey(), + coordinator_instance: *coordinator_instance, + coordinator_account: *coordinator_account, + }; + let instruction = Instruction { + accounts: accounts.to_account_metas(None), + data: UpdateDataLocations { data_location }.data(), + program_id: psyche_solana_coordinator::ID, + }; + endpoint + .process_instruction_with_signers(payer, instruction, &[authority]) + .await?; + Ok(()) +} + pub async fn process_coordinator_join_run( endpoint: &mut ToolboxEndpoint, payer: &Keypair, diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_full_round.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_full_round.rs index fc03202cf..3545e5686 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_full_round.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_full_round.rs @@ -3,13 +3,16 @@ use psyche_coordinator::RunState; use psyche_coordinator::WAITING_FOR_MEMBERS_EXTRA_SECONDS; use psyche_coordinator::WitnessProof; use psyche_coordinator::model::Checkpoint; +use psyche_coordinator::model::DummyType; use psyche_coordinator::model::HubRepo; use psyche_coordinator::model::LLM; use psyche_coordinator::model::LLMArchitecture; use psyche_coordinator::model::LLMTrainingDataLocation; use psyche_coordinator::model::LLMTrainingDataType; +use psyche_coordinator::model::MAX_DATA_LOCATIONS; use psyche_coordinator::model::Model; use psyche_core::ConstantLR; +use psyche_core::FixedVec; use psyche_core::LearningRateSchedule; use psyche_core::OptimizerDefinition; use psyche_solana_authorizer::logic::AuthorizationGrantorUpdateParams; @@ -27,6 +30,7 @@ use psyche_solana_tooling::process_coordinator_instructions::process_coordinator use psyche_solana_tooling::process_coordinator_instructions::process_coordinator_set_paused; use psyche_solana_tooling::process_coordinator_instructions::process_coordinator_tick; use psyche_solana_tooling::process_coordinator_instructions::process_coordinator_witness; +use psyche_solana_tooling::process_coordinator_instructions::process_data_locations_update; use psyche_solana_tooling::process_coordinator_instructions::process_update; use solana_sdk::signature::Keypair; use solana_sdk::signer::Signer; @@ -115,7 +119,6 @@ pub async fn run() { checkpoint: Checkpoint::Dummy(HubRepo::dummy()), max_seq_len: 4096, data_type: LLMTrainingDataType::Pretraining, - data_location: LLMTrainingDataLocation::default(), lr_schedule: LearningRateSchedule::Constant(ConstantLR::default()), optimizer: OptimizerDefinition::Distro { clip_grad_norm: None, @@ -132,6 +135,17 @@ pub async fn run() { .await .unwrap(); + process_data_locations_update( + &mut endpoint, + &payer, + &main_authority, + &coordinator_instance, + &coordinator_account, + Some(LLMTrainingDataLocation::default()), + ) + .await + .unwrap(); + // Coordinator's state should now have changed assert_eq!( get_coordinator_account_state(&mut endpoint, &coordinator_account) diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_rewards.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_rewards.rs index f69bbf8c5..a08c7d0ed 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_rewards.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_rewards.rs @@ -28,6 +28,7 @@ use psyche_solana_tooling::process_coordinator_instructions::process_coordinator use psyche_solana_tooling::process_coordinator_instructions::process_coordinator_set_paused; use psyche_solana_tooling::process_coordinator_instructions::process_coordinator_tick; use psyche_solana_tooling::process_coordinator_instructions::process_coordinator_witness; +use psyche_solana_tooling::process_coordinator_instructions::process_data_locations_update; use psyche_solana_tooling::process_coordinator_instructions::process_update; use solana_sdk::signature::Keypair; use solana_sdk::signer::Signer; @@ -112,7 +113,6 @@ pub async fn run() { checkpoint: Checkpoint::Dummy(HubRepo::dummy()), max_seq_len: 4096, data_type: LLMTrainingDataType::Pretraining, - data_location: LLMTrainingDataLocation::default(), lr_schedule: LearningRateSchedule::Constant(ConstantLR::default()), optimizer: OptimizerDefinition::Distro { clip_grad_norm: None, @@ -129,6 +129,17 @@ pub async fn run() { .await .unwrap(); + process_data_locations_update( + &mut endpoint, + &payer, + &main_authority, + &coordinator_instance, + &coordinator_account, + Some(LLMTrainingDataLocation::default()), + ) + .await + .unwrap(); + // Set the reward rate for the epoch process_coordiantor_set_future_epoch_rates( &mut endpoint, diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_create_update.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_create_update.rs index e51ced2dd..fe0724550 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_create_update.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_create_update.rs @@ -59,7 +59,6 @@ pub async fn run() { checkpoint: Checkpoint::Dummy(HubRepo::dummy()), max_seq_len: 4096, data_type: LLMTrainingDataType::Pretraining, - data_location: LLMTrainingDataLocation::default(), lr_schedule: LearningRateSchedule::Constant(ConstantLR::default()), optimizer: OptimizerDefinition::Distro { clip_grad_norm: None, @@ -76,6 +75,7 @@ pub async fn run() { epoch_slashing_rate_per_client: None, paused: Some(false), client_version: None, + data_location: Some(LLMTrainingDataLocation::default()), }; // Prepare the collateral mint diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_full_epoch.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_full_epoch.rs index 014772e32..dffafc5ff 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_full_epoch.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_full_epoch.rs @@ -1,17 +1,18 @@ -use std::vec; - use psyche_coordinator::CommitteeSelection; use psyche_coordinator::CoordinatorConfig; use psyche_coordinator::SOLANA_MAX_NUM_WITNESSES; use psyche_coordinator::WAITING_FOR_MEMBERS_EXTRA_SECONDS; use psyche_coordinator::model::Checkpoint; +use psyche_coordinator::model::DummyType; use psyche_coordinator::model::HubRepo; use psyche_coordinator::model::LLM; use psyche_coordinator::model::LLMArchitecture; use psyche_coordinator::model::LLMTrainingDataLocation; use psyche_coordinator::model::LLMTrainingDataType; +use psyche_coordinator::model::MAX_DATA_LOCATIONS; use psyche_coordinator::model::Model; use psyche_core::ConstantLR; +use psyche_core::FixedVec; use psyche_core::LearningRateSchedule; use psyche_core::OptimizerDefinition; use psyche_solana_authorizer::logic::AuthorizationGranteeUpdateParams; @@ -36,6 +37,7 @@ use psyche_solana_treasurer::logic::RunCreateParams; use psyche_solana_treasurer::logic::RunUpdateParams; use solana_sdk::signature::Keypair; use solana_sdk::signer::Signer; +use std::vec; #[tokio::test] pub async fn run() { @@ -234,7 +236,6 @@ pub async fn run() { checkpoint: Checkpoint::Dummy(HubRepo::dummy()), max_seq_len: 4096, data_type: LLMTrainingDataType::Pretraining, - data_location: LLMTrainingDataLocation::default(), lr_schedule: LearningRateSchedule::Constant( ConstantLR::default(), ), @@ -255,6 +256,7 @@ pub async fn run() { epoch_slashing_rate_per_client: None, paused: Some(false), client_version: None, + data_location: Some(LLMTrainingDataLocation::default()), }, ) .await diff --git a/architectures/decentralized/solana-treasurer/programs/solana-treasurer/src/logic/run_update.rs b/architectures/decentralized/solana-treasurer/programs/solana-treasurer/src/logic/run_update.rs index b1f4e02bc..65e4dc5c0 100644 --- a/architectures/decentralized/solana-treasurer/programs/solana-treasurer/src/logic/run_update.rs +++ b/architectures/decentralized/solana-treasurer/programs/solana-treasurer/src/logic/run_update.rs @@ -1,6 +1,7 @@ use anchor_lang::prelude::*; use psyche_coordinator::CoordinatorConfig; use psyche_coordinator::CoordinatorProgress; +use psyche_coordinator::model::LLMTrainingDataLocation; use psyche_coordinator::model::Model; use psyche_solana_coordinator::CoordinatorAccount; use psyche_solana_coordinator::CoordinatorInstance; @@ -47,6 +48,7 @@ pub struct RunUpdateParams { pub epoch_slashing_rate_per_client: Option, pub paused: Option, pub client_version: Option, + pub data_location: Option, } pub fn run_update_processor( @@ -109,6 +111,27 @@ pub fn run_update_processor( )?; } + if let Some(data_location) = params.data_location { + psyche_solana_coordinator::cpi::update_data_locations( + CpiContext::new( + context.accounts.coordinator_program.to_account_info(), + OwnerCoordinatorAccounts { + authority: context.accounts.run.to_account_info(), + coordinator_instance: context + .accounts + .coordinator_instance + .to_account_info(), + coordinator_account: context + .accounts + .coordinator_account + .to_account_info(), + }, + ) + .with_signer(run_signer_seeds), + Some(data_location), + )?; + } + if let Some(paused) = params.paused { set_paused( CpiContext::new( diff --git a/architectures/decentralized/testing/src/docker_setup.rs b/architectures/decentralized/testing/src/docker_setup.rs index 70ca92a23..cfc6d4b26 100644 --- a/architectures/decentralized/testing/src/docker_setup.rs +++ b/architectures/decentralized/testing/src/docker_setup.rs @@ -8,6 +8,7 @@ use bollard::{ secret::{ContainerSummary, HostConfig}, }; use psyche_client::IntegrationTestLogMarker; +use psyche_coordinator::model::LLMTrainingDataLocation; use std::process::{Command, Stdio}; use std::sync::Arc; use std::time::Duration; @@ -55,14 +56,21 @@ impl Drop for DockerTestCleanup { } } -/// FIXME: The config path must be relative to the compose file for now. pub async fn e2e_testing_setup( docker_client: Arc, init_num_clients: usize, +) -> DockerTestCleanup { + e2e_testing_setup_with_datasource(docker_client, init_num_clients, None).await +} + +pub async fn e2e_testing_setup_with_datasource( + docker_client: Arc, + init_num_clients: usize, + data_source: Option>, ) -> DockerTestCleanup { remove_old_client_containers(docker_client).await; - spawn_psyche_network(init_num_clients).unwrap(); + spawn_psyche_network(init_num_clients, data_source).unwrap(); spawn_ctrl_c_task(); @@ -72,15 +80,18 @@ pub async fn e2e_testing_setup( pub async fn e2e_testing_setup_subscription( docker_client: Arc, init_num_clients: usize, + data_source: Option>, ) -> DockerTestCleanup { remove_old_client_containers(docker_client).await; #[cfg(not(feature = "python"))] let config_file_path = ConfigBuilder::new() .with_num_clients(init_num_clients) + .with_data_source(data_source) .build(); #[cfg(feature = "python")] let config_file_path = ConfigBuilder::new() .with_num_clients(init_num_clients) + .with_data_source(data_source) .with_architecture("HfAuto") .with_batch_size(8 * init_num_clients as u32) .build(); @@ -233,14 +244,19 @@ pub async fn spawn_new_client_with_monitoring( } // Updated spawn function -pub fn spawn_psyche_network(init_num_clients: usize) -> Result<(), DockerWatcherError> { +pub fn spawn_psyche_network( + init_num_clients: usize, + data_source: Option>, +) -> Result<(), DockerWatcherError> { #[cfg(not(feature = "python"))] let config_file_path = ConfigBuilder::new() .with_num_clients(init_num_clients) + .with_data_source(data_source) .build(); #[cfg(feature = "python")] let config_file_path = ConfigBuilder::new() .with_num_clients(init_num_clients) + .with_data_source(data_source) .with_architecture("HfAuto") .with_batch_size(8 * init_num_clients as u32) .build(); diff --git a/architectures/decentralized/testing/src/docker_watcher.rs b/architectures/decentralized/testing/src/docker_watcher.rs index 5905d1019..bb266cc88 100644 --- a/architectures/decentralized/testing/src/docker_watcher.rs +++ b/architectures/decentralized/testing/src/docker_watcher.rs @@ -29,6 +29,8 @@ pub enum Response { SolanaSubscription(String, String), WitnessElected(String), Error(ObservedErrorKind, String), + DataProviderFetchSuccess(u64), + DataProviderFetchError(u64), } #[derive(thiserror::Error, Debug)] @@ -310,6 +312,26 @@ impl DockerWatcher { println!("Probably the test ended so we drop the log sender"); } } + IntegrationTestLogMarker::DataProviderFetchSuccess => { + let provider_idx = parsed_log + .get("provider_idx") + .and_then(|v| v.as_u64()) + .unwrap(); + let response = Response::DataProviderFetchSuccess(provider_idx); + if log_sender.send(response).await.is_err() { + println!("Probably the test ended so we drop the log sender"); + } + } + IntegrationTestLogMarker::DataProviderFetchError => { + let provider_idx = parsed_log + .get("provider_idx") + .and_then(|v| v.as_u64()) + .unwrap(); + let response = Response::DataProviderFetchError(provider_idx); + if log_sender.send(response).await.is_err() { + println!("Probably the test ended so we drop the log sender"); + } + } } } Ok(()) diff --git a/architectures/decentralized/testing/src/utils.rs b/architectures/decentralized/testing/src/utils.rs index 631b05f52..af3649754 100644 --- a/architectures/decentralized/testing/src/utils.rs +++ b/architectures/decentralized/testing/src/utils.rs @@ -6,7 +6,7 @@ use anchor_client::{ }; use psyche_coordinator::{ NUM_STORED_ROUNDS, Round, RunState, - model::{Checkpoint, Model}, + model::{Checkpoint, LLMTrainingDataLocation, Model}, }; use psyche_core::FixedVec; use psyche_solana_coordinator::{ClientId, SOLANA_MAX_NUM_PENDING_CLIENTS}; @@ -130,6 +130,7 @@ pub struct ConfigBuilder { num_clients: usize, batch_size: u32, architecture: String, + data_source: Option>, } impl Default for ConfigBuilder { @@ -157,6 +158,7 @@ impl ConfigBuilder { num_clients: 1, batch_size: 4, architecture: String::from("HfLlama"), + data_source: None, } } @@ -175,6 +177,11 @@ impl ConfigBuilder { self } + pub fn with_data_source(mut self, source: Option>) -> Self { + self.data_source = source; + self + } + pub fn build(mut self) -> PathBuf { // Apply runtime overrides self.set_value("config.min_clients", self.num_clients as u32); @@ -187,6 +194,12 @@ impl ConfigBuilder { self.set_value("config.global_batch_size_start", self.batch_size); self.set_value("config.global_batch_size_end", self.batch_size); + if let Some(src) = self.data_source.clone() { + self.set_value( + "model.LLM.data_locations", + toml::Value::try_from(src).unwrap(), + ); + } #[cfg(feature = "python")] self.set_value("config.warmup_time", 100); diff --git a/architectures/decentralized/testing/tests/integration_tests.rs b/architectures/decentralized/testing/tests/integration_tests.rs index 77c31d62f..32cf29c75 100644 --- a/architectures/decentralized/testing/tests/integration_tests.rs +++ b/architectures/decentralized/testing/tests/integration_tests.rs @@ -10,8 +10,11 @@ use std::{sync::Arc, time::Duration}; use bollard::container::StartContainerOptions; use bollard::{Docker, container::KillContainerOptions}; use psyche_client::IntegrationTestLogMarker; +use psyche_coordinator::model::{DummyType, LLMTrainingDataLocation}; use psyche_coordinator::{RunState, model::Checkpoint}; -use psyche_decentralized_testing::docker_setup::e2e_testing_setup_subscription; +use psyche_decentralized_testing::docker_setup::{ + e2e_testing_setup_subscription, e2e_testing_setup_with_datasource, +}; use psyche_decentralized_testing::{ CLIENT_CONTAINER_PREFIX, NGINX_PROXY_PREFIX, chaos::{ChaosAction, ChaosScheduler}, @@ -676,7 +679,7 @@ async fn test_solana_subscriptions() { let mut watcher = DockerWatcher::new(docker.clone()); // Initialize a Solana run with 2 client - let _cleanup = e2e_testing_setup_subscription(docker.clone(), 2).await; + let _cleanup = e2e_testing_setup_subscription(docker.clone(), 2, None).await; // Monitor the client containers let _monitor_client_1 = watcher @@ -949,3 +952,93 @@ async fn test_lost_only_peer_go_back_to_hub_checkpoint() { } } } + +/// spawn 2 clients and run for 3 epochs but the first defined data provider fails +/// this tests checks that the logic for retrying the failing data provider and switching to the new is working +#[test_log::test(tokio::test(flavor = "multi_thread"))] +#[serial] +async fn test_backup_data_provider() { + let mut saw_provider_0_error = false; + let mut successful_fetches_after_error = 0; + let mut current_epoch = -1; + + let docker = Arc::new(Docker::connect_with_socket_defaults().unwrap()); + let mut watcher = DockerWatcher::new(docker.clone()); + + let _cleanup = e2e_testing_setup_with_datasource( + docker.clone(), + 2, + Some(vec![LLMTrainingDataLocation::Dummy(DummyType::Failing)]), + ) + .await; + + let _monitor_client_1 = watcher + .monitor_container( + &format!("{CLIENT_CONTAINER_PREFIX}-1"), + vec![ + IntegrationTestLogMarker::Loss, + IntegrationTestLogMarker::DataProviderFetchError, + IntegrationTestLogMarker::DataProviderFetchSuccess, + ], + ) + .unwrap(); + let _monitor_client_2 = watcher + .monitor_container( + &format!("{CLIENT_CONTAINER_PREFIX}-2"), + vec![ + IntegrationTestLogMarker::Loss, + IntegrationTestLogMarker::DataProviderFetchError, + IntegrationTestLogMarker::DataProviderFetchSuccess, + ], + ) + .unwrap(); + + let mut live_interval = time::interval(Duration::from_secs(10)); + loop { + tokio::select! { + _ = live_interval.tick() => { + if let Err(e) = watcher.monitor_clients_health(2).await { + panic!("{}", e); + } + } + response = watcher.log_rx.recv() => { + match response { + Some(Response::DataProviderFetchError(provider_idx)) => { + println!("Data provider {} fetch error", provider_idx); + if provider_idx == 0 { + saw_provider_0_error = true; + } + } + Some(Response::DataProviderFetchSuccess(provider_idx)) => { + println!("Data provider {} fetch success", provider_idx); + if provider_idx == 1 && saw_provider_0_error { + successful_fetches_after_error += 1; + println!("Successful fetch {} after error", successful_fetches_after_error); + if successful_fetches_after_error >= 2 { + println!("Saw 2 successful fetches after error, test successful!"); + return; + } + } + } + Some(Response::Loss(client, epoch, step, loss)) => { + println!( + "client: {:?}, epoch: {}, step: {}, Loss: {:?}", + client, epoch, step, loss + ); + // assert that the loss decreases each epoch + if epoch as i64 > current_epoch { + current_epoch = epoch as i64; + + if epoch > 1 { + assert!(saw_provider_0_error, "Should have seen error from provider 0"); + assert!(successful_fetches_after_error >= 2, "Should have seen successful fetch after error"); + return; + } + } + } + _ => {} + } + } + } + } +} diff --git a/config/consilience/40b-devnet.toml b/config/consilience/40b-devnet.toml index 9fd1e8c34..49981bde9 100644 --- a/config/consilience/40b-devnet.toml +++ b/config/consilience/40b-devnet.toml @@ -18,7 +18,9 @@ waiting_for_members_extra_time = 10 architecture = "HfDeepseek" data_type = "Pretraining" max_seq_len = 2048 -data_location = { WeightedHttp = "https://storage.googleapis.com/nous-pretraining-public-us/consilience-stage1-mix.json" } +data_locations = [ + { WeightedHttp = "https://storage.googleapis.com/nous-pretraining-public-us/consilience-stage1-mix.json" }, +] cold_start_warmup_steps = 100 [model.LLM.checkpoint.Hub] diff --git a/config/solana-test/config.toml b/config/solana-test/config.toml index d40bdf77a..8468e272c 100644 --- a/config/solana-test/config.toml +++ b/config/solana-test/config.toml @@ -19,17 +19,10 @@ architecture = "HfLlama" data_type = "Pretraining" max_seq_len = 2048 cold_start_warmup_steps = 0 - -[model.LLM.checkpoint.Hub] -repo_id = "emozilla/llama2-1.1b-gqa-init" - -[model.LLM.data_location.Http] -token_size_in_bytes = "TwoBytes" -shuffle = "DontShuffle" - -[model.LLM.data_location.Http.location.Gcp] -bucket_name = "nous-pretraining-public-us" -filter_directory = "fineweb-edu-tokenized-llama2" +checkpoint = { Hub = { repo_id = "emozilla/llama2-1.1b-gqa-init" } } +data_locations = [ + { Http = { location = { Gcp = { bucket_name = "nous-pretraining-public-us", filter_directory = "fineweb-edu-tokenized-llama2" } }, token_size_in_bytes = "TwoBytes", shuffle = "DontShuffle" } }, +] [model.LLM.lr_schedule.Cosine] base_lr = 4.0e-4 diff --git a/config/solana-test/light-config-dummy-failing.toml b/config/solana-test/light-config-dummy-failing.toml new file mode 100644 index 000000000..d8c48ef89 --- /dev/null +++ b/config/solana-test/light-config-dummy-failing.toml @@ -0,0 +1,41 @@ +[config] +warmup_time = 30 +cooldown_time = 30 +rounds_per_epoch = 20 +max_round_train_time = 30 +round_witness_time = 1 +min_clients = 1 +init_min_clients = 1 +verification_percent = 0 +witness_nodes = 1 +global_batch_size_start = 8 +global_batch_size_end = 8 +global_batch_size_warmup_tokens = 0 +total_steps = 25000 + +[model.LLM] +architecture = "HfLlama" +data_type = "Pretraining" +max_seq_len = 2048 +cold_start_warmup_steps = 0 +data_locations = [ + { Dummy = "Failing" }, + { Http = { location = { Gcp = { bucket_name = "nous-pretraining-public-us", filter_directory = "fineweb-edu-tokenized-llama2" } }, token_size_in_bytes = "TwoBytes", shuffle = "DontShuffle" } }, +] + +[model.LLM.checkpoint.Hub] +repo_id = "emozilla/llama2-20m-init" + +[model.LLM.lr_schedule.Cosine] +base_lr = 4.0e-4 +warmup_steps = 250 +warmup_init_lr = 0.0 +total_steps = 25000 +final_lr = 4.0e-5 + +[model.LLM.optimizer.Distro] +clip_grad_norm = 1.0 +compression_decay = 0.999 +compression_chunk = 64 +compression_topk = 8 +quantize_1bit = true diff --git a/config/solana-test/light-config.toml b/config/solana-test/light-config.toml index 5de8ca8a4..283086467 100644 --- a/config/solana-test/light-config.toml +++ b/config/solana-test/light-config.toml @@ -19,17 +19,13 @@ architecture = "HfLlama" data_type = "Pretraining" max_seq_len = 2048 cold_start_warmup_steps = 0 +data_locations = [ + { Http = { location = { Gcp = { bucket_name = "nous-pretraining-public-us", filter_directory = "fineweb-edu-tokenized-llama2" } }, token_size_in_bytes = "TwoBytes", shuffle = "DontShuffle" } }, +] [model.LLM.checkpoint.Hub] repo_id = "emozilla/llama2-20m-init" -[model.LLM.data_location.Http] -token_size_in_bytes = "TwoBytes" -shuffle = "DontShuffle" -[model.LLM.data_location.Http.location.Gcp] -bucket_name = "nous-pretraining-public-us" -filter_directory = "fineweb-edu-tokenized-llama2" - [model.LLM.lr_schedule.Cosine] base_lr = 4.0e-4 warmup_steps = 250 diff --git a/config/solana-test/nano-config.toml b/config/solana-test/nano-config.toml index c275feea3..3e64d904c 100644 --- a/config/solana-test/nano-config.toml +++ b/config/solana-test/nano-config.toml @@ -19,18 +19,14 @@ architecture = "HfLlama" data_type = "Pretraining" max_seq_len = 64 cold_start_warmup_steps = 0 +data_locations = [ + { Http = { location = { SingleUrl = "https://huggingface.co/pefontana/Nano-Llama/resolve/main/tiny-ci-dataset/000_tiny-test.ds" }, token_size_in_bytes = "TwoBytes", shuffle = "DontShuffle" } }, +] [model.LLM.checkpoint.Hub] repo_id = "pefontana/Nano-Llama" revision = "cf48eac4944f6e954a3d9c9c30e8c865e64e7d03" -[model.LLM.data_location.Http] -token_size_in_bytes = "TwoBytes" -shuffle = "DontShuffle" - -[model.LLM.data_location.Http.location] -SingleUrl = "https://huggingface.co/pefontana/Nano-Llama/resolve/main/tiny-ci-dataset/000_tiny-test.ds" - [model.LLM.lr_schedule.Cosine] base_lr = 4.0e-4 warmup_steps = 250 diff --git a/config/test/state.toml b/config/test/state.toml index 94f35c970..352630043 100644 --- a/config/test/state.toml +++ b/config/test/state.toml @@ -22,7 +22,7 @@ data_type = "Pretraining" max_seq_len = 512 checkpoint = "Dummy" optimizer = "Dummy" -data_location = "Dummy" +data_locations = [{ Dummy = "Working" }] cold_start_warmup_steps = 0 diff --git a/nix/docker.nix b/nix/docker.nix index 6fa2e2774..34e153ea1 100644 --- a/nix/docker.nix +++ b/nix/docker.nix @@ -66,9 +66,11 @@ let '') (pkgs.runCommand "entrypoint" { } '' mkdir -p $out/bin + mkdir -p $out/architectures/decentralized/solana-authorizer/target/deploy cp ${../docker/test/client_test_entrypoint.sh} $out/bin/client_test_entrypoint.sh cp ${../docker/test/run_owner_entrypoint.sh} $out/bin/run_owner_entrypoint.sh cp ${../scripts/join-authorization-create.sh} $out/bin/join-authorization-create.sh + cp ${../architectures/decentralized/solana-authorizer/target/deploy/psyche_solana_authorizer-keypair.json} $out/architectures/decentralized/solana-authorizer/target/deploy/psyche_solana_authorizer-keypair.json chmod +x $out/bin/client_test_entrypoint.sh chmod +x $out/bin/run_owner_entrypoint.sh chmod +x $out/bin/join-authorization-create.sh diff --git a/psyche-book/src/enduser/run-config.md b/psyche-book/src/enduser/run-config.md index 299ab98ea..d225fc666 100644 --- a/psyche-book/src/enduser/run-config.md +++ b/psyche-book/src/enduser/run-config.md @@ -66,18 +66,14 @@ total_steps = 25000 architecture = "HfLlama" data_type = "Pretraining" max_seq_len = 2048 +# You may define more than one data location, to use as backup +data_locations = [ + { Http = { location = { Gcp = { bucket_name = "nous-pretraining-public-us", filter_directory = "fineweb-edu-tokenized-llama2" } }, token_size_in_bytes = "TwoBytes", shuffle = "DontShuffle" } } +] [model.LLM.checkpoint.Hub] repo_id = "emozilla/llama2-20m-init" -[model.LLM.data_location.Http] -token_size_in_bytes = "TwoBytes" -shuffle = "DontShuffle" - -[model.LLM.data_location.Http.location.Gcp] -bucket_name = "nous-pretraining-public-us" -filter_directory = "fineweb-edu-tokenized-llama2" - [model.LLM.lr_schedule.Cosine] base_lr = 4.0e-4 warmup_steps = 250 diff --git a/scripts/join-authorization-create.sh b/scripts/join-authorization-create.sh index 7b2e521f2..fba7d8201 100644 --- a/scripts/join-authorization-create.sh +++ b/scripts/join-authorization-create.sh @@ -31,7 +31,7 @@ GRANTOR_PUBKEY=$(solana-keygen pubkey $GRANTOR_KEYPAIR_FILE) GRANTEE_PUBKEY="$1" shift -PSYCHE_AUTHORIZER_ID="PsyAUmhpmiUouWsnJdNGFSX8vZ6rWjXjgDPHsgqPGyw" +PSYCHE_AUTHORIZER_ID=$(solana-keygen pubkey ./architectures/decentralized/solana-authorizer/target/deploy/psyche_solana_authorizer-keypair.json) PSYCHE_AUTH_SCOPE="utf8:CoordinatorJoinRun" # Make sure all is good to go diff --git a/shared/client/src/fetch_data.rs b/shared/client/src/fetch_data.rs index 10f7d84bd..9ffdbd7e4 100644 --- a/shared/client/src/fetch_data.rs +++ b/shared/client/src/fetch_data.rs @@ -1,3 +1,4 @@ +use anyhow::{Result, bail}; use psyche_coordinator::{Coordinator, get_batch_ids_for_node}; use psyche_core::{BatchId, NodeIdentity}; use psyche_data_provider::{DataProvider, TokenizedDataProvider}; @@ -14,29 +15,39 @@ use tokio::{ task::JoinHandle, time::sleep, }; -use tracing::{Instrument, debug, error, trace, trace_span, warn}; +use tracing::{Instrument, debug, error, info, trace, trace_span, warn}; + +use crate::IntegrationTestLogMarker; pub type BatchStep = u32; pub type BatchIdSet = HashSet; -const MAX_RETRIES: u32 = 7; -const BASE_DELAY_MS: u64 = 2000; +const MAX_RETRIES: u32 = 4; +const BASE_DELAY_MS: u64 = 500; pub struct DataFetcher { - data_provider: Arc>>, + data_providers: Vec>>>, active_fetch_task: Option<(BatchStep, JoinHandle<()>)>, buffer_size: usize, + last_successful_provider_idx: Arc>, // Store the index of the last successful provider _phantom: PhantomData, } impl DataFetcher { - pub fn new(data_provider: DataProvider, buffer_size: usize) -> Self { - Self { - data_provider: Arc::new(Mutex::new(data_provider)), + pub fn new(data_providers: Vec>, buffer_size: usize) -> Result { + if data_providers.is_empty() { + bail!("Must provide at least one data provider"); + } + Ok(Self { + data_providers: data_providers + .into_iter() + .map(|dp| Arc::new(Mutex::new(dp))) + .collect(), active_fetch_task: None, buffer_size, + last_successful_provider_idx: Arc::new(Mutex::new(0)), // Start with the first provider _phantom: Default::default(), - } + }) } pub fn fetch_data( @@ -69,38 +80,95 @@ impl DataFetcher { step, tokio::spawn({ trace!("New fetch task for step {step} has been spawned"); - let data_provider = self.data_provider.clone(); // only one of these tasks will acquire the lock at once. once one dies, the lock is released for sure. + let data_providers = self.data_providers.clone(); + let last_successful_provider_idx = self.last_successful_provider_idx.clone(); // Clone Arc for the task async move { + let num_providers = data_providers.len(); + if num_providers == 0 { + error!("No data providers configured."); + return; + } + loop { let batch_id = { match assigned_batch_ids.pop() { Some(assigned) => assigned, None => { - // out of assigned data! + debug!("No more assigned batch IDs for step {step}."); return; } } }; - let mut retry_count = 0; - let batch = loop { - match data_provider.lock().await.get_samples(batch_id).await { - Ok(batch) => break batch, - Err(err) if retry_count < MAX_RETRIES => { - retry_count += 1; - let delay_ms = BASE_DELAY_MS * (retry_count as u64 - 1); - warn!( - "Data fetch error for batch_id={} (attempt {}/{}): \"{:#}\". Retrying in {}ms", - batch_id, retry_count, MAX_RETRIES, err, delay_ms - ); - sleep(Duration::from_millis(delay_ms)).await; - continue; - } - Err(err) => { - error!("Data fetch failed for batch_id={} after {} attempts: {err:#}", batch_id, MAX_RETRIES); - return; + let mut batch_option = None; + let start_idx = *last_successful_provider_idx.lock().await; // Read the last successful index + + // Iterate through providers, starting from the last successful one and wrapping around + for i in 0..num_providers { + let provider_idx = (start_idx + i) % num_providers; + let data_provider = &data_providers[provider_idx]; + + info!(batch_id = %batch_id, provider_idx, "Attempting fetch with provider {}", provider_idx); + let mut retry_count = 0; + loop { + match data_provider.lock().await.get_samples(batch_id).await { + Ok(batch) => { + info!( + integration_test_log_marker = %IntegrationTestLogMarker::DataProviderFetchSuccess, + batch_id = %batch_id, + provider_idx, + "Successfully fetched batch with provider", + ); + batch_option = Some(batch); + // Update the last successful index + *last_successful_provider_idx.lock().await = provider_idx; + break; // Break retry loop, batch found + }, + Err(err) if retry_count < MAX_RETRIES => { + retry_count += 1; + let delay_ms = BASE_DELAY_MS * 2u64.pow(retry_count - 1); + let delay_ms = Duration::from_millis(delay_ms / 2); + + warn!( + batch_id = %batch_id, + provider_idx, + attempt = retry_count, + max_retries = MAX_RETRIES, + error = %err, + "Data fetch error for batch_id={} (attempt {}/{}) with provider {} \"{:#}\". Retrying in {}ms", + provider_idx, batch_id, retry_count, MAX_RETRIES, err, delay_ms.as_millis() + ); + sleep(delay_ms).await; + continue; // Continue retry loop + } + Err(err) => { + error!( + integration_test_log_marker = %IntegrationTestLogMarker::DataProviderFetchError, + batch_id = %batch_id, + provider_idx, + error = %err, + "Data fetch failed permanently for provider {}", + provider_idx + ); + break; // Break retry loop, provider failed permanently for this batch + } } + } // End retry loop + + if batch_option.is_some() { + break; // Break provider loop (for i in 0..num_providers), batch found + } + // If batch_option is None here, it means the current provider failed permanently for this batch_id + warn!(batch_id = %batch_id, provider_idx, "Provider {} failed permanently for this batch, trying next.", provider_idx); + } // End provider loop + + // After trying all providers + let batch = match batch_option { + Some(b) => b, + None => { + error!(batch_id = %batch_id, "Failed to fetch batch {} after {} attempts after trying all data providers.", batch_id, MAX_RETRIES); + continue; // Skip this batch and try the next assigned ID } }; @@ -119,12 +187,12 @@ impl DataFetcher { .await .is_err() { - debug!("Data loop finished"); - return; + debug!("Data loop finished because receiver dropped (step {step})."); + return; // Receiver is gone, stop the task } - } + } // End main loop } - .instrument(trace_span!("fetch_data")) + .instrument(trace_span!("fetch_data", step = step)) }), )); diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index 05afb13b5..90795611a 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -27,7 +27,7 @@ use tokio::{ sync::{mpsc::UnboundedSender, oneshot}, task::{JoinError, JoinHandle}, }; -use tracing::{debug, error, info}; +use tracing::{debug, error, info, warn}; use super::{ CheckpointConfig, FinishedBroadcast, cooldown::CooldownStepMetadata, evals::ModelTaskRunner, @@ -193,74 +193,144 @@ impl RunInitConfigAndIO DataProvider::Server( - DataProviderTcpClient::connect( - (&data_server).into(), - init_config.network_identity, - init_config.private_key, - ) - .await?, - ), - LLMTrainingDataLocation::Local(_) => todo!(), - LLMTrainingDataLocation::Dummy => { - DataProvider::Dummy(DummyDataProvider::new(TokenSize::TwoBytes, 2048, u64::MAX)) - } - LLMTrainingDataLocation::Http(HttpLLMTrainingDataLocation { - location, - token_size_in_bytes, - shuffle, - }) => { - let file_urls = FileURLs::from_location(&location).await?; - DataProvider::Http(HttpDataProvider::new( - file_urls, + debug!("Setting up data providers from {:?}", data_locations); + let mut data_providers = Vec::new(); + + for data_location in data_locations.iter() { + let provider = match data_location { + LLMTrainingDataLocation::Server(data_server) => { + let client = match DataProviderTcpClient::connect( + data_server.into(), + init_config.network_identity.clone(), + init_config.private_key.clone(), + ) + .await + { + Ok(client) => client, + Err(e) => { + warn!("Failed to connect to data server at {}: {}", data_server, e); + continue; + } + }; + Some(DataProvider::Server(client)) + } + LLMTrainingDataLocation::Local(_) => todo!(), + LLMTrainingDataLocation::Dummy(dummy_type) => Some(DataProvider::Dummy( + DummyDataProvider::new(TokenSize::TwoBytes, 2048, u64::MAX, *dummy_type), + )), + LLMTrainingDataLocation::Http(HttpLLMTrainingDataLocation { + location, token_size_in_bytes, - llm.max_seq_len, shuffle, - )?) - } - LLMTrainingDataLocation::WeightedHttp(config_url) => DataProvider::WeightedHttp( - WeightedDataProvider::::from_config_url( - &String::from(&config_url), - llm.max_seq_len, - ) - .await?, - ), - LLMTrainingDataLocation::Preprocessed(url) => { - let url: String = (&url).into(); - let dir = if std::fs::exists(&url).unwrap_or_default() { - PathBuf::from(url) - } else { - download_dataset_repo_async( - url.clone(), - None, - None, - hub_read_token, - Some(hub_max_concurrent_downloads), - false, - ) - .await? - .first() - .ok_or(anyhow::anyhow!("No files downloaded for {url}"))? - .parent() - .unwrap() - .into() - }; - DataProvider::Preprocessed(PreprocessedDataProvider::new_from_directory( - dir, - llm.max_seq_len as usize, - Shuffle::DontShuffle, - Some(Split::Train), - None, - )?) + }) => { + if let Ok(file_urls) = FileURLs::from_location(location).await { + if let Ok(provider) = HttpDataProvider::new( + file_urls, + *token_size_in_bytes, + llm.max_seq_len, + *shuffle, + ) { + Some(DataProvider::Http(provider)) + } else { + warn!( + "Failed to create HTTP data provider for location: {:?}", + location + ); + None + } + } else { + warn!( + "Failed to create HTTP data provider for location: {:?}", + location + ); + None + } + } + LLMTrainingDataLocation::WeightedHttp(config_url) => { + if let Ok(provider) = + WeightedDataProvider::::from_config_url( + &String::from(config_url), + llm.max_seq_len, + ) + .await + { + Some(DataProvider::WeightedHttp(provider)) + } else { + warn!( + "Failed to create Weighted HTTP data provider for config URL: {}", + config_url + ); + None + } + } + + LLMTrainingDataLocation::Preprocessed(url) => { + let url: String = (url).into(); + let dir: anyhow::Result = + if std::fs::exists(&url).unwrap_or_default() { + Ok(PathBuf::from(url.clone())) + } else { + let dataset_download = download_dataset_repo_async( + url.clone(), + None, + None, + hub_read_token.clone(), + Some(hub_max_concurrent_downloads), + false, + ) + .await + .map_err(|e| { + anyhow::anyhow!("Failed to download repo for {url}: {e:?}") + }); + let file_in_repo = dataset_download.and_then(|r| { + r.into_iter() + .nth(0) + .ok_or(anyhow::anyhow!("No files downloaded for {url}")) + }); + file_in_repo.and_then(|f| { + f.parent() + .map(|p| p.to_owned()) + .ok_or(anyhow::anyhow!("Path has no parent")) + }) + }; + let provider = dir.and_then(|dir| { + PreprocessedDataProvider::new_from_directory( + dir, + llm.max_seq_len as usize, + Shuffle::DontShuffle, + Some(Split::Train), + None, + ) + }); + match provider { + Ok(provider) => Some(DataProvider::Preprocessed(provider)), + Err(err) => { + warn!( + "Failed to create Preprocessed data provider for URL: {}\n{:?}", + url, err + ); + None + } + } + } + }; + if let Some(provider) = provider { + data_providers.push(provider); } - }; - Ok(data_provider) + } + if data_providers.is_empty() { + Err(InitRunError::DataProviderConnect(anyhow::anyhow!( + "No valid data providers could be initialized." + ))) + } else { + info!("Initialized {} data providers", data_providers.len()); + Ok(data_providers) + } }; let model_future: JoinHandle> = match &llm.architecture @@ -680,9 +750,10 @@ impl RunInitConfigAndIO::new(data_provider, init_config.data_parallelism * 2); + DataFetcher::::new(data_providers, init_config.data_parallelism * 2); let trainers: Vec = match models { RawLoadedModelType::ParallelNativeModels(models) => { @@ -808,7 +879,7 @@ impl RunInitConfigAndIO "solana_subscription", Self::WitnessElected => "witness_elected", Self::Error => "error", + Self::DataProviderFetchSuccess => "data_provider_fetch_success", + Self::DataProviderFetchError => "data_provider_fetch_error", } ) } @@ -44,6 +48,8 @@ impl FromStr for IntegrationTestLogMarker { "solana_subscription" => Self::SolanaSubscription, "witness_elected" => Self::WitnessElected, "error" => Self::Error, + "data_provider_fetch_success" => Self::DataProviderFetchSuccess, + "data_provider_fetch_error" => Self::DataProviderFetchError, _ => return Err(()), }) } diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index a9ac66c85..19ea58b04 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -1,6 +1,6 @@ use crate::{ Commitment, Committee, CommitteeProof, CommitteeSelection, WitnessProof, - model::{Checkpoint, HubRepo, Model}, + model::{Checkpoint, HubRepo, LLMDataLocations, Model}, }; use anchor_lang::{AnchorDeserialize, AnchorSerialize, InitSpace, prelude::borsh}; @@ -304,6 +304,8 @@ pub struct Coordinator { pub model: Model, + pub data_locations: LLMDataLocations, + pub config: CoordinatorConfig, #[serde(default)] diff --git a/shared/coordinator/src/model.rs b/shared/coordinator/src/model.rs index 083f3ce31..182b4ad64 100644 --- a/shared/coordinator/src/model.rs +++ b/shared/coordinator/src/model.rs @@ -86,10 +86,8 @@ pub enum LLMTrainingDataType { )] #[repr(C)] #[allow(clippy::large_enum_variant)] -#[derive(Default)] pub enum LLMTrainingDataLocation { - #[default] - Dummy, + Dummy(DummyType), Server(FixedString<{ SOLANA_MAX_STRING_LEN }>), Local(FixedString<{ SOLANA_MAX_URL_STRING_LEN }>), Http(HttpLLMTrainingDataLocation), @@ -98,6 +96,32 @@ pub enum LLMTrainingDataLocation { Preprocessed(FixedString<{ SOLANA_MAX_URL_STRING_LEN }>), } +impl Default for LLMTrainingDataLocation { + fn default() -> Self { + Self::Dummy(DummyType::Working) + } +} + +#[derive( + AnchorSerialize, + AnchorDeserialize, + InitSpace, + Serialize, + Deserialize, + Clone, + Debug, + Zeroable, + Copy, + TS, + PartialEq, + Eq, +)] +#[repr(C)] +pub enum DummyType { + Working, + Failing, +} + #[derive( AnchorSerialize, AnchorDeserialize, @@ -165,7 +189,6 @@ impl LLMTrainingDataLocationAndWeight { TS, )] #[repr(C)] -#[allow(clippy::large_enum_variant)] pub enum HttpTrainingDataLocation { SingleUrl(FixedString<{ SOLANA_MAX_URL_STRING_LEN }>), NumberedFiles { @@ -182,6 +205,8 @@ pub enum HttpTrainingDataLocation { }, } +pub const MAX_DATA_LOCATIONS: usize = 3; + #[derive( AnchorSerialize, AnchorDeserialize, Serialize, Deserialize, Clone, Debug, Zeroable, Copy, TS, )] @@ -192,17 +217,44 @@ pub struct LLM { pub architecture: LLMArchitecture, pub checkpoint: Checkpoint, pub data_type: LLMTrainingDataType, - pub data_location: LLMTrainingDataLocation, pub lr_schedule: LearningRateSchedule, pub optimizer: OptimizerDefinition, } +#[derive( + AnchorSerialize, AnchorDeserialize, Serialize, Deserialize, Clone, Debug, Zeroable, Copy, TS, +)] +#[repr(C)] +pub struct LLMDataLocations { + pub data_locations: FixedVec, +} + +impl LLMDataLocations { + pub fn iter(&self) -> impl DoubleEndedIterator { + self.data_locations.iter() + } + + pub fn iter_mut(&mut self) -> impl DoubleEndedIterator { + self.data_locations.iter_mut() + } +} + +impl LLMDataLocations { + pub fn dummy() -> Self { + let mut data_locations: FixedVec = + FixedVec::new(); + data_locations + .push(LLMTrainingDataLocation::Dummy(DummyType::Working)) + .unwrap(); + Self { data_locations } + } +} + impl LLM { pub fn dummy() -> Self { Self { architecture: LLMArchitecture::HfLlama, checkpoint: Checkpoint::Dummy(HubRepo::dummy()), - data_location: LLMTrainingDataLocation::default(), data_type: LLMTrainingDataType::Pretraining, lr_schedule: LearningRateSchedule::Constant(ConstantLR::default()), max_seq_len: 2048, @@ -280,28 +332,6 @@ impl Model { return false; } - let bad_data_location = match llm.data_location { - LLMTrainingDataLocation::Dummy => false, - LLMTrainingDataLocation::Server(url) => url.is_empty(), - LLMTrainingDataLocation::Local(_) => false, - LLMTrainingDataLocation::Http(HttpLLMTrainingDataLocation { - location, .. - }) => match location { - HttpTrainingDataLocation::SingleUrl(url) => url.is_empty(), - HttpTrainingDataLocation::NumberedFiles { - url_template, - num_files, - .. - } => url_template.is_empty() || num_files == 0, - HttpTrainingDataLocation::Gcp { bucket_name, .. } => bucket_name.is_empty(), - }, - LLMTrainingDataLocation::WeightedHttp(url) => url.is_empty(), - LLMTrainingDataLocation::Preprocessed(url) => url.is_empty(), - }; - if bad_data_location { - msg!("model check failed: bad LLM training data location."); - return false; - } let bad_checkpoint = match llm.checkpoint { Checkpoint::Dummy(_hub_repo) => false, Checkpoint::Ephemeral => true, diff --git a/shared/data-provider/src/dummy.rs b/shared/data-provider/src/dummy.rs index 60042436f..349c5ad03 100644 --- a/shared/data-provider/src/dummy.rs +++ b/shared/data-provider/src/dummy.rs @@ -1,11 +1,13 @@ use crate::{LengthKnownDataProvider, TokenizedData, traits::TokenizedDataProvider}; use anyhow::{Result, bail}; +use psyche_coordinator::model::DummyType; use psyche_core::{BatchId, TokenSize}; pub struct DummyDataProvider { seq_len: usize, token_size_in_bytes: TokenSize, num_sequences: u64, + dummy_type: DummyType, } impl DummyDataProvider { @@ -13,11 +15,13 @@ impl DummyDataProvider { token_size_in_bytes: TokenSize, num_tokens_per_sequence: usize, // num tokens per sequence num_sequences: u64, + dummy_type: DummyType, ) -> Self { Self { seq_len: num_tokens_per_sequence, token_size_in_bytes, num_sequences, + dummy_type, } } @@ -45,6 +49,9 @@ impl DummyDataProvider { impl TokenizedDataProvider for DummyDataProvider { async fn get_samples(&mut self, data_ids: BatchId) -> Result> { + if self.dummy_type == DummyType::Failing { + return Err(anyhow::anyhow!("DummyDataProvider dummy error")); + } for id in data_ids.iter() { if id > self.num_sequences { bail!("id {id} > self.num_sequences {}", self.num_sequences) diff --git a/shared/data-provider/tests/weighted.rs b/shared/data-provider/tests/weighted.rs index 9730d6370..183000ee2 100644 --- a/shared/data-provider/tests/weighted.rs +++ b/shared/data-provider/tests/weighted.rs @@ -1,4 +1,5 @@ use anyhow::Result; +use psyche_coordinator::model::DummyType; use psyche_core::{BatchId, ClosedInterval, Shuffle, TokenSize}; use psyche_data_provider::{ DummyDataProvider, LengthKnownDataProvider, TokenizedData, TokenizedDataProvider, @@ -185,8 +186,8 @@ async fn test_weighted_data_provider_consistency() -> Result<()> { #[test(tokio::test)] async fn test_weighted_data_provider_with_dummy_provider() -> Result<()> { - let dummy1 = DummyDataProvider::new(TokenSize::TwoBytes, 10, 50); // 10 tokens per sequence - let dummy2 = DummyDataProvider::new(TokenSize::TwoBytes, 10, 50); + let dummy1 = DummyDataProvider::new(TokenSize::TwoBytes, 10, 50, DummyType::Working); // 10 tokens per sequence + let dummy2 = DummyDataProvider::new(TokenSize::TwoBytes, 10, 50, DummyType::Working); let mut weighted_provider = WeightedDataProvider::new( vec![(dummy1, 0.5), (dummy2, 0.5)], diff --git a/website/backend/src/coordinatorChainLoop.ts b/website/backend/src/coordinatorChainLoop.ts index 21aa04dc2..2076d8651 100644 --- a/website/backend/src/coordinatorChainLoop.ts +++ b/website/backend/src/coordinatorChainLoop.ts @@ -353,6 +353,28 @@ export async function startWatchCoordinatorChainLoop( }) break } + case 'clear_data_locations': { + const runPdaAddr = i.accounts[1].toString() + const coordinatorAddr = i.accounts[2].toString() + runUpdates.getAndTouchCurrentRun({ + runPdaAddr, + coordinatorAddr, + decoded, + tx, + }) + break + } + case 'update_data_locations': { + const runPdaAddr = i.accounts[1].toString() + const coordinatorAddr = i.accounts[2].toString() + runUpdates.getAndTouchCurrentRun({ + runPdaAddr, + coordinatorAddr, + decoded, + tx, + }) + break + } default: { const _missed_tx: never = decoded throw new Error(