diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index dd053c805fb8..52222584249c 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -172,10 +172,18 @@ def __init__( self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) + # Init function for inv_freq + def init_inv_freq(device: torch.device): + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + return inv_freq + + self.register_buffer("inv_freq", init_inv_freq(device), persistent=False) self.original_inv_freq = self.inv_freq + # Save buffer init callback + LlamaModel.buf_init_callbacks.setdefault("rotary_emb.inv_freq", init_inv_freq) + + def _dynamic_frequency_update(self, position_ids, device): """ dynamic RoPE layers should recompute `inv_freq` in the following situations: @@ -891,6 +899,8 @@ class LlamaModel(LlamaPreTrainedModel): Args: config: LlamaConfig """ + # A dict from buffer FQN to its init function + buf_init_callbacks = {} def __init__(self, config: LlamaConfig): super().__init__(config) @@ -908,6 +918,7 @@ def __init__(self, config: LlamaConfig): # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): return self.embed_tokens @@ -1111,6 +1122,9 @@ def _update_causal_mask( class LlamaForCausalLM(LlamaPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] + # A dict from buffer FQN to its init function + buf_init_callbacks = {} + def __init__(self, config): super().__init__(config) self.model = LlamaModel(config) @@ -1120,6 +1134,12 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() + # Create buffer init callbacks by extending the one from `LlamaModel`, + # i.e. appending a prefix to all buffer FQNs. + for key, val in self.model.buf_init_callbacks.items(): + new_key = ".".join(["model", key]) + self.buf_init_callbacks[new_key] = val + def get_input_embeddings(self): return self.model.embed_tokens