From 512a436c768115a48c14fb54754966325a8fc6f2 Mon Sep 17 00:00:00 2001 From: Alex Kholodniak Date: Tue, 6 Jan 2026 03:45:10 +0200 Subject: [PATCH 1/2] chore: ignore .claude directory --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 62b01d3..b35c1db 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,4 @@ ehthumbs.db lcov.info *.profraw .cursor/ +.claude/ \ No newline at end of file From 1a4fcb1aa51d668948f1ab7edb9832a92b9ecc51 Mon Sep 17 00:00:00 2001 From: Alex Kholodniak Date: Tue, 6 Jan 2026 04:03:43 +0200 Subject: [PATCH 2/2] Add utilities for character-level text generation --- examples/text_generation_advanced.rs | 459 +++++++-------------------- examples/text_utils_example.rs | 116 +++++++ src/lib.rs | 5 + src/text.rs | 403 +++++++++++++++++++++++ 4 files changed, 646 insertions(+), 337 deletions(-) create mode 100644 examples/text_utils_example.rs create mode 100644 src/text.rs diff --git a/examples/text_generation_advanced.rs b/examples/text_generation_advanced.rs index 2f02c01..f56b802 100644 --- a/examples/text_generation_advanced.rs +++ b/examples/text_generation_advanced.rs @@ -1,400 +1,185 @@ use ndarray::Array2; use rust_lstm::models::lstm_network::LSTMNetwork; +use rust_lstm::layers::linear::LinearLayer; +use rust_lstm::text::{TextVocabulary, CharacterEmbedding, sample_with_temperature}; use rust_lstm::training::LSTMTrainer; -use rust_lstm::loss::MSELoss; +use rust_lstm::loss::CrossEntropyLoss; use rust_lstm::optimizers::Adam; use std::collections::HashMap; -/// Advanced character-level language model using LSTM with embedded representations struct CharacterLSTM { + vocab: TextVocabulary, + embedding: CharacterEmbedding, network: LSTMNetwork, - trainer: Option>, - char_to_idx: HashMap, - idx_to_char: HashMap, - vocab_size: usize, + output_layer: LinearLayer, + trainer: Option>, + hidden_size: usize, sequence_length: usize, - embedding_size: usize, } impl CharacterLSTM { - fn new(text: &str, sequence_length: usize, hidden_size: usize, embedding_size: usize) -> Self { - // Build vocabulary from text - let unique_chars: std::collections::HashSet = text.chars().collect(); - let mut chars: Vec = unique_chars.into_iter().collect(); - chars.sort(); // Ensure consistent ordering - - let vocab_size = chars.len(); - - // Create character mappings - let char_to_idx: HashMap = chars.iter().enumerate() - .map(|(i, &c)| (c, i)) - .collect(); - let idx_to_char: HashMap = chars.iter().enumerate() - .map(|(i, &c)| (i, c)) - .collect(); - - // Create network: embedding_size input -> hidden_size (single layer) - let network = LSTMNetwork::new(embedding_size, hidden_size, 1); - - println!("Built vocabulary: {} unique characters", vocab_size); - println!("Characters: {:?}", chars.iter().take(20).collect::>()); - println!("Network: {} -> {} -> {}", embedding_size, hidden_size, embedding_size); - + fn new(text: &str, sequence_length: usize, hidden_size: usize, embed_dim: usize) -> Self { + let vocab = TextVocabulary::from_text(text); + let embedding = CharacterEmbedding::new(vocab.size(), embed_dim); + let network = LSTMNetwork::new(embed_dim, hidden_size, 1); + let output_layer = LinearLayer::new(hidden_size, vocab.size()); + + println!("Vocabulary size: {}", vocab.size()); + println!("Network: embed({}) -> LSTM({}) -> Linear({})", embed_dim, hidden_size, vocab.size()); + Self { + vocab, + embedding, network, + output_layer, trainer: None, - char_to_idx, - idx_to_char, - vocab_size, + hidden_size, sequence_length, - embedding_size, - } - } - - /// Convert character to embedded representation - fn char_to_embedding(&self, ch: char) -> Array2 { - let idx = self.char_to_idx.get(&ch).copied().unwrap_or(0); - let mut embedding = vec![0.0; self.embedding_size]; - - // Simple embedding: use sine/cosine features based on character index - for i in 0..self.embedding_size { - let freq = (i + 1) as f64 / self.embedding_size as f64; - if i % 2 == 0 { - embedding[i] = ((idx as f64) * freq).sin(); - } else { - embedding[i] = ((idx as f64) * freq).cos(); - } - } - - Array2::from_shape_vec((self.embedding_size, 1), embedding).unwrap() - } - - /// Convert embedding back to character using similarity - fn embedding_to_char(&self, embedding: &Array2) -> char { - let mut best_char = ' '; - let mut best_similarity = f64::NEG_INFINITY; - - // Find character with most similar embedding - for (&ch, &_idx) in &self.char_to_idx { - let char_embedding = self.char_to_embedding(ch); - - // Compute cosine similarity - let dot_product: f64 = (0..self.embedding_size) - .map(|i| embedding[[i, 0]] * char_embedding[[i, 0]]) - .sum(); - - let norm1: f64 = (0..self.embedding_size) - .map(|i| embedding[[i, 0]] * embedding[[i, 0]]) - .sum::().sqrt(); - - let norm2: f64 = (0..self.embedding_size) - .map(|i| char_embedding[[i, 0]] * char_embedding[[i, 0]]) - .sum::().sqrt(); - - let similarity = if norm1 > 0.0 && norm2 > 0.0 { - dot_product / (norm1 * norm2) - } else { - 0.0 - }; - - if similarity > best_similarity { - best_similarity = similarity; - best_char = ch; - } - } - - best_char - } - - /// Sample character with temperature control - fn sample_char_with_temperature(&self, embedding: &Array2, temperature: f64) -> char { - let mut scores = Vec::new(); - - // Calculate similarity scores for all characters - for (&ch, &_idx) in &self.char_to_idx { - let char_embedding = self.char_to_embedding(ch); - - let dot_product: f64 = (0..self.embedding_size) - .map(|i| embedding[[i, 0]] * char_embedding[[i, 0]]) - .sum(); - - scores.push((ch, dot_product / temperature)); - } - - // Apply softmax and sample - let max_score = scores.iter().map(|(_, score)| *score).fold(f64::NEG_INFINITY, f64::max); - let exp_scores: Vec<(char, f64)> = scores.iter() - .map(|(ch, score)| (*ch, (score - max_score).exp())) - .collect(); - - let sum: f64 = exp_scores.iter().map(|(_, exp_score)| *exp_score).sum(); - let probabilities: Vec<(char, f64)> = exp_scores.iter() - .map(|(ch, exp_score)| (*ch, exp_score / sum)) - .collect(); - - // Sample from distribution - let mut rng_val = rand::random::(); - for &(ch, prob) in &probabilities { - rng_val -= prob; - if rng_val <= 0.0 { - return ch; - } } - - // Fallback to most probable character - probabilities.iter() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) - .map(|(ch, _)| *ch) - .unwrap_or(' ') } - /// Project embedding to hidden space for targets - fn embedding_to_hidden(&self, embedding: &Array2) -> Array2 { - // Simple projection: repeat/pad embedding to hidden size - let hidden_size = 32; // This should match the network hidden size - let mut hidden = vec![0.0; hidden_size]; - - for i in 0..hidden_size { - if i < self.embedding_size { - hidden[i] = embedding[[i, 0]]; - } else { - // Pad with scaled values or zeros - hidden[i] = 0.0; - } - } - - Array2::from_shape_vec((hidden_size, 1), hidden).unwrap() - } - - /// Create training sequences from text fn create_sequences(&self, text: &str) -> Vec<(Vec>, Vec>)> { let chars: Vec = text.chars().collect(); let mut sequences = Vec::new(); - - for i in 0..chars.len().saturating_sub(self.sequence_length) { + + for i in 0..chars.len().saturating_sub(self.sequence_length + 1) { let mut inputs = Vec::new(); let mut targets = Vec::new(); - - // Create input sequence and corresponding target sequence - for j in i..i + self.sequence_length { - inputs.push(self.char_to_embedding(chars[j])); - - // Target is the next character's embedding projected to hidden space - if j + 1 < chars.len() { - let target_embedding = self.char_to_embedding(chars[j + 1]); - let hidden_target = self.embedding_to_hidden(&target_embedding); - targets.push(hidden_target); - } - } - - if inputs.len() == targets.len() && !inputs.is_empty() { - sequences.push((inputs, targets)); + + for j in 0..self.sequence_length { + let char_idx = self.vocab.char_to_index(chars[i + j]).unwrap_or(0); + let next_idx = self.vocab.char_to_index(chars[i + j + 1]).unwrap_or(0); + + let emb = self.embedding.lookup(char_idx); + let input = Array2::from_shape_vec((emb.len(), 1), emb.to_vec()).unwrap(); + inputs.push(input); + + let mut target = Array2::zeros((self.hidden_size, 1)); + target[[next_idx % self.hidden_size, 0]] = 1.0; + targets.push(target); } + + sequences.push((inputs, targets)); } - + sequences } - /// Train the character-level language model - fn train(&mut self, text: &str, epochs: usize, validation_split: f64) { - println!("Creating character sequences from text..."); + fn train(&mut self, text: &str, epochs: usize) { + println!("Creating sequences..."); let sequences = self.create_sequences(text); - + if sequences.is_empty() { - println!("No training sequences created!"); + println!("No sequences created!"); return; } - - let split_idx = ((sequences.len() as f64) * (1.0 - validation_split)) as usize; - let (train_data, val_data) = sequences.split_at(split_idx); - - println!("Training on {} sequences, validating on {} sequences", - train_data.len(), val_data.len()); - - // Create trainer with MSE loss for embedding regression - let loss_function = MSELoss; + + let split = (sequences.len() as f64 * 0.9) as usize; + let (train, val) = sequences.split_at(split); + + println!("Training on {} sequences, validating on {}", train.len(), val.len()); + + let loss_fn = CrossEntropyLoss; let optimizer = Adam::new(0.002); - let mut trainer = LSTMTrainer::new(self.network.clone(), loss_function, optimizer); - - // Configure training + let mut trainer = LSTMTrainer::new(self.network.clone(), loss_fn, optimizer); + let mut config = rust_lstm::training::TrainingConfig::default(); config.epochs = epochs; - config.print_every = epochs / 5; // Print 5 times during training + config.print_every = epochs / 5; config.clip_gradient = Some(5.0); - trainer = trainer.with_config(config); - - // Train the model - trainer.train(train_data, if val_data.is_empty() { None } else { Some(val_data) }); - + + trainer.train(train, if val.is_empty() { None } else { Some(val) }); self.trainer = Some(trainer); - println!("Character LSTM training completed!"); + + println!("Training complete!"); } - /// Generate text starting with a seed string - fn generate_text(&mut self, seed: &str, length: usize, temperature: f64) -> String { + fn generate(&mut self, seed: &str, length: usize, temperature: f64) -> String { if self.trainer.is_none() { - println!("Model not trained yet!"); - return String::new(); + return seed.to_string(); } - + let mut generated = seed.to_string(); - let mut current_sequence: Vec = seed.chars().collect(); - - // Ensure we have enough characters to start - while current_sequence.len() < self.sequence_length { - current_sequence.insert(0, ' '); // Pad with spaces + let mut chars: Vec = seed.chars().collect(); + + while chars.len() < self.sequence_length { + chars.insert(0, ' '); } - - let network = if let Some(ref trainer) = self.trainer { - &trainer.network - } else { - println!("Trainer not available"); - return generated; - }; - - let mut inference_network = network.clone(); - inference_network.eval(); - + + let network = &self.trainer.as_ref().unwrap().network; + let mut inference_net = network.clone(); + inference_net.eval(); + for _ in 0..length { - // Prepare input sequence - let start_idx = current_sequence.len().saturating_sub(self.sequence_length); - let input_chars = ¤t_sequence[start_idx..]; - - let inputs: Vec> = input_chars.iter() - .map(|&ch| self.char_to_embedding(ch)) - .collect(); - - let (outputs, _) = inference_network.forward_sequence_with_cache(&inputs); - - if let Some((last_output, _)) = outputs.last() { - let predicted_embedding = self.project_to_embedding(last_output); - - let next_char = self.sample_next_char(&predicted_embedding, temperature); - - generated.push(next_char); - current_sequence.push(next_char); - - if current_sequence.len() > self.sequence_length * 2 { - current_sequence.drain(0..self.sequence_length); - } - } else { - println!("No prediction generated, stopping text generation"); - break; + let start = chars.len().saturating_sub(self.sequence_length); + let window: Vec = chars[start..].to_vec(); + + let mut h = Array2::zeros((self.hidden_size, 1)); + let mut c = Array2::zeros((self.hidden_size, 1)); + + for ch in &window { + let idx = self.vocab.char_to_index(*ch).unwrap_or(0); + let emb = self.embedding.lookup(idx); + let input = Array2::from_shape_vec((emb.len(), 1), emb.to_vec()).unwrap(); + let (new_h, new_c) = inference_net.forward(&input, &h, &c); + h = new_h; + c = new_c; } - } - - generated - } - /// Project LSTM output to embedding space - fn project_to_embedding(&self, lstm_output: &Array2) -> Array2 { - // Simple projection: take first embedding_size elements of hidden state - // In practice, this would be a learned linear layer - let hidden_size = lstm_output.nrows(); - let mut embedding = vec![0.0; self.embedding_size]; - - for i in 0..self.embedding_size.min(hidden_size) { - embedding[i] = lstm_output[[i, 0]]; - } - - Array2::from_shape_vec((self.embedding_size, 1), embedding).unwrap() - } + let logits_2d = self.output_layer.forward(&h); + let logits = logits_2d.column(0).to_owned(); - /// Sample next character using temperature-based sampling - fn sample_next_char(&self, predicted_embedding: &Array2, temperature: f64) -> char { - let mut similarities = Vec::new(); - - for (&ch, &_idx) in &self.char_to_idx { - let char_embedding = self.char_to_embedding(ch); - - let dot_product = predicted_embedding.iter() - .zip(char_embedding.iter()) - .map(|(a, b)| a * b) - .sum::(); - - let pred_norm = predicted_embedding.iter().map(|x| x * x).sum::().sqrt(); - let char_norm = char_embedding.iter().map(|x| x * x).sum::().sqrt(); - - let similarity = if pred_norm > 0.0 && char_norm > 0.0 { - dot_product / (pred_norm * char_norm) - } else { - 0.0 - }; - - similarities.push((ch, similarity)); - } - - let max_similarity = similarities.iter().map(|(_, s)| *s).fold(f64::NEG_INFINITY, f64::max); - let mut probabilities = Vec::new(); - let mut total_prob = 0.0; - - for (ch, similarity) in &similarities { - let scaled_similarity = (similarity - max_similarity) / temperature; - let prob = scaled_similarity.exp(); - probabilities.push((*ch, prob)); - total_prob += prob; - } - - for (_, prob) in &mut probabilities { - *prob /= total_prob; - } - - let random_value: f64 = rand::random(); - let mut cumulative_prob = 0.0; - - for (ch, prob) in probabilities { - cumulative_prob += prob; - if random_value <= cumulative_prob { - return ch; + let next_idx = sample_with_temperature(&logits, temperature); + if let Some(next_char) = self.vocab.index_to_char(next_idx) { + generated.push(next_char); + chars.push(next_char); } } - - similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); - similarities[0].0 + + generated } } -/// Sample training texts for different domains fn get_sample_texts() -> HashMap<&'static str, &'static str> { let mut texts = HashMap::new(); - - texts.insert("poetry", - "The woods are lovely, dark and deep, But I have promises to keep, And miles to go before I sleep, And miles to go before I sleep. Two roads diverged in a yellow wood, And sorry I could not travel both And be one traveler, long I stood And looked down one as far as I could To where it bent in the undergrowth."); - - texts.insert("code", - "fn main() { println!(\"Hello, world!\"); let x = 42; if x > 10 { println!(\"x is greater than 10\"); } for i in 0..5 { println!(\"i = {}\", i); } }"); - + + texts.insert("poetry", + "The woods are lovely, dark and deep, But I have promises to keep, \ + And miles to go before I sleep, And miles to go before I sleep. \ + Two roads diverged in a yellow wood, And sorry I could not travel both."); + + texts.insert("code", + "fn main() { println!(\"Hello, world!\"); let x = 42; \ + if x > 10 { println!(\"x is greater\"); } for i in 0..5 { println!(\"i = {}\", i); } }"); + texts.insert("prose", - "In a hole in the ground there lived a hobbit. Not a nasty, dirty, wet hole, filled with the ends of worms and an oozy smell, nor yet a dry, bare, sandy hole with nothing in it to sit down on or to eat: it was a hobbit-hole, and that means comfort."); - + "In a hole in the ground there lived a hobbit. Not a nasty, dirty, wet hole, \ + filled with the ends of worms and an oozy smell, nor yet a dry, bare, sandy hole."); + texts } fn main() { - println!("Advanced Text Generation with Character-Level LSTM"); - println!("===================================================\n"); - - let sample_texts = get_sample_texts(); - - for (domain, text) in &sample_texts { - println!("Training {} model...", domain); - println!("Training text preview: {}...\n", &text[..text.len().min(100)]); - - // Create and train model with embedding - let mut model = CharacterLSTM::new(text, 8, 32, 16); // 8-char sequences, 32 hidden, 16 embedding - model.train(text, 8, 0.1); // 8 epochs for quick demo, 10% validation - - println!("\nGenerating text samples:"); - - // Generate with different temperatures - let temperatures = [0.8, 1.2]; - for &temp in &temperatures { - let seed = text.chars().take(5).collect::(); - let generated = model.generate_text(&seed, 60, temp); - println!("\nTemperature {:.1}: {}", temp, generated); + println!("Character-Level LSTM with Text Utilities"); + println!("=========================================\n"); + + let texts = get_sample_texts(); + + for (domain, text) in &texts { + println!("Domain: {}", domain); + println!("Text: {}...\n", &text[..text.len().min(60)]); + + let mut model = CharacterLSTM::new(text, 8, 32, 16); + model.train(text, 10); + + println!("\nGenerated samples:"); + for temp in [0.5, 1.0, 1.5] { + let seed: String = text.chars().take(5).collect(); + let output = model.generate(&seed, 50, temp); + println!(" temp={:.1}: {}", temp, output); } - - println!("\n{}\n", "=".repeat(60)); + + println!("\n{}\n", "=".repeat(50)); } -} \ No newline at end of file +} diff --git a/examples/text_utils_example.rs b/examples/text_utils_example.rs new file mode 100644 index 0000000..862c1c0 --- /dev/null +++ b/examples/text_utils_example.rs @@ -0,0 +1,116 @@ +//! Example demonstrating text generation utilities. +//! +//! Shows TextVocabulary, CharacterEmbedding, and sampling functions. + +use ndarray::Array2; +use rust_lstm::text::{ + TextVocabulary, CharacterEmbedding, + sample_with_temperature, sample_top_k, sample_nucleus, argmax, softmax +}; +use rust_lstm::layers::linear::LinearLayer; +use rust_lstm::models::lstm_network::LSTMNetwork; + +fn main() { + println!("Text Generation Utilities Demo"); + println!("==============================\n"); + + // 1. Vocabulary + let text = "Hello, World! This is a test."; + let vocab = TextVocabulary::from_text(text); + + println!("1. TextVocabulary"); + println!(" Text: \"{}\"", text); + println!(" Vocabulary size: {}", vocab.size()); + println!(" Characters: {:?}", vocab.chars()); + + let encoded = vocab.encode("Hello"); + let decoded = vocab.decode(&encoded); + println!(" Encode 'Hello': {:?}", encoded); + println!(" Decode back: \"{}\"", decoded); + + // 2. Character Embedding + println!("\n2. CharacterEmbedding"); + let embed_dim = 16; + let mut embedding = CharacterEmbedding::new(vocab.size(), embed_dim); + println!(" Embedding: {} chars -> {} dimensions", vocab.size(), embed_dim); + println!(" Parameters: {}", embedding.num_parameters()); + + // Lookup single character + let h_idx = vocab.char_to_index('H').unwrap(); + let h_vec = embedding.lookup(h_idx); + println!(" 'H' embedding (first 4): [{:.3}, {:.3}, {:.3}, {:.3}, ...]", + h_vec[0], h_vec[1], h_vec[2], h_vec[3]); + + // Forward pass for sequence + let seq_indices = vocab.encode("Hi"); + let seq_embeddings = embedding.forward(&seq_indices); + println!(" 'Hi' embeddings shape: {:?}", seq_embeddings.shape()); + + // 3. LSTM + Linear for text generation + println!("\n3. LSTM + Linear Pipeline"); + let hidden_size = 32; + let mut lstm = LSTMNetwork::new(embed_dim, hidden_size, 1); + let mut output_layer = LinearLayer::new(hidden_size, vocab.size()); + + // Process a character + let char_idx = vocab.char_to_index('H').unwrap(); + let char_emb = embedding.lookup(char_idx); + let input = Array2::from_shape_vec((embed_dim, 1), char_emb.to_vec()).unwrap(); + + let h0 = Array2::zeros((hidden_size, 1)); + let c0 = Array2::zeros((hidden_size, 1)); + + let (hidden, _cell) = lstm.forward(&input, &h0, &c0); + let logits_2d = output_layer.forward(&hidden); + let logits = logits_2d.column(0).to_owned(); + + println!(" Input: 'H' -> embed({}) -> LSTM -> Linear -> logits({})", + embed_dim, vocab.size()); + + // 4. Sampling strategies + println!("\n4. Sampling Strategies"); + println!(" Logits range: [{:.2}, {:.2}]", + logits.iter().cloned().fold(f64::INFINITY, f64::min), + logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max)); + + // Greedy + let greedy_idx = argmax(&logits); + let greedy_char = vocab.index_to_char(greedy_idx).unwrap_or('?'); + println!(" Greedy (argmax): '{}' (idx {})", greedy_char, greedy_idx); + + // Temperature sampling + for temp in [0.5, 1.0, 1.5] { + let idx = sample_with_temperature(&logits, temp); + let ch = vocab.index_to_char(idx).unwrap_or('?'); + println!(" Temperature {:.1}: '{}' (idx {})", temp, ch, idx); + } + + // Top-k sampling + let k = 5; + let idx = sample_top_k(&logits, k, 1.0); + let ch = vocab.index_to_char(idx).unwrap_or('?'); + println!(" Top-{} sampling: '{}' (idx {})", k, ch, idx); + + // Nucleus sampling + let p = 0.9; + let idx = sample_nucleus(&logits, p, 1.0); + let ch = vocab.index_to_char(idx).unwrap_or('?'); + println!(" Nucleus (p={:.1}): '{}' (idx {})", p, ch, idx); + + // 5. Softmax probabilities + println!("\n5. Probability Distribution"); + let probs = softmax(&logits); + let mut prob_chars: Vec<_> = probs.iter() + .enumerate() + .map(|(i, &p)| (vocab.index_to_char(i).unwrap_or('?'), p)) + .collect(); + prob_chars.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + + println!(" Top 5 most likely next characters:"); + for (ch, prob) in prob_chars.iter().take(5) { + let display = if *ch == ' ' { "' '" } else { &ch.to_string() }; + println!(" {:>3}: {:.1}%", display, prob * 100.0); + } + + println!("\nDone!"); +} diff --git a/src/lib.rs b/src/lib.rs index 58cffdd..d1e1c87 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,6 +41,7 @@ pub mod optimizers; pub mod schedulers; pub mod training; pub mod persistence; +pub mod text; // Re-export commonly used items pub use models::lstm_network::{LSTMNetwork, LSTMNetworkCache, LSTMNetworkBatchCache, LayerDropoutConfig}; @@ -67,6 +68,10 @@ pub use schedulers::{ }; pub use loss::{LossFunction, MSELoss, MAELoss, CrossEntropyLoss}; pub use persistence::{ModelPersistence, PersistentModel, ModelMetadata, PersistenceError}; +pub use text::{ + TextVocabulary, CharacterEmbedding, EmbeddingGradients, + sample_with_temperature, sample_top_k, sample_nucleus, argmax, softmax +}; #[cfg(test)] mod tests { diff --git a/src/text.rs b/src/text.rs new file mode 100644 index 0000000..dcec6a0 --- /dev/null +++ b/src/text.rs @@ -0,0 +1,403 @@ +//! Text generation utilities for character-level language models. +//! +//! Provides vocabulary management, character embeddings, and sampling strategies. + +use std::collections::HashMap; +use ndarray::{Array1, Array2}; +use ndarray_rand::RandomExt; +use ndarray_rand::rand_distr::Uniform; +use crate::optimizers::Optimizer; + +/// Character vocabulary for text generation tasks. +/// +/// Maps characters to indices and vice versa. +#[derive(Clone, Debug)] +pub struct TextVocabulary { + char_to_idx: HashMap, + idx_to_char: HashMap, + vocab_size: usize, +} + +impl TextVocabulary { + /// Create vocabulary from text, extracting unique characters. + pub fn from_text(text: &str) -> Self { + let mut chars: Vec = text.chars().collect::>() + .into_iter().collect(); + chars.sort(); + + let vocab_size = chars.len(); + let char_to_idx: HashMap = chars.iter() + .enumerate() + .map(|(i, &c)| (c, i)) + .collect(); + let idx_to_char: HashMap = chars.iter() + .enumerate() + .map(|(i, &c)| (i, c)) + .collect(); + + Self { char_to_idx, idx_to_char, vocab_size } + } + + /// Create vocabulary from explicit character list. + pub fn from_chars(chars: &[char]) -> Self { + let vocab_size = chars.len(); + let char_to_idx: HashMap = chars.iter() + .enumerate() + .map(|(i, &c)| (c, i)) + .collect(); + let idx_to_char: HashMap = chars.iter() + .enumerate() + .map(|(i, &c)| (i, c)) + .collect(); + + Self { char_to_idx, idx_to_char, vocab_size } + } + + /// Get index for a character. + pub fn char_to_index(&self, ch: char) -> Option { + self.char_to_idx.get(&ch).copied() + } + + /// Get character for an index. + pub fn index_to_char(&self, idx: usize) -> Option { + self.idx_to_char.get(&idx).copied() + } + + /// Get vocabulary size. + pub fn size(&self) -> usize { + self.vocab_size + } + + /// Check if character is in vocabulary. + pub fn contains(&self, ch: char) -> bool { + self.char_to_idx.contains_key(&ch) + } + + /// Get all characters in vocabulary order. + pub fn chars(&self) -> Vec { + let mut chars: Vec<_> = self.idx_to_char.iter().collect(); + chars.sort_by_key(|(idx, _)| *idx); + chars.into_iter().map(|(_, &ch)| ch).collect() + } + + /// Encode string to indices. + pub fn encode(&self, text: &str) -> Vec { + text.chars() + .filter_map(|ch| self.char_to_index(ch)) + .collect() + } + + /// Decode indices to string. + pub fn decode(&self, indices: &[usize]) -> String { + indices.iter() + .filter_map(|&idx| self.index_to_char(idx)) + .collect() + } +} + +/// Gradients for character embedding layer. +#[derive(Clone, Debug)] +pub struct EmbeddingGradients { + pub weight: Array2, +} + +/// Trainable character embedding layer. +/// +/// Maps character indices to dense vectors. +#[derive(Clone, Debug)] +pub struct CharacterEmbedding { + pub weight: Array2, // (vocab_size, embed_dim) + vocab_size: usize, + embed_dim: usize, + input_cache: Option>, +} + +impl CharacterEmbedding { + /// Create new embedding with random initialization. + pub fn new(vocab_size: usize, embed_dim: usize) -> Self { + let scale = (1.0 / embed_dim as f64).sqrt(); + let weight = Array2::random((vocab_size, embed_dim), Uniform::new(-scale, scale)); + + Self { + weight, + vocab_size, + embed_dim, + input_cache: None, + } + } + + /// Create embedding with zero initialization. + pub fn new_zeros(vocab_size: usize, embed_dim: usize) -> Self { + Self { + weight: Array2::zeros((vocab_size, embed_dim)), + vocab_size, + embed_dim, + input_cache: None, + } + } + + /// Create embedding from existing weights. + pub fn from_weights(weight: Array2) -> Self { + let (vocab_size, embed_dim) = weight.dim(); + Self { + weight, + vocab_size, + embed_dim, + input_cache: None, + } + } + + /// Get embedding dimension. + pub fn embed_dim(&self) -> usize { + self.embed_dim + } + + /// Get vocabulary size. + pub fn vocab_size(&self) -> usize { + self.vocab_size + } + + /// Lookup single character embedding. + pub fn lookup(&self, char_idx: usize) -> Array1 { + assert!(char_idx < self.vocab_size, "Index {} out of vocabulary size {}", char_idx, self.vocab_size); + self.weight.row(char_idx).to_owned() + } + + /// Forward pass for sequence of indices. + /// Returns (seq_len, embed_dim) matrix. + pub fn forward(&mut self, char_indices: &[usize]) -> Array2 { + self.input_cache = Some(char_indices.to_vec()); + + let seq_len = char_indices.len(); + let mut output = Array2::zeros((seq_len, self.embed_dim)); + + for (i, &idx) in char_indices.iter().enumerate() { + assert!(idx < self.vocab_size, "Index {} out of vocabulary size {}", idx, self.vocab_size); + output.row_mut(i).assign(&self.weight.row(idx)); + } + + output + } + + /// Backward pass - compute gradients. + /// grad_output shape: (seq_len, embed_dim) + pub fn backward(&self, grad_output: &Array2) -> EmbeddingGradients { + let indices = self.input_cache.as_ref().expect("No cached input for backward pass"); + + let mut weight_grad = Array2::zeros((self.vocab_size, self.embed_dim)); + + for (i, &idx) in indices.iter().enumerate() { + for j in 0..self.embed_dim { + weight_grad[[idx, j]] += grad_output[[i, j]]; + } + } + + EmbeddingGradients { weight: weight_grad } + } + + /// Update parameters with optimizer. + pub fn update_parameters(&mut self, gradients: &EmbeddingGradients, optimizer: &mut O, prefix: &str) { + optimizer.update(&format!("{}_weight", prefix), &mut self.weight, &gradients.weight); + } + + /// Get number of parameters. + pub fn num_parameters(&self) -> usize { + self.weight.len() + } +} + +/// Sample from logits with temperature scaling. +/// +/// Higher temperature = more random, lower = more deterministic. +pub fn sample_with_temperature(logits: &Array1, temperature: f64) -> usize { + assert!(temperature > 0.0, "Temperature must be positive"); + + // Scale logits by temperature + let scaled: Vec = logits.iter().map(|&x| x / temperature).collect(); + + // Softmax with numerical stability + let max_val = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max); + let exp_vals: Vec = scaled.iter().map(|&x| (x - max_val).exp()).collect(); + let sum: f64 = exp_vals.iter().sum(); + let probs: Vec = exp_vals.iter().map(|&x| x / sum).collect(); + + // Sample from distribution + let mut rng_val = rand::random::(); + for (i, &prob) in probs.iter().enumerate() { + rng_val -= prob; + if rng_val <= 0.0 { + return i; + } + } + + probs.len() - 1 +} + +/// Sample from top-k most likely tokens. +/// +/// Filters to k highest probability tokens before sampling. +pub fn sample_top_k(logits: &Array1, k: usize, temperature: f64) -> usize { + assert!(k > 0, "k must be positive"); + assert!(temperature > 0.0, "Temperature must be positive"); + + let k = k.min(logits.len()); + + // Get indices sorted by logit value (descending) + let mut indexed: Vec<(usize, f64)> = logits.iter().enumerate().map(|(i, &v)| (i, v)).collect(); + indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + + // Keep top k + let top_k: Vec<(usize, f64)> = indexed.into_iter().take(k).collect(); + + // Apply temperature and softmax to top-k only + let scaled: Vec = top_k.iter().map(|(_, v)| v / temperature).collect(); + let max_val = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max); + let exp_vals: Vec = scaled.iter().map(|&x| (x - max_val).exp()).collect(); + let sum: f64 = exp_vals.iter().sum(); + let probs: Vec = exp_vals.iter().map(|&x| x / sum).collect(); + + // Sample + let mut rng_val = rand::random::(); + for (i, &prob) in probs.iter().enumerate() { + rng_val -= prob; + if rng_val <= 0.0 { + return top_k[i].0; + } + } + + top_k[k - 1].0 +} + +/// Nucleus (top-p) sampling. +/// +/// Samples from smallest set of tokens whose cumulative probability exceeds p. +pub fn sample_nucleus(logits: &Array1, p: f64, temperature: f64) -> usize { + assert!(p > 0.0 && p <= 1.0, "p must be in (0, 1]"); + assert!(temperature > 0.0, "Temperature must be positive"); + + // Apply temperature and softmax + let scaled: Vec = logits.iter().map(|&x| x / temperature).collect(); + let max_val = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max); + let exp_vals: Vec = scaled.iter().map(|&x| (x - max_val).exp()).collect(); + let sum: f64 = exp_vals.iter().sum(); + let probs: Vec = exp_vals.iter().map(|&x| x / sum).collect(); + + // Sort by probability (descending) + let mut indexed: Vec<(usize, f64)> = probs.iter().enumerate().map(|(i, &v)| (i, v)).collect(); + indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + + // Find nucleus (cumulative prob >= p) + let mut cumulative = 0.0; + let mut nucleus: Vec<(usize, f64)> = Vec::new(); + for (idx, prob) in indexed { + cumulative += prob; + nucleus.push((idx, prob)); + if cumulative >= p { + break; + } + } + + // Renormalize nucleus probabilities + let nucleus_sum: f64 = nucleus.iter().map(|(_, prob)| prob).sum(); + let nucleus_probs: Vec = nucleus.iter().map(|(_, prob)| prob / nucleus_sum).collect(); + + // Sample from nucleus + let mut rng_val = rand::random::(); + for (i, &prob) in nucleus_probs.iter().enumerate() { + rng_val -= prob; + if rng_val <= 0.0 { + return nucleus[i].0; + } + } + + nucleus.last().map(|(idx, _)| *idx).unwrap_or(0) +} + +/// Get argmax (greedy decoding). +pub fn argmax(logits: &Array1) -> usize { + logits.iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) + .map(|(idx, _)| idx) + .unwrap_or(0) +} + +/// Apply softmax to logits. +pub fn softmax(logits: &Array1) -> Array1 { + let max_val = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max); + let exp_vals: Array1 = logits.mapv(|x| (x - max_val).exp()); + let sum: f64 = exp_vals.sum(); + exp_vals / sum +} + +#[cfg(test)] +mod tests { + use super::*; + use ndarray::arr1; + + #[test] + fn test_vocabulary_from_text() { + let vocab = TextVocabulary::from_text("hello"); + assert_eq!(vocab.size(), 4); // h, e, l, o + assert!(vocab.contains('h')); + assert!(vocab.contains('l')); + assert!(!vocab.contains('x')); + } + + #[test] + fn test_vocabulary_encode_decode() { + let vocab = TextVocabulary::from_text("abc"); + let encoded = vocab.encode("cab"); + let decoded = vocab.decode(&encoded); + assert_eq!(decoded, "cab"); + } + + #[test] + fn test_embedding_forward() { + let mut emb = CharacterEmbedding::new(10, 8); + let output = emb.forward(&[0, 3, 5]); + assert_eq!(output.shape(), &[3, 8]); + } + + #[test] + fn test_embedding_lookup() { + let emb = CharacterEmbedding::new(10, 8); + let vec = emb.lookup(5); + assert_eq!(vec.len(), 8); + } + + #[test] + fn test_sample_with_temperature() { + let logits = arr1(&[1.0, 2.0, 3.0]); + let idx = sample_with_temperature(&logits, 1.0); + assert!(idx < 3); + } + + #[test] + fn test_sample_top_k() { + let logits = arr1(&[1.0, 5.0, 2.0, 0.5]); + let idx = sample_top_k(&logits, 2, 1.0); + // Should only sample from indices 1 or 2 (top 2) + assert!(idx == 1 || idx == 2); + } + + #[test] + fn test_sample_nucleus() { + let logits = arr1(&[0.0, 10.0, 0.0]); // Very peaked distribution + let idx = sample_nucleus(&logits, 0.9, 1.0); + assert_eq!(idx, 1); // Should almost always be 1 + } + + #[test] + fn test_argmax() { + let logits = arr1(&[1.0, 5.0, 2.0]); + assert_eq!(argmax(&logits), 1); + } + + #[test] + fn test_softmax() { + let logits = arr1(&[1.0, 2.0, 3.0]); + let probs = softmax(&logits); + assert!((probs.sum() - 1.0).abs() < 1e-6); + } +}