From 7bb876436666c7a6496f92562f8db9c48df31863 Mon Sep 17 00:00:00 2001 From: Seas0 Date: Thu, 5 Jun 2025 15:30:34 +0800 Subject: [PATCH] generate: skip unnecessary unmask steps and warn user about it Running the model against inputs without any `[mask]` token is absurd and useless, try to intercept this situation and skip the model when sampling. --- generate.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/generate.py b/generate.py index c2cef3b2..e34a68b3 100644 --- a/generate.py +++ b/generate.py @@ -70,6 +70,10 @@ def generate(model, prompt, steps=128, gen_length=128, block_length=128, tempera block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id) num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) for i in range(steps): + # Skip non-informative steps + if num_transfer_tokens[:, i].sum() == 0: + warning('Detected unnecessary unmask steps w/o masked inputs, please lower the total step cnt.') + continue mask_index = (x == mask_id) if cfg_scale > 0.: un_x = x.clone()