From 13c2ddb7266f74508d03e6ec85b1d9c980dd5e77 Mon Sep 17 00:00:00 2001 From: Aman Priyanshu Date: Wed, 26 Feb 2025 14:27:16 -0800 Subject: [PATCH] enabling mps inference --- chat.py | 4 ++-- generate.py | 19 +++++++++++++------ 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/chat.py b/chat.py index dd4eff49..0b882a09 100644 --- a/chat.py +++ b/chat.py @@ -5,7 +5,7 @@ def chat(): - device = 'cuda' + device = 'mps' model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval() tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True) @@ -29,7 +29,7 @@ def chat(): else: prompt = torch.cat([prompt, input_ids[:, 1:]], dim=1) - out = generate(model, prompt, steps=steps, gen_length=gen_length, block_length=32, temperature=0., cfg_scale=0., remasking='low_confidence') + out = generate(model, prompt, steps=steps, gen_length=gen_length, block_length=32, temperature=0., cfg_scale=0., remasking='low_confidence', is_mps=True) answer = tokenizer.batch_decode(out[:, prompt.shape[1]:], skip_special_tokens=True)[0] print(f"Bot's reply: {answer}") diff --git a/generate.py b/generate.py index 47b5bfa8..977a6af5 100644 --- a/generate.py +++ b/generate.py @@ -5,14 +5,18 @@ from transformers import AutoTokenizer, AutoModel -def add_gumbel_noise(logits, temperature): +def add_gumbel_noise(logits, temperature, is_mps): ''' The Gumbel max is a method for sampling categorical distributions. According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. Thus, we use float64. ''' - logits = logits.to(torch.float64) - noise = torch.rand_like(logits, dtype=torch.float64) + if not is_mps: + logits = logits.to(torch.float64) + noise = torch.rand_like(logits, dtype=torch.float64) + else: + logits = logits.to(torch.float32) + noise = torch.rand_like(logits, dtype=torch.float32) gumbel_noise = (- torch.log(noise)) ** temperature return logits.exp() / gumbel_noise @@ -40,7 +44,7 @@ def get_num_transfer_tokens(mask_index, steps): @ torch.no_grad() def generate(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0., - cfg_scale=0., remasking='low_confidence', mask_id=126336): + cfg_scale=0., remasking='low_confidence', mask_id=126336, is_mps=False): ''' Args: model: Mask predictor. @@ -79,11 +83,14 @@ def generate(model, prompt, steps=128, gen_length=128, block_length=128, tempera else: logits = model(x).logits - logits_with_noise = add_gumbel_noise(logits, temperature=temperature) + logits_with_noise = add_gumbel_noise(logits, temperature=temperature, is_mps=is_mps) x0 = torch.argmax(logits_with_noise, dim=-1) # b, l if remasking == 'low_confidence': - p = F.softmax(logits.to(torch.float64), dim=-1) + if not is_mps: + p = F.softmax(logits.to(torch.float64), dim=-1) + else: + p = F.softmax(logits.to(torch.float32), dim=-1) x0_p = torch.squeeze( torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l elif remasking == 'random':