Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions generate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import numpy as np
import time
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModel
Expand Down Expand Up @@ -40,9 +41,10 @@ def get_num_transfer_tokens(mask_index, steps):
return num_transfer_tokens


@ torch.no_grad()
@torch.no_grad()
def generate(model, prompt, attention_mask=None, steps=128, gen_length=128, block_length=128, temperature=0.,
cfg_scale=0., remasking='low_confidence', mask_id=126336, logits_eos_inf=False, confidence_eos_eot_inf=False):
cfg_scale=0., remasking='low_confidence', mask_id=126336, logits_eos_inf=False,
confidence_eos_eot_inf=False):
'''
Args:
model: Mask predictor.
Expand All @@ -61,7 +63,9 @@ def generate(model, prompt, attention_mask=None, steps=128, gen_length=128, bloc
x[:, :prompt.shape[1]] = prompt.clone()

if attention_mask is not None:
attention_mask = torch.cat([attention_mask, torch.ones((prompt.shape[0], gen_length), dtype=attention_mask.dtype, device=model.device)], dim=-1)
attention_mask = torch.cat([attention_mask,
torch.ones((prompt.shape[0], gen_length), dtype=attention_mask.dtype,
device=model.device)], dim=-1)

prompt_index = (x != mask_id)

Expand All @@ -72,7 +76,8 @@ def generate(model, prompt, attention_mask=None, steps=128, gen_length=128, bloc
steps = steps // num_blocks

for num_block in range(num_blocks):
block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id)
block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (
num_block + 1) * block_length:] == mask_id)
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
for i in range(steps):
mask_index = (x == mask_id)
Expand All @@ -92,15 +97,15 @@ def generate(model, prompt, attention_mask=None, steps=128, gen_length=128, bloc
logits[:, :, 126081] = -torch.inf

logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
x0 = torch.argmax(logits_with_noise, dim=-1) # b, l

if confidence_eos_eot_inf:
logits_with_noise[:, :, 126081] = logits[:, :, 126348] = -torch.inf

if remasking == 'low_confidence':
p = F.softmax(logits, dim=-1)
x0_p = torch.squeeze(
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
elif remasking == 'random':
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
else:
Expand All @@ -123,7 +128,8 @@ def generate(model, prompt, attention_mask=None, steps=128, gen_length=128, bloc
def main():
device = 'cuda'

model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True,
torch_dtype=torch.bfloat16).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)

# The LLaDA architecture theoretically supports both left-padding and right-padding.
Expand All @@ -134,13 +140,15 @@ def main():
# If the padding ID equals the mask ID, you need to modify our generate function to achieve correct inference.
assert tokenizer.pad_token_id != 126336

prompts = [ "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?",
"Joy can read 8 pages of a book in 20 minutes. How many hours will it take her to read 120 pages?",
"Randy has 60 mango trees on his farm. He also has 5 less than half as many coconut trees as mango trees. How many trees does Randy have in all on his farm?"]
prompts = [
"Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?",
"Joy can read 8 pages of a book in 20 minutes. How many hours will it take her to read 120 pages?",
"Randy has 60 mango trees on his farm. He also has 5 less than half as many coconut trees as mango trees. How many trees does Randy have in all on his farm?"]

# Add special tokens for the Instruct model. The Base model does not require the following two lines.
messages = [{"role": "user", "content": prompt} for prompt in prompts]
prompts = [tokenizer.apply_chat_template([message], add_generation_prompt=True, tokenize=False) for message in messages]
prompts = [tokenizer.apply_chat_template([message], add_generation_prompt=True, tokenize=False) for message in
messages]

encoded_outputs = tokenizer(
prompts,
Expand All @@ -151,11 +159,14 @@ def main():
input_ids = encoded_outputs['input_ids'].to(device)
attention_mask = encoded_outputs['attention_mask'].to(device)

out = generate(model, input_ids, attention_mask, steps=128, gen_length=128, block_length=32, temperature=0., cfg_scale=0., remasking='low_confidence')
out = generate(model, input_ids, attention_mask, steps=128, gen_length=128, block_length=128, temperature=0.5,
cfg_scale=0., remasking='low_confidence')
output = tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)

for o in output:
print(o)
print('-' * 50)


if __name__ == '__main__':
main()
203 changes: 203 additions & 0 deletions generate_optimized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
import torch
import numpy as np
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel


def add_gumbel_noise_original(logits, temperature):
'''
The Gumbel max is a method for sampling categorical distributions.
According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
Thus, we use float64.
'''
if temperature == 0:
return logits
logits = logits.to(torch.float64)
noise = torch.rand_like(logits, dtype=torch.float64)
gumbel_noise = (- torch.log(noise)) ** temperature
return logits.exp() / gumbel_noise


def sample_gumbel_chunked(logits, temperature):
'''
Optimized implementation:
Iterates through the batch dimension, performing gumbel + argmax sequentially.

Reduces Peak VRAM for sampling by a factor of Batch_Size.
'''
if temperature == 0:
return torch.argmax(logits, dim=-1)

output_indices = torch.empty(logits.shape[:-1], dtype=torch.long, device=logits.device)

for i in range(logits.shape[0]):
logit_slice = logits[i].to(torch.float64)

noise = torch.rand_like(logit_slice, dtype=torch.float64)

gumbel_noise = (-torch.log(noise)) ** temperature
noisy_logits = logit_slice.exp() / gumbel_noise

output_indices[i] = torch.argmax(noisy_logits, dim=-1)

return output_indices


def get_num_transfer_tokens(mask_index, steps):
'''
In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
the expected number of tokens transitioned at each step should be consistent.

This function is designed to precompute the number of tokens that need to be transitioned at each step.
'''
mask_num = mask_index.sum(dim=1, keepdim=True)
base = mask_num // steps
remainder = mask_num % steps
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
for i in range(mask_num.size(0)):
num_transfer_tokens[i, :remainder[i]] += 1
return num_transfer_tokens


@torch.no_grad()
def generate(model, prompt, attention_mask=None, steps=128, gen_length=128, block_length=128, temperature=0.,
cfg_scale=0., remasking='low_confidence', mask_id=126336, logits_eos_inf=False,
confidence_eos_eot_inf=False,
use_efficient_sampling=True): # <--- New Flag
'''
Args:
model: Mask predictor.
prompt: A tensor of shape (1, L).
steps: Sampling steps, less than or equal to gen_length.
gen_length: Generated answer length.
block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
temperature: Categorical distribution sampling temperature.
cfg_scale: Unsupervised classifier-free guidance scale.
remasking: Remasking strategy. 'low_confidence' or 'random'.
mask_id: The toke id of [MASK] is 126336.
logits_eos_inf: Whether to set the logits of EOS token to -inf. See Appendix B.4 of LLaDA for details
confidence_eos_eot_inf: Whether to set the confidence of EOS and EoT token to -inf. See Appendix B.4 of LLaDA for details
use_efficient_sampling: Whether to set batched gumbel noise to save VRAM.
'''

x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
x[:, :prompt.shape[1]] = prompt.clone()

if attention_mask is not None:
attention_mask = torch.cat([attention_mask,
torch.ones((prompt.shape[0], gen_length), dtype=attention_mask.dtype,
device=model.device)], dim=-1)

prompt_index = (x != mask_id)

assert gen_length % block_length == 0
num_blocks = gen_length // block_length

assert steps % num_blocks == 0
steps = steps // num_blocks

for num_block in range(num_blocks):
block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (
num_block + 1) * block_length:] == mask_id)
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
for i in range(steps):
mask_index = (x == mask_id)
if cfg_scale > 0.:
un_x = x.clone()
un_x[prompt_index] = mask_id
x_ = torch.cat([x, un_x], dim=0)
if attention_mask is not None:
attention_mask_ = torch.cat([attention_mask, attention_mask], dim=0)
logits = model(x_, attention_mask=attention_mask_).logits
logits, un_logits = torch.chunk(logits, 2, dim=0)
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
else:
logits = model(x, attention_mask=attention_mask).logits

if logits_eos_inf:
logits[:, :, 126081] = -torch.inf

# --- Updated gumel sampling ---
if use_efficient_sampling:
x0 = sample_gumbel_chunked(logits, temperature)
else:
# Original Path: Returns full float64 logit tensor before argmax
logits_with_noise = add_gumbel_noise_original(logits, temperature=temperature)
x0 = torch.argmax(logits_with_noise, dim=-1)

if confidence_eos_eot_inf:
# ---- original code below, possible bug as logits_with_noise isn't used upstream ---
# logits_with_noise[:, :, 126081] = logits[:, :, 126348] = -torch.inf
# ---- We keep behaviour of old code in this PR so just set 126348 to -inf ---
logits[:, :, 126081] = logits[:, :, 126348] = -torch.inf

if remasking == 'low_confidence':
p = F.softmax(logits, dim=-1)
x0_p = torch.squeeze(
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
elif remasking == 'random':
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
else:
raise NotImplementedError(remasking)

x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf

x0 = torch.where(mask_index, x0, x)
confidence = torch.where(mask_index, x0_p, -np.inf)

transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
for j in range(confidence.shape[0]):
_, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
transfer_index[j, select_index] = True
x[transfer_index] = x0[transfer_index]

return x


def main():
device = 'cuda'

model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True,
torch_dtype=torch.bfloat16).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)

# The LLaDA architecture theoretically supports both left-padding and right-padding.
# However, the sampling code implementation is simpler with left-padding.
if tokenizer.padding_side != 'left':
tokenizer.padding_side = 'left'

# If the padding ID equals the mask ID, you need to modify our generate function to achieve correct inference.
assert tokenizer.pad_token_id != 126336

prompts = [
"Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?",
"Joy can read 8 pages of a book in 20 minutes. How many hours will it take her to read 120 pages?",
"Randy has 60 mango trees on his farm. He also has 5 less than half as many coconut trees as mango trees. How many trees does Randy have in all on his farm?"]

# Add special tokens for the Instruct model. The Base model does not require the following two lines.
messages = [{"role": "user", "content": prompt} for prompt in prompts]
prompts = [tokenizer.apply_chat_template([message], add_generation_prompt=True, tokenize=False) for message in
messages]

encoded_outputs = tokenizer(
prompts,
add_special_tokens=False,
padding=True,
return_tensors="pt"
)
input_ids = encoded_outputs['input_ids'].to(device)
attention_mask = encoded_outputs['attention_mask'].to(device)

# Example usage with efficient sampling enabled
out = generate(model, input_ids, attention_mask, steps=128, gen_length=128, block_length=32, temperature=0.7,
cfg_scale=0., remasking='low_confidence', use_efficient_sampling=True)

output = tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)
for o in output:
print(o)
print('-' * 50)


if __name__ == '__main__':
main()
35 changes: 35 additions & 0 deletions gumbel_test_output.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@

>>> [Distribution Check: GPU] Batch=8, Temp=1.0
Sampled Mean Index: 490.88
Expected Mean Index: 500.00
SUCCESS: GPU output distribution looks statistically valid.

>>> [Performance Test] Batch=8, Seq=128, Vocab=128000
Original: Time = 0.3060s | Peak VRAM = 5250.00 MB
Efficient: Time = 0.1930s | Peak VRAM = 2006.01 MB
VRAM Saved: 3243.99 MB

>>> Batch=32 (High VRAM)

>>> [Performance Test] Batch=32, Seq=128, Vocab=128000
Original: Time = 0.3961s | Peak VRAM = 21000.00 MB
Efficient: Time = 0.6998s | Peak VRAM = 5756.03 MB
VRAM Saved: 15243.97 MB

>>> [Logic Check: CPU] Batch=8, Temp=1
Original: Time = 0.0041s
Efficient: Time = 0.0009s
SUCCESS: CPU outputs are identical. Logic is correct.

>>> [Logic Check: CPU] Batch=8, Temp=10
Original: Time = 0.0009s
Efficient: Time = 0.0011s
SUCCESS: CPU outputs are identical. Logic is correct.

>>> [Logit match check CPU] Batch=8, Temp=1
Max Difference: 0.0
SUCCESS: The floating point values are bit-wise identical.

>>> [Logit match check CPU] Batch=8, Temp=10
Max Difference: 0.0
SUCCESS: The floating point values are bit-wise identical.
Loading