From 56c2eb23f7e1d448d4bf2393e7bc23b889fab822 Mon Sep 17 00:00:00 2001 From: Aditya Tomar Date: Thu, 18 Sep 2025 21:28:05 -0700 Subject: [PATCH] fix batch size and single-gpu accelerator bug --- eval_llada.py | 3 ++- generate.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/eval_llada.py b/eval_llada.py index 9e2a9a70..4588521c 100644 --- a/eval_llada.py +++ b/eval_llada.py @@ -274,7 +274,8 @@ def _tokenize(e): generated_answer = self.tokenizer.decode(generated_answer_ids, skip_special_tokens=True) out.append(generated_answer) - self.accelerator.wait_for_everyone() + if self.accelerator is not None: + self.accelerator.wait_for_everyone() return out diff --git a/generate.py b/generate.py index c2cef3b2..f97799bb 100644 --- a/generate.py +++ b/generate.py @@ -55,7 +55,7 @@ def generate(model, prompt, steps=128, gen_length=128, block_length=128, tempera remasking: Remasking strategy. 'low_confidence' or 'random'. mask_id: The toke id of [MASK] is 126336. ''' - x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device) + x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device) x[:, :prompt.shape[1]] = prompt.clone() prompt_index = (x != mask_id)