Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,27 @@ def forward(self, hidden_states):
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


class MistralRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()

self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)

# Init function for inv_freq
def init_inv_freq(device: torch.device):
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
return inv_freq

self.register_buffer("inv_freq", init_inv_freq(device), persistent=False)
self.original_inv_freq = self.inv_freq

# Save buffer init callback
MistralModel.buf_init_callbacks.setdefault("rotary_emb.inv_freq", init_inv_freq)

@torch.no_grad()
# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward
Expand Down Expand Up @@ -700,6 +707,8 @@ class MistralModel(MistralPreTrainedModel):
Args:
config: MistralConfig
"""
# A dict from buffer FQN to its init function
buf_init_callbacks = {}

def __init__(self, config: MistralConfig):
super().__init__(config)
Expand Down Expand Up @@ -953,6 +962,8 @@ def _update_causal_mask(

class MistralForCausalLM(MistralPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
# A dict from buffer FQN to its init function
buf_init_callbacks = {}

def __init__(self, config):
super().__init__(config)
Expand All @@ -962,6 +973,11 @@ def __init__(self, config):

# Initialize weights and apply final processing
self.post_init()
# Create buffer init callbacks by extending the one from `LlamaModel`,
# i.e. appending a prefix to all buffer FQNs.
for key, val in self.model.buf_init_callbacks.items():
new_key = ".".join(["pattern", key])
self.buf_init_callbacks[new_key] = val

def get_input_embeddings(self):
return self.model.embed_tokens
Expand Down Expand Up @@ -1098,10 +1114,10 @@ def prepare_inputs_for_generation(
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
Expand Down