Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 12 additions & 9 deletions architectures/centralized/client/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use anyhow::{Error, Result};
use bytemuck::Zeroable;
use hf_hub::Repo;
use psyche_centralized_shared::{ClientId, ClientToServerMessage, ServerToClientMessage};
use psyche_client::HubUploadInfo;
use psyche_client::UploadInfo;
use psyche_client::{
Client, ClientTUI, ClientTUIState, NC, RunInitConfig, TrainArgs, read_identity_secret_key,
};
Expand Down Expand Up @@ -29,7 +31,7 @@ pub type TabsData = <Tabs as CustomWidget>::Data;
pub enum ToSend {
Witness(Box<OpportunisticData>),
HealthCheck(HealthChecks<ClientId>),
Checkpoint(model::HubRepo),
Checkpoint(model::Checkpoint),
}

struct Backend {
Expand Down Expand Up @@ -67,7 +69,7 @@ impl WatcherBackend<ClientId> for Backend {
Ok(())
}

async fn send_checkpoint(&mut self, checkpoint: model::HubRepo) -> Result<()> {
async fn send_checkpoint(&mut self, checkpoint: model::Checkpoint) -> Result<()> {
self.tx.send(ToSend::Checkpoint(checkpoint))?;
Ok(())
}
Expand Down Expand Up @@ -173,18 +175,19 @@ impl App {
) -> Result<()> {
// sanity checks
if let Some(checkpoint_config) = &state_options.checkpoint_config {
if let Some(hub_upload) = &checkpoint_config.hub_upload {
if let Some(UploadInfo::Hub(HubUploadInfo {
hub_repo,
hub_token,
})) = &checkpoint_config.upload_info
{
let api = hf_hub::api::tokio::ApiBuilder::new()
.with_token(Some(hub_upload.hub_token.clone()))
.with_token(Some(hub_token.clone()))
.build()?;
let repo_api = api.repo(Repo::new(
hub_upload.hub_repo.clone(),
hf_hub::RepoType::Model,
));
let repo_api = api.repo(Repo::new(hub_repo.clone(), hf_hub::RepoType::Model));
if !repo_api.is_writable().await {
anyhow::bail!(
"Checkpoint upload repo {} is not writable with the passed API key.",
hub_upload.hub_repo
hub_repo
)
}
}
Expand Down
2 changes: 1 addition & 1 deletion architectures/centralized/server/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ impl psyche_watcher::Backend<ClientId> for ChannelCoordinatorBackend {
bail!("Server does not send health checks");
}

async fn send_checkpoint(&mut self, _checkpoint: model::HubRepo) -> Result<()> {
async fn send_checkpoint(&mut self, _checkpoint: model::Checkpoint) -> Result<()> {
bail!("Server does not send checkpoints");
}
}
Expand Down
2 changes: 1 addition & 1 deletion architectures/centralized/shared/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub enum ClientToServerMessage {
Join { run_id: String },
Witness(Box<OpportunisticData>),
HealthCheck(HealthChecks<ClientId>),
Checkpoint(model::HubRepo),
Checkpoint(model::Checkpoint),
}

#[derive(Serialize, Deserialize, Debug, Clone)]
Expand Down
7 changes: 4 additions & 3 deletions architectures/decentralized/solana-client/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ use anchor_client::{
use anyhow::{Context, Result, anyhow};
use futures_util::StreamExt;
use psyche_client::IntegrationTestLogMarker;
use psyche_coordinator::{CommitteeProof, Coordinator, HealthChecks, model::HubRepo};
use psyche_coordinator::model::{self, Checkpoint};
use psyche_coordinator::{CommitteeProof, Coordinator, HealthChecks};
use psyche_watcher::{Backend as WatcherBackend, OpportunisticData};
use solana_account_decoder_client_types::{UiAccount, UiAccountEncoding};
use solana_transaction_status_client_types::UiTransactionEncoding;
Expand Down Expand Up @@ -333,7 +334,7 @@ impl SolanaBackend {
&self,
coordinator_instance: Pubkey,
coordinator_account: Pubkey,
repo: HubRepo,
repo: Checkpoint,
) {
let user = self.get_payer();
let instruction = instructions::coordinator_checkpoint(
Expand Down Expand Up @@ -603,7 +604,7 @@ impl WatcherBackend<psyche_solana_coordinator::ClientId> for SolanaBackendRunner
Ok(())
}

async fn send_checkpoint(&mut self, checkpoint: HubRepo) -> Result<()> {
async fn send_checkpoint(&mut self, checkpoint: model::Checkpoint) -> Result<()> {
self.backend
.send_checkpoint(self.instance, self.account, checkpoint);
Ok(())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use anyhow::Result;
use clap::Args;
use psyche_coordinator::model::Checkpoint;
use psyche_coordinator::model::HubRepo;
use psyche_core::FixedString;

Expand Down Expand Up @@ -45,7 +46,7 @@ pub async fn command_checkpoint_execute(
&coordinator_instance,
&coordinator_account,
&user,
repo,
Checkpoint::Hub(repo),
);
let signature = backend
.send_and_retry("Checkpoint", &[instruction], &[])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ pub fn coordinator_checkpoint(
coordinator_instance: &Pubkey,
coordinator_account: &Pubkey,
user: &Pubkey,
repo: psyche_coordinator::model::HubRepo,
repo: psyche_coordinator::model::Checkpoint,
) -> Instruction {
anchor_instruction(
psyche_solana_coordinator::ID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use psyche_coordinator::RunState;
use psyche_coordinator::SOLANA_MAX_STRING_LEN;
use psyche_coordinator::TickResult;
use psyche_coordinator::Witness;
use psyche_coordinator::model::HubRepo;
use psyche_coordinator::model::Checkpoint;
use psyche_coordinator::model::Model;
use psyche_core::FixedString;
use psyche_core::SmallBoolean;
Expand Down Expand Up @@ -389,7 +389,11 @@ impl CoordinatorInstanceState {
self.tick()
}

pub fn checkpoint(&mut self, payer: &Pubkey, repo: HubRepo) -> Result<()> {
pub fn checkpoint(
&mut self,
payer: &Pubkey,
repo: Checkpoint,
) -> Result<()> {
// O(n) on clients, reconsider
let id = self.clients_state.find_signer(payer)?;
let index = self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use psyche_coordinator::Witness;
use psyche_coordinator::WitnessBloom;
use psyche_coordinator::WitnessMetadata;
use psyche_coordinator::WitnessProof;
use psyche_coordinator::model::{HubRepo, Model};
use psyche_coordinator::model::Model;
use psyche_core::MerkleRoot;
use serde::Deserialize;
use serde::Serialize;
Expand Down Expand Up @@ -313,7 +313,7 @@ pub mod psyche_solana_coordinator {

pub fn checkpoint(
ctx: Context<PermissionlessCoordinatorAccounts>,
repo: HubRepo,
repo: psyche_coordinator::model::Checkpoint,
) -> Result<()> {
let mut account = ctx.accounts.coordinator_account.load_mut()?;
account.increment_nonce();
Expand Down
1 change: 1 addition & 0 deletions nix/lib.nix
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ let
python312
pkg-config
perl
cargo-nextest
];

buildInputs =
Expand Down
107 changes: 73 additions & 34 deletions shared/client/src/cli.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use crate::{CheckpointConfig, HubUploadInfo, WandBInfo};
use crate::{CheckpointConfig, WandBInfo};

use crate::UploadInfo;
use anyhow::{Result, anyhow, bail};
use clap::Args;
use psyche_data_provider::{GcsUploadInfo, HubUploadInfo};
use psyche_eval::tasktype_from_name;
use psyche_modeling::Devices;
use psyche_network::{DiscoveryMode, RelayKind, SecretKey};
Expand Down Expand Up @@ -146,6 +148,14 @@ pub struct TrainArgs {
#[clap(long, env)]
pub hub_repo: Option<String>,

/// Name of the GCS bucket containing model data and configuration.
#[clap(long, env)]
pub gcs_bucket: Option<String>,

/// Prefix within the GCS bucket for model data and configuration.
#[clap(long, env)]
pub gcs_prefix: Option<String>,

#[clap(long, env, default_value_t = 3)]
pub hub_max_concurrent_downloads: usize,

Expand Down Expand Up @@ -224,43 +234,72 @@ impl TrainArgs {

pub fn checkpoint_config(&self) -> Result<Option<CheckpointConfig>> {
let hub_read_token = std::env::var("HF_TOKEN").ok();
let checkpoint_upload_info = match (
&hub_read_token,
self.hub_repo.clone(),
self.checkpoint_dir.clone(),
self.delete_old_steps,
self.keep_steps,
) {
(Some(token), Some(repo), Some(dir), delete_old_steps, keep_steps) => {
if keep_steps == 0 {
bail!("keep_steps must be >= 1 for hub repository uploads (got {keep_steps})")

if self.hub_repo.is_some() && self.gcs_bucket.is_some() {
bail!("Use either GCS or HF hub for checkpoint uploads, not both.");
}

let checkpoint_dir = match &self.checkpoint_dir {
Some(dir) => dir,
None => {
if self.hub_repo.is_some() || self.gcs_bucket.is_some() {
bail!(
"--hub-repo or --gcs-bucket was set, but no --checkpoint-dir was passed!"
);
}
Some(CheckpointConfig {
checkpoint_dir: dir,
hub_upload: Some(HubUploadInfo {
hub_repo: repo,
hub_token: token.to_string(),
}),
delete_old_steps,
keep_steps,
})
}
(None, Some(_), Some(_), _, _) => {
bail!("hub-repo and checkpoint-dir set, but no HF_TOKEN env variable.")
}
(_, Some(_), None, _, _) => {
bail!("--hub-repo was set, but no --checkpoint-dir was passed!")
return Ok(None);
}
(_, None, Some(dir), delete_old_steps, keep_steps) => Some(CheckpointConfig {
checkpoint_dir: dir,
hub_upload: None,
delete_old_steps,
keep_steps,
}),
(_, None, _, _, _) => None,
};

Ok(checkpoint_upload_info)
let upload_info = self.build_upload_info(&hub_read_token)?;

if upload_info.is_some() && self.keep_steps == 0 {
bail!(
"keep_steps must be >= 1 for checkpoint uploads (got {})",
self.keep_steps
);
}

Ok(Some(CheckpointConfig {
checkpoint_dir: checkpoint_dir.clone(),
upload_info,
delete_old_steps: self.delete_old_steps,
keep_steps: self.keep_steps,
}))
}

fn build_upload_info(&self, hub_token: &Option<String>) -> Result<Option<UploadInfo>> {
if let Some(repo) = &self.hub_repo {
return self.build_hub_upload_info(repo, hub_token);
}

if let Some(bucket) = &self.gcs_bucket {
return self.build_gcs_upload_info(bucket);
}

Ok(None)
}

fn build_hub_upload_info(
&self,
repo: &str,
token: &Option<String>,
) -> Result<Option<UploadInfo>> {
let token = token.as_ref().ok_or_else(|| {
anyhow::anyhow!("hub-repo and checkpoint-dir set, but no HF_TOKEN env variable.")
})?;

Ok(Some(UploadInfo::Hub(HubUploadInfo {
hub_repo: repo.to_string(),
hub_token: token.to_string(),
})))
}

fn build_gcs_upload_info(&self, bucket: &str) -> Result<Option<UploadInfo>> {
Ok(Some(UploadInfo::Gcs(GcsUploadInfo {
gcs_bucket: bucket.to_string(),
gcs_prefix: self.gcs_prefix.clone(),
})))
}

pub fn eval_tasks(&self) -> Result<Vec<psyche_eval::Task>> {
Expand Down
3 changes: 2 additions & 1 deletion shared/client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ pub use cli::{TrainArgs, prepare_environment, print_identity_keys, read_identity
pub use client::Client;
pub use protocol::{Broadcast, BroadcastType, Finished, NC, TrainingResult};
pub use state::{
CheckpointConfig, HubUploadInfo, InitRunError, RoundState, RunInitConfig, RunInitConfigAndIO,
CheckpointConfig, GcsUploadInfo, HubUploadInfo, InitRunError, RoundState, RunInitConfig,
RunInitConfigAndIO, UploadInfo,
};
pub use testing::IntegrationTestLogMarker;
pub use tui::{ClientTUI, ClientTUIState};
Expand Down
Loading
Loading