diff --git a/generate.py b/generate.py index 3f71e513..c1582ffb 100644 --- a/generate.py +++ b/generate.py @@ -1,5 +1,6 @@ import torch import numpy as np +import time import torch.nn.functional as F from transformers import AutoTokenizer, AutoModel @@ -40,9 +41,10 @@ def get_num_transfer_tokens(mask_index, steps): return num_transfer_tokens -@ torch.no_grad() +@torch.no_grad() def generate(model, prompt, attention_mask=None, steps=128, gen_length=128, block_length=128, temperature=0., - cfg_scale=0., remasking='low_confidence', mask_id=126336, logits_eos_inf=False, confidence_eos_eot_inf=False): + cfg_scale=0., remasking='low_confidence', mask_id=126336, logits_eos_inf=False, + confidence_eos_eot_inf=False): ''' Args: model: Mask predictor. @@ -61,7 +63,9 @@ def generate(model, prompt, attention_mask=None, steps=128, gen_length=128, bloc x[:, :prompt.shape[1]] = prompt.clone() if attention_mask is not None: - attention_mask = torch.cat([attention_mask, torch.ones((prompt.shape[0], gen_length), dtype=attention_mask.dtype, device=model.device)], dim=-1) + attention_mask = torch.cat([attention_mask, + torch.ones((prompt.shape[0], gen_length), dtype=attention_mask.dtype, + device=model.device)], dim=-1) prompt_index = (x != mask_id) @@ -72,7 +76,8 @@ def generate(model, prompt, attention_mask=None, steps=128, gen_length=128, bloc steps = steps // num_blocks for num_block in range(num_blocks): - block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id) + block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + ( + num_block + 1) * block_length:] == mask_id) num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) for i in range(steps): mask_index = (x == mask_id) @@ -92,15 +97,15 @@ def generate(model, prompt, attention_mask=None, steps=128, gen_length=128, bloc logits[:, :, 126081] = -torch.inf logits_with_noise = add_gumbel_noise(logits, temperature=temperature) - x0 = torch.argmax(logits_with_noise, dim=-1) # b, l - + x0 = torch.argmax(logits_with_noise, dim=-1) # b, l + if confidence_eos_eot_inf: logits_with_noise[:, :, 126081] = logits[:, :, 126348] = -torch.inf if remasking == 'low_confidence': p = F.softmax(logits, dim=-1) x0_p = torch.squeeze( - torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l + torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l elif remasking == 'random': x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) else: @@ -123,7 +128,8 @@ def generate(model, prompt, attention_mask=None, steps=128, gen_length=128, bloc def main(): device = 'cuda' - model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval() + model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True, + torch_dtype=torch.bfloat16).to(device).eval() tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True) # The LLaDA architecture theoretically supports both left-padding and right-padding. @@ -134,13 +140,15 @@ def main(): # If the padding ID equals the mask ID, you need to modify our generate function to achieve correct inference. assert tokenizer.pad_token_id != 126336 - prompts = [ "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?", - "Joy can read 8 pages of a book in 20 minutes. How many hours will it take her to read 120 pages?", - "Randy has 60 mango trees on his farm. He also has 5 less than half as many coconut trees as mango trees. How many trees does Randy have in all on his farm?"] + prompts = [ + "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?", + "Joy can read 8 pages of a book in 20 minutes. How many hours will it take her to read 120 pages?", + "Randy has 60 mango trees on his farm. He also has 5 less than half as many coconut trees as mango trees. How many trees does Randy have in all on his farm?"] # Add special tokens for the Instruct model. The Base model does not require the following two lines. messages = [{"role": "user", "content": prompt} for prompt in prompts] - prompts = [tokenizer.apply_chat_template([message], add_generation_prompt=True, tokenize=False) for message in messages] + prompts = [tokenizer.apply_chat_template([message], add_generation_prompt=True, tokenize=False) for message in + messages] encoded_outputs = tokenizer( prompts, @@ -151,11 +159,14 @@ def main(): input_ids = encoded_outputs['input_ids'].to(device) attention_mask = encoded_outputs['attention_mask'].to(device) - out = generate(model, input_ids, attention_mask, steps=128, gen_length=128, block_length=32, temperature=0., cfg_scale=0., remasking='low_confidence') + out = generate(model, input_ids, attention_mask, steps=128, gen_length=128, block_length=128, temperature=0.5, + cfg_scale=0., remasking='low_confidence') output = tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True) + for o in output: print(o) print('-' * 50) + if __name__ == '__main__': main() diff --git a/generate_optimized.py b/generate_optimized.py new file mode 100644 index 00000000..1b70da5f --- /dev/null +++ b/generate_optimized.py @@ -0,0 +1,203 @@ +import torch +import numpy as np +import torch.nn.functional as F +from transformers import AutoTokenizer, AutoModel + + +def add_gumbel_noise_original(logits, temperature): + ''' + The Gumbel max is a method for sampling categorical distributions. + According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. + Thus, we use float64. + ''' + if temperature == 0: + return logits + logits = logits.to(torch.float64) + noise = torch.rand_like(logits, dtype=torch.float64) + gumbel_noise = (- torch.log(noise)) ** temperature + return logits.exp() / gumbel_noise + + +def sample_gumbel_chunked(logits, temperature): + ''' + Optimized implementation: + Iterates through the batch dimension, performing gumbel + argmax sequentially. + + Reduces Peak VRAM for sampling by a factor of Batch_Size. + ''' + if temperature == 0: + return torch.argmax(logits, dim=-1) + + output_indices = torch.empty(logits.shape[:-1], dtype=torch.long, device=logits.device) + + for i in range(logits.shape[0]): + logit_slice = logits[i].to(torch.float64) + + noise = torch.rand_like(logit_slice, dtype=torch.float64) + + gumbel_noise = (-torch.log(noise)) ** temperature + noisy_logits = logit_slice.exp() / gumbel_noise + + output_indices[i] = torch.argmax(noisy_logits, dim=-1) + + return output_indices + + +def get_num_transfer_tokens(mask_index, steps): + ''' + In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals. + Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)), + the expected number of tokens transitioned at each step should be consistent. + + This function is designed to precompute the number of tokens that need to be transitioned at each step. + ''' + mask_num = mask_index.sum(dim=1, keepdim=True) + base = mask_num // steps + remainder = mask_num % steps + num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base + for i in range(mask_num.size(0)): + num_transfer_tokens[i, :remainder[i]] += 1 + return num_transfer_tokens + + +@torch.no_grad() +def generate(model, prompt, attention_mask=None, steps=128, gen_length=128, block_length=128, temperature=0., + cfg_scale=0., remasking='low_confidence', mask_id=126336, logits_eos_inf=False, + confidence_eos_eot_inf=False, + use_efficient_sampling=True): # <--- New Flag + ''' + Args: + model: Mask predictor. + prompt: A tensor of shape (1, L). + steps: Sampling steps, less than or equal to gen_length. + gen_length: Generated answer length. + block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking. + temperature: Categorical distribution sampling temperature. + cfg_scale: Unsupervised classifier-free guidance scale. + remasking: Remasking strategy. 'low_confidence' or 'random'. + mask_id: The toke id of [MASK] is 126336. + logits_eos_inf: Whether to set the logits of EOS token to -inf. See Appendix B.4 of LLaDA for details + confidence_eos_eot_inf: Whether to set the confidence of EOS and EoT token to -inf. See Appendix B.4 of LLaDA for details + use_efficient_sampling: Whether to set batched gumbel noise to save VRAM. + ''' + + x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device) + x[:, :prompt.shape[1]] = prompt.clone() + + if attention_mask is not None: + attention_mask = torch.cat([attention_mask, + torch.ones((prompt.shape[0], gen_length), dtype=attention_mask.dtype, + device=model.device)], dim=-1) + + prompt_index = (x != mask_id) + + assert gen_length % block_length == 0 + num_blocks = gen_length // block_length + + assert steps % num_blocks == 0 + steps = steps // num_blocks + + for num_block in range(num_blocks): + block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + ( + num_block + 1) * block_length:] == mask_id) + num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) + for i in range(steps): + mask_index = (x == mask_id) + if cfg_scale > 0.: + un_x = x.clone() + un_x[prompt_index] = mask_id + x_ = torch.cat([x, un_x], dim=0) + if attention_mask is not None: + attention_mask_ = torch.cat([attention_mask, attention_mask], dim=0) + logits = model(x_, attention_mask=attention_mask_).logits + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (cfg_scale + 1) * (logits - un_logits) + else: + logits = model(x, attention_mask=attention_mask).logits + + if logits_eos_inf: + logits[:, :, 126081] = -torch.inf + + # --- Updated gumel sampling --- + if use_efficient_sampling: + x0 = sample_gumbel_chunked(logits, temperature) + else: + # Original Path: Returns full float64 logit tensor before argmax + logits_with_noise = add_gumbel_noise_original(logits, temperature=temperature) + x0 = torch.argmax(logits_with_noise, dim=-1) + + if confidence_eos_eot_inf: + # ---- original code below, possible bug as logits_with_noise isn't used upstream --- + # logits_with_noise[:, :, 126081] = logits[:, :, 126348] = -torch.inf + # ---- We keep behaviour of old code in this PR so just set 126348 to -inf --- + logits[:, :, 126081] = logits[:, :, 126348] = -torch.inf + + if remasking == 'low_confidence': + p = F.softmax(logits, dim=-1) + x0_p = torch.squeeze( + torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l + elif remasking == 'random': + x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) + else: + raise NotImplementedError(remasking) + + x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf + + x0 = torch.where(mask_index, x0, x) + confidence = torch.where(mask_index, x0_p, -np.inf) + + transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) + for j in range(confidence.shape[0]): + _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) + transfer_index[j, select_index] = True + x[transfer_index] = x0[transfer_index] + + return x + + +def main(): + device = 'cuda' + + model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True, + torch_dtype=torch.bfloat16).to(device).eval() + tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True) + + # The LLaDA architecture theoretically supports both left-padding and right-padding. + # However, the sampling code implementation is simpler with left-padding. + if tokenizer.padding_side != 'left': + tokenizer.padding_side = 'left' + + # If the padding ID equals the mask ID, you need to modify our generate function to achieve correct inference. + assert tokenizer.pad_token_id != 126336 + + prompts = [ + "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?", + "Joy can read 8 pages of a book in 20 minutes. How many hours will it take her to read 120 pages?", + "Randy has 60 mango trees on his farm. He also has 5 less than half as many coconut trees as mango trees. How many trees does Randy have in all on his farm?"] + + # Add special tokens for the Instruct model. The Base model does not require the following two lines. + messages = [{"role": "user", "content": prompt} for prompt in prompts] + prompts = [tokenizer.apply_chat_template([message], add_generation_prompt=True, tokenize=False) for message in + messages] + + encoded_outputs = tokenizer( + prompts, + add_special_tokens=False, + padding=True, + return_tensors="pt" + ) + input_ids = encoded_outputs['input_ids'].to(device) + attention_mask = encoded_outputs['attention_mask'].to(device) + + # Example usage with efficient sampling enabled + out = generate(model, input_ids, attention_mask, steps=128, gen_length=128, block_length=32, temperature=0.7, + cfg_scale=0., remasking='low_confidence', use_efficient_sampling=True) + + output = tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True) + for o in output: + print(o) + print('-' * 50) + + +if __name__ == '__main__': + main() diff --git a/gumbel_test_output.txt b/gumbel_test_output.txt new file mode 100644 index 00000000..7a650889 --- /dev/null +++ b/gumbel_test_output.txt @@ -0,0 +1,35 @@ + +>>> [Distribution Check: GPU] Batch=8, Temp=1.0 + Sampled Mean Index: 490.88 + Expected Mean Index: 500.00 +SUCCESS: GPU output distribution looks statistically valid. + +>>> [Performance Test] Batch=8, Seq=128, Vocab=128000 +Original: Time = 0.3060s | Peak VRAM = 5250.00 MB +Efficient: Time = 0.1930s | Peak VRAM = 2006.01 MB +VRAM Saved: 3243.99 MB + +>>> Batch=32 (High VRAM) + +>>> [Performance Test] Batch=32, Seq=128, Vocab=128000 +Original: Time = 0.3961s | Peak VRAM = 21000.00 MB +Efficient: Time = 0.6998s | Peak VRAM = 5756.03 MB +VRAM Saved: 15243.97 MB + +>>> [Logic Check: CPU] Batch=8, Temp=1 +Original: Time = 0.0041s +Efficient: Time = 0.0009s +SUCCESS: CPU outputs are identical. Logic is correct. + +>>> [Logic Check: CPU] Batch=8, Temp=10 +Original: Time = 0.0009s +Efficient: Time = 0.0011s +SUCCESS: CPU outputs are identical. Logic is correct. + +>>> [Logit match check CPU] Batch=8, Temp=1 + Max Difference: 0.0 +SUCCESS: The floating point values are bit-wise identical. + +>>> [Logit match check CPU] Batch=8, Temp=10 + Max Difference: 0.0 +SUCCESS: The floating point values are bit-wise identical. diff --git a/test_efficient_gumbel.py b/test_efficient_gumbel.py new file mode 100644 index 00000000..01405bb5 --- /dev/null +++ b/test_efficient_gumbel.py @@ -0,0 +1,179 @@ +import torch +import time +import argparse +import numpy as np +from generate_optimized import sample_gumbel_chunked, add_gumbel_noise_original + + +def get_peak_memory_mb(): + return torch.cuda.max_memory_allocated() / (1024 * 1024) + + +def reset_memory_stats(): + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + + +def test_logic_strict_cpu(batch_size, seq_len, vocab_size, temperature=1.0): + """ + Runs on CPU to verify the output is identical. + CPU RNG is serial, so this should match exactly. + """ + print(f"\n>>> [Logic Check: CPU] Batch={batch_size}, Temp={temperature}") + device = 'cpu' + logits = torch.randn(batch_size, seq_len, vocab_size, device=device, dtype=torch.float32) # float32 for CPU speed + + torch.manual_seed(42) + start_t = time.time() + noisy_logits_orig = add_gumbel_noise_original(logits, temperature) + indices_orig = torch.argmax(noisy_logits_orig, dim=-1) + print (f'Original: Time = {time.time() - start_t:.4f}s') + + torch.manual_seed(42) + start_t = time.time() + indices_eff = sample_gumbel_chunked(logits, temperature) + print (f'Efficient: Time = {time.time() - start_t:.4f}s') + mismatch = (indices_orig != indices_eff).sum().item() + if mismatch == 0: + print("SUCCESS: CPU outputs are identical. Logic is correct.") + else: + print(f"FAILURE: CPU outputs differ by {mismatch} elements.") + + +def test_statistical_gpu(batch_size, seq_len, vocab_size, temperature=1.0, device='cuda'): + """ + Runs on GPU to verify the distribution matches. + We cannot expect identity due to Philox RNG. + """ + print(f"\n>>> [Distribution Check: GPU] Batch={batch_size}, Temp={temperature}") + + logits = torch.zeros(batch_size, seq_len, vocab_size, device=device, dtype=torch.float16) + + torch.manual_seed(42) + indices_eff = sample_gumbel_chunked(logits, temperature) + + # Since logits are 0, we are essentially sampling purely from Gumbel noise. + # Uniform distribution of indices means the noise is working. + # We check if the distribution of selected indices is roughly uniform across vocab. + + # Expected mean index should be approx vocab_size / 2 + mean_idx = indices_eff.float().mean().item() + expected_idx = vocab_size / 2.0 + + # We allow a loose tolerance because random is random + diff = abs(mean_idx - expected_idx) + + print(f" Sampled Mean Index: {mean_idx:.2f}") + print(f" Expected Mean Index: {expected_idx:.2f}") + + if diff < (vocab_size * 0.05): # within 5% + print("SUCCESS: GPU output distribution looks statistically valid.") + else: + print("WARNING: Distribution might be skewed (or sample size too small).") + + +def test_performance(batch_size, seq_len, vocab_size, temperature=0.7, device='cuda'): + print(f"\n>>> [Performance Test] Batch={batch_size}, Seq={seq_len}, Vocab={vocab_size}") + + logits = torch.randn(batch_size, seq_len, vocab_size, device=device, dtype=torch.bfloat16) + + reset_memory_stats() + start_t = time.time() + + peak_mem_orig = 0 + try: + _ = add_gumbel_noise_original(logits, temperature) + torch.cuda.synchronize() + end_t = time.time() + peak_mem_orig = get_peak_memory_mb() + time_orig = end_t - start_t + print(f"Original: Time = {time_orig:.4f}s | Peak VRAM = {peak_mem_orig:.2f} MB") + except RuntimeError: + print(f"Original: CRASHED") + peak_mem_orig = float('inf') + + reset_memory_stats() + start_t = time.time() + + _ = sample_gumbel_chunked(logits, temperature) + + torch.cuda.synchronize() + end_t = time.time() + + peak_mem_eff = get_peak_memory_mb() + time_eff = end_t - start_t + + print(f"Efficient: Time = {time_eff:.4f}s | Peak VRAM = {peak_mem_eff:.2f} MB") + + if peak_mem_orig != float('inf'): + saved = peak_mem_orig - peak_mem_eff + print(f"VRAM Saved: {saved:.2f} MB") + + +def sample_gumbel_chunked_debug(logits, temperature): + """ + A 'Debug' version of the efficient function that returns the floats + instead of the argmax, allowing us to verify the appraoch. + """ + if temperature == 0: + return logits + + output_floats = torch.empty_like(logits, dtype=torch.float64) + + for i in range(logits.shape[0]): + logit_slice = logits[i].to(torch.float64) + noise = torch.rand_like(logit_slice, dtype=torch.float64) + gumbel_noise = (-torch.log(noise)) ** temperature + + # Store the float result instead of argmaxing + output_floats[i] = logit_slice.exp() / gumbel_noise + + return output_floats + + +def test_bit_exactness_cpu(batch_size=4, seq_len=32, vocab_size=100, temperature=1.0): + print(f"\n>>> [Logit match check CPU] Batch={batch_size}, Temp={temperature}") + + device = 'cpu' + logits = torch.randn(batch_size, seq_len, vocab_size, device=device, dtype=torch.float32) + + torch.manual_seed(42) + noisy_logits_orig = add_gumbel_noise_original(logits, temperature) + + torch.manual_seed(42) + noisy_logits_chunked = sample_gumbel_chunked_debug(logits, temperature) + + diff = (noisy_logits_orig - noisy_logits_chunked).abs() + max_diff = diff.max().item() + + print(f" Max Difference: {max_diff}") + + if max_diff == 0.0: + print("SUCCESS: The floating point values are bit-wise identical.") + else: + print(f"FAILURE: Max difference is {max_diff}. Math is not identical.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch_size", type=int, default=8) + parser.add_argument("--seq_len", type=int, default=128) + parser.add_argument("--vocab_size", type=int, default=128000) + parser.add_argument("--device", type=str, default="cuda") + args = parser.parse_args() + + if torch.cuda.is_available(): + test_statistical_gpu(8, 128, 1000, device='cuda') + test_performance(args.batch_size, args.seq_len, args.vocab_size, device='cuda') + + print("\n>>> Batch=32 (High VRAM)") + test_performance(32, 128, args.vocab_size, device='cuda') + else: + print("Skipping GPU tests.") + + # Show the algorithms are identical with CPU RNG + + test_logic_strict_cpu(8, 32, 100, temperature=1) + test_logic_strict_cpu(8, 32, 100, temperature=10) + test_bit_exactness_cpu(8, 32, 100, temperature=1) + test_bit_exactness_cpu(8, 32, 100, temperature=10)