Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 55 additions & 5 deletions psyche-book/src/development/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ The `train` example, documented below, is useful to test how your model trains u
## Running

```bash
cargo run --example train -- ---help
cargo run --example train -- --help
```

You'll need a pre-tokenized dataset downloaded to your disk for training.

> A PR is welcome to add an option to the trainer to use the HTTP data provider! You can refer to the http example in the data-provider crate for a sample implementation.
You'll need a pre-tokenized dataset for training. The `train` example supports multiple data sources: local files, HTTP URLs, GCP buckets, and weighted configurations.

For a Llama 2 model, a pre-tokenized dataset to test with is available at [https://huggingface.co/datasets/emozilla/fineweb-10bt-tokenized-datatrove-llama2/](https://huggingface.co/datasets/emozilla/fineweb-10bt-tokenized-datatrove-llama2/tree/main).
Psyche only needs the `.ds` files, and will load any/all `.ds` files in the specified folder - you can download just one for smaller tests.
Psyche only needs the `.ds` files, and will load any/all `.ds` files in the specified folder - you can use just one for smaller tests.

### Local data

If you've downloaded part or all of the above dataset into a folder `data/fineweb-10bt` inside the Psyche repo, you can start a simple training run on a 20m parameter Llama 2 model:

Expand All @@ -29,6 +29,56 @@ cargo run --example train -- \
--micro-batch 1
```

### HTTP

You can stream data directly from HTTP URLs without downloading the dataset first. There are several ways to specify HTTP data sources:

#### URL template

Use a template with `{}` placeholder that gets replaced with padded numbers:

```bash
cargo run --example train -- \
--model emozilla/llama2-20m-init \
--total-batch 2 \
--micro-batch 1 \
http-template \
--template "https://example.com/data/{}.ds" \
--start 0 \
--end 10 \
--left-pad-zeros 5
```

This would load files from `https://example.com/data/00000.ds` through `https://example.com/data/00009.ds`.

#### Explicit URLs

Provide a list of URLs directly:

```bash
cargo run --example train -- \
--model emozilla/llama2-20m-init \
--total-batch 2 \
--micro-batch 1 \
urls \
https://example.com/data/file1.ds \
https://example.com/data/file2.ds
```

#### GCP bucket

Load all `.ds` files from a Google Cloud Storage bucket:

```bash
cargo run --example train -- \
--model emozilla/llama2-20m-init \
--total-batch 2 \
--micro-batch 1 \
gcp \
--bucket-name my-bucket \
--directory data/tokenized
```

## Adding a new model type

The `train` example currently asssumes your model is a Llama or Deepseek v2/v3 model, and instantiates it via `(LlamaForCausalLM|DeepseekForCausalLM)::from_pretrained`.
Expand Down
168 changes: 138 additions & 30 deletions shared/modeling/examples/train.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use psyche_core::{
};
use psyche_data_provider::{
DataProvider, LengthKnownDataProvider, LocalDataProvider, PreprocessedDataProvider, Split,
TokenizedDataProvider, download_model_repo_sync,
TokenizedDataProvider, WeightedDataProvider, WeightedHttpProvidersConfig,
download_model_repo_sync,
http::{FileURLs, HttpDataProvider},
};
use psyche_modeling::{
AttentionImplementation, Batch, BatchData, BatchDataCPU, CausalLM, CommunicatorId,
Expand Down Expand Up @@ -91,14 +93,63 @@ enum Commands {
},
}

#[derive(Subcommand, Debug, Clone)]
enum DataSource {
/// Local directory (default behavior)
Local {
#[arg(long, default_value = "data")]
path: String,
},
/// HTTP with URL template
HttpTemplate {
/// URL template with {} placeholder (e.g., "http://example.com/{}.ds")
#[arg(long)]
template: String,
/// Start index
#[arg(long, default_value = "0")]
start: u32,
/// End index
#[arg(long)]
end: u32,
/// Number of zeros to left-pad to
#[arg(long, default_value = "0")]
left_pad_zeros: u8,
},
/// HTTP with explicit URLs
Urls {
/// List of data URLs
urls: Vec<String>,
},
/// HTTP from GCP bucket
Gcp {
/// The name of the GCP bucket
#[arg(long)]
bucket_name: String,
/// An optional directory to filter by
#[arg(long)]
directory: Option<String>,
},
/// Weighted HTTP config (JSON file or URL)
WeightedConfig {
/// Path or URL to WeightedHttpProvidersConfig JSON file
#[arg(long)]
config: String,
},
}

#[derive(Args, Debug, Clone)]
struct RunArgs {
#[arg(long, default_value = "emozilla/llama2-215m-init")]
model: String,

/// Path to local data directory (backwards compatibility, use subcommand instead)
#[arg(long, default_value = "data")]
data_path: String,

/// Data source subcommand (optional, defaults to local data_path)
#[command(subcommand)]
data_source: Option<DataSource>,

#[arg(long, default_value_t = 2048)]
sequence_length: usize,

Expand Down Expand Up @@ -233,38 +284,95 @@ async fn main() -> Result<()> {
None => Shuffle::DontShuffle,
};

let mut dataset: DataProvider<DummyNodeIdentity> = match LocalDataProvider::new_from_directory(
&args.data_path,
args.token_size.try_into()?,
args.sequence_length,
shuffle,
)
.with_context(|| "Failed to load data with local data provider.")
{
Ok(dataset) => {
info!(
"Loaded local dataset with {} samples",
dataset.num_sequences()
);
DataProvider::Local(dataset)
}
Err(err) => {
println!(
"Failed to load with local data provider. {err:?} Trying preprocessed data provider instead"
);
let dataset = PreprocessedDataProvider::new_from_directory(
&args.data_path,
let token_size = args.token_size.try_into()?;

let mut dataset: DataProvider<DummyNodeIdentity> = match &args.data_source {
None | Some(DataSource::Local { .. }) => {
// Use data_path from Local subcommand or fallback to args.data_path
let data_path = match &args.data_source {
Some(DataSource::Local { path }) => path,
_ => &args.data_path,
};
match LocalDataProvider::new_from_directory(
data_path,
token_size,
args.sequence_length,
shuffle,
Some(Split::Train),
None,
)
.with_context(|| "Failed to load preprocessed data")?;
info!(
"Loaded preprocessed dataset with {} samples",
dataset.num_sequences()
);
DataProvider::Preprocessed(dataset)
.with_context(|| "Failed to load data with local data provider.")
{
Ok(dataset) => {
info!(
"Loaded local dataset with {} samples",
dataset.num_sequences()
);
DataProvider::Local(dataset)
}
Err(err) => {
println!(
"Failed to load with local data provider. {err:?} Trying preprocessed data provider instead"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in main, the train.rs falls back to pre-processed data if local data fails, this change keeps this behaviour, but might be more clear if it has an explicit option.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think it’s a good idea to separate the possibility of using Preprocessed data. We might even end up splitting it further into Preprocessed Local (this behavior) and HTTP Preprocessed, which is the one I mentioned in the issue and whose implementation is here: #506. That should cover all the different data provider possibilities we have at the moment.

);
let dataset = PreprocessedDataProvider::new_from_directory(
data_path,
args.sequence_length,
shuffle,
Some(Split::Train),
None,
)
.with_context(|| "Failed to load preprocessed data")?;
info!(
"Loaded preprocessed dataset with {} samples",
dataset.num_sequences()
);
DataProvider::Preprocessed(dataset)
}
}
}
Some(DataSource::HttpTemplate {
template,
start,
end,
left_pad_zeros,
}) => {
let urls =
FileURLs::from_template(template, *start, *left_pad_zeros, end - start).await?;
let provider =
HttpDataProvider::new(urls, token_size, args.sequence_length as u32, shuffle)?;
info!("Loaded HTTP template dataset");
DataProvider::Http(provider)
}
Some(DataSource::Urls { urls }) => {
if urls.is_empty() {
anyhow::bail!("At least one URL must be provided");
}
let urls = FileURLs::from_list(urls).await?;
let provider =
HttpDataProvider::new(urls, token_size, args.sequence_length as u32, shuffle)?;
info!("Loaded HTTP URLs dataset");
DataProvider::Http(provider)
}
Some(DataSource::Gcp {
bucket_name,
directory,
}) => {
let urls = FileURLs::from_gcp_bucket(bucket_name, directory.clone()).await?;
let provider =
HttpDataProvider::new(urls, token_size, args.sequence_length as u32, shuffle)?;
info!("Loaded GCP bucket dataset");
DataProvider::Http(provider)
}
Some(DataSource::WeightedConfig { config }) => {
let provider = if config.starts_with("http://") || config.starts_with("https://") {
WeightedDataProvider::from_config_url(config, args.sequence_length as u32).await?
} else {
let content = std::fs::read_to_string(config)
.with_context(|| format!("Failed to read config file: {}", config))?;
let cfg: WeightedHttpProvidersConfig = serde_json::from_str(&content)
.with_context(|| format!("Failed to parse config JSON: {}", config))?;
WeightedDataProvider::from_config(cfg, args.sequence_length as u32).await?
};
info!("Loaded weighted HTTP dataset");
DataProvider::WeightedHttp(provider)
}
};

Expand Down