diff --git a/psyche-book/src/development/models.md b/psyche-book/src/development/models.md index 08c033aa7..9166412cb 100644 --- a/psyche-book/src/development/models.md +++ b/psyche-book/src/development/models.md @@ -9,15 +9,15 @@ The `train` example, documented below, is useful to test how your model trains u ## Running ```bash -cargo run --example train -- ---help +cargo run --example train -- --help ``` -You'll need a pre-tokenized dataset downloaded to your disk for training. - -> A PR is welcome to add an option to the trainer to use the HTTP data provider! You can refer to the http example in the data-provider crate for a sample implementation. +You'll need a pre-tokenized dataset for training. The `train` example supports multiple data sources: local files, HTTP URLs, GCP buckets, and weighted configurations. For a Llama 2 model, a pre-tokenized dataset to test with is available at [https://huggingface.co/datasets/emozilla/fineweb-10bt-tokenized-datatrove-llama2/](https://huggingface.co/datasets/emozilla/fineweb-10bt-tokenized-datatrove-llama2/tree/main). -Psyche only needs the `.ds` files, and will load any/all `.ds` files in the specified folder - you can download just one for smaller tests. +Psyche only needs the `.ds` files, and will load any/all `.ds` files in the specified folder - you can use just one for smaller tests. + +### Local data If you've downloaded part or all of the above dataset into a folder `data/fineweb-10bt` inside the Psyche repo, you can start a simple training run on a 20m parameter Llama 2 model: @@ -29,6 +29,56 @@ cargo run --example train -- \ --micro-batch 1 ``` +### HTTP + +You can stream data directly from HTTP URLs without downloading the dataset first. There are several ways to specify HTTP data sources: + +#### URL template + +Use a template with `{}` placeholder that gets replaced with padded numbers: + +```bash +cargo run --example train -- \ + --model emozilla/llama2-20m-init \ + --total-batch 2 \ + --micro-batch 1 \ + http-template \ + --template "https://example.com/data/{}.ds" \ + --start 0 \ + --end 10 \ + --left-pad-zeros 5 +``` + +This would load files from `https://example.com/data/00000.ds` through `https://example.com/data/00009.ds`. + +#### Explicit URLs + +Provide a list of URLs directly: + +```bash +cargo run --example train -- \ + --model emozilla/llama2-20m-init \ + --total-batch 2 \ + --micro-batch 1 \ + urls \ + https://example.com/data/file1.ds \ + https://example.com/data/file2.ds +``` + +#### GCP bucket + +Load all `.ds` files from a Google Cloud Storage bucket: + +```bash +cargo run --example train -- \ + --model emozilla/llama2-20m-init \ + --total-batch 2 \ + --micro-batch 1 \ + gcp \ + --bucket-name my-bucket \ + --directory data/tokenized +``` + ## Adding a new model type The `train` example currently asssumes your model is a Llama or Deepseek v2/v3 model, and instantiates it via `(LlamaForCausalLM|DeepseekForCausalLM)::from_pretrained`. diff --git a/shared/modeling/examples/train.rs b/shared/modeling/examples/train.rs index b67a17f46..9aa752197 100644 --- a/shared/modeling/examples/train.rs +++ b/shared/modeling/examples/train.rs @@ -5,7 +5,9 @@ use psyche_core::{ }; use psyche_data_provider::{ DataProvider, LengthKnownDataProvider, LocalDataProvider, PreprocessedDataProvider, Split, - TokenizedDataProvider, download_model_repo_sync, + TokenizedDataProvider, WeightedDataProvider, WeightedHttpProvidersConfig, + download_model_repo_sync, + http::{FileURLs, HttpDataProvider}, }; use psyche_modeling::{ AttentionImplementation, Batch, BatchData, BatchDataCPU, CausalLM, CommunicatorId, @@ -91,14 +93,63 @@ enum Commands { }, } +#[derive(Subcommand, Debug, Clone)] +enum DataSource { + /// Local directory (default behavior) + Local { + #[arg(long, default_value = "data")] + path: String, + }, + /// HTTP with URL template + HttpTemplate { + /// URL template with {} placeholder (e.g., "http://example.com/{}.ds") + #[arg(long)] + template: String, + /// Start index + #[arg(long, default_value = "0")] + start: u32, + /// End index + #[arg(long)] + end: u32, + /// Number of zeros to left-pad to + #[arg(long, default_value = "0")] + left_pad_zeros: u8, + }, + /// HTTP with explicit URLs + Urls { + /// List of data URLs + urls: Vec, + }, + /// HTTP from GCP bucket + Gcp { + /// The name of the GCP bucket + #[arg(long)] + bucket_name: String, + /// An optional directory to filter by + #[arg(long)] + directory: Option, + }, + /// Weighted HTTP config (JSON file or URL) + WeightedConfig { + /// Path or URL to WeightedHttpProvidersConfig JSON file + #[arg(long)] + config: String, + }, +} + #[derive(Args, Debug, Clone)] struct RunArgs { #[arg(long, default_value = "emozilla/llama2-215m-init")] model: String, + /// Path to local data directory (backwards compatibility, use subcommand instead) #[arg(long, default_value = "data")] data_path: String, + /// Data source subcommand (optional, defaults to local data_path) + #[command(subcommand)] + data_source: Option, + #[arg(long, default_value_t = 2048)] sequence_length: usize, @@ -233,38 +284,95 @@ async fn main() -> Result<()> { None => Shuffle::DontShuffle, }; - let mut dataset: DataProvider = match LocalDataProvider::new_from_directory( - &args.data_path, - args.token_size.try_into()?, - args.sequence_length, - shuffle, - ) - .with_context(|| "Failed to load data with local data provider.") - { - Ok(dataset) => { - info!( - "Loaded local dataset with {} samples", - dataset.num_sequences() - ); - DataProvider::Local(dataset) - } - Err(err) => { - println!( - "Failed to load with local data provider. {err:?} Trying preprocessed data provider instead" - ); - let dataset = PreprocessedDataProvider::new_from_directory( - &args.data_path, + let token_size = args.token_size.try_into()?; + + let mut dataset: DataProvider = match &args.data_source { + None | Some(DataSource::Local { .. }) => { + // Use data_path from Local subcommand or fallback to args.data_path + let data_path = match &args.data_source { + Some(DataSource::Local { path }) => path, + _ => &args.data_path, + }; + match LocalDataProvider::new_from_directory( + data_path, + token_size, args.sequence_length, shuffle, - Some(Split::Train), - None, ) - .with_context(|| "Failed to load preprocessed data")?; - info!( - "Loaded preprocessed dataset with {} samples", - dataset.num_sequences() - ); - DataProvider::Preprocessed(dataset) + .with_context(|| "Failed to load data with local data provider.") + { + Ok(dataset) => { + info!( + "Loaded local dataset with {} samples", + dataset.num_sequences() + ); + DataProvider::Local(dataset) + } + Err(err) => { + println!( + "Failed to load with local data provider. {err:?} Trying preprocessed data provider instead" + ); + let dataset = PreprocessedDataProvider::new_from_directory( + data_path, + args.sequence_length, + shuffle, + Some(Split::Train), + None, + ) + .with_context(|| "Failed to load preprocessed data")?; + info!( + "Loaded preprocessed dataset with {} samples", + dataset.num_sequences() + ); + DataProvider::Preprocessed(dataset) + } + } + } + Some(DataSource::HttpTemplate { + template, + start, + end, + left_pad_zeros, + }) => { + let urls = + FileURLs::from_template(template, *start, *left_pad_zeros, end - start).await?; + let provider = + HttpDataProvider::new(urls, token_size, args.sequence_length as u32, shuffle)?; + info!("Loaded HTTP template dataset"); + DataProvider::Http(provider) + } + Some(DataSource::Urls { urls }) => { + if urls.is_empty() { + anyhow::bail!("At least one URL must be provided"); + } + let urls = FileURLs::from_list(urls).await?; + let provider = + HttpDataProvider::new(urls, token_size, args.sequence_length as u32, shuffle)?; + info!("Loaded HTTP URLs dataset"); + DataProvider::Http(provider) + } + Some(DataSource::Gcp { + bucket_name, + directory, + }) => { + let urls = FileURLs::from_gcp_bucket(bucket_name, directory.clone()).await?; + let provider = + HttpDataProvider::new(urls, token_size, args.sequence_length as u32, shuffle)?; + info!("Loaded GCP bucket dataset"); + DataProvider::Http(provider) + } + Some(DataSource::WeightedConfig { config }) => { + let provider = if config.starts_with("http://") || config.starts_with("https://") { + WeightedDataProvider::from_config_url(config, args.sequence_length as u32).await? + } else { + let content = std::fs::read_to_string(config) + .with_context(|| format!("Failed to read config file: {}", config))?; + let cfg: WeightedHttpProvidersConfig = serde_json::from_str(&content) + .with_context(|| format!("Failed to parse config JSON: {}", config))?; + WeightedDataProvider::from_config(cfg, args.sequence_length as u32).await? + }; + info!("Loaded weighted HTTP dataset"); + DataProvider::WeightedHttp(provider) } };