diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 1a2b732e85e4..2cdb5257a944 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -72,11 +72,9 @@ def forward(self, hidden_states): variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) - def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - class MistralRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -84,8 +82,17 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Init function for inv_freq + def init_inv_freq(device: torch.device): + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + return inv_freq + + self.register_buffer("inv_freq", init_inv_freq(device), persistent=False) + self.original_inv_freq = self.inv_freq + + # Save buffer init callback + MistralModel.buf_init_callbacks.setdefault("rotary_emb.inv_freq", init_inv_freq) @torch.no_grad() # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward @@ -700,6 +707,8 @@ class MistralModel(MistralPreTrainedModel): Args: config: MistralConfig """ + # A dict from buffer FQN to its init function + buf_init_callbacks = {} def __init__(self, config: MistralConfig): super().__init__(config) @@ -953,6 +962,8 @@ def _update_causal_mask( class MistralForCausalLM(MistralPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] + # A dict from buffer FQN to its init function + buf_init_callbacks = {} def __init__(self, config): super().__init__(config) @@ -962,6 +973,11 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() + # Create buffer init callbacks by extending the one from `LlamaModel`, + # i.e. appending a prefix to all buffer FQNs. + for key, val in self.model.buf_init_callbacks.items(): + new_key = ".".join(["pattern", key]) + self.buf_init_callbacks[new_key] = val def get_input_embeddings(self): return self.model.embed_tokens @@ -1098,10 +1114,10 @@ def prepare_inputs_for_generation( position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] - - # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. position_ids = position_ids.clone(memory_format=torch.contiguous_format) - + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds}