Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,18 @@ def __init__(
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Init function for inv_freq
def init_inv_freq(device: torch.device):
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
return inv_freq

self.register_buffer("inv_freq", init_inv_freq(device), persistent=False)
self.original_inv_freq = self.inv_freq

# Save buffer init callback
LlamaModel.buf_init_callbacks.setdefault("rotary_emb.inv_freq", init_inv_freq)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this not be called in the LlamaModel directly, it's a bit weird for us to register something like this at the class level



def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
Expand Down Expand Up @@ -891,6 +899,8 @@ class LlamaModel(LlamaPreTrainedModel):
Args:
config: LlamaConfig
"""
# A dict from buffer FQN to its init function
buf_init_callbacks = {}

def __init__(self, config: LlamaConfig):
super().__init__(config)
Expand All @@ -908,6 +918,7 @@ def __init__(self, config: LlamaConfig):
# Initialize weights and apply final processing
self.post_init()


def get_input_embeddings(self):
return self.embed_tokens

Expand Down Expand Up @@ -1111,6 +1122,9 @@ def _update_causal_mask(
class LlamaForCausalLM(LlamaPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]

# A dict from buffer FQN to its init function
buf_init_callbacks = {}

def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
Expand All @@ -1120,6 +1134,12 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

# Create buffer init callbacks by extending the one from `LlamaModel`,
# i.e. appending a prefix to all buffer FQNs.
for key, val in self.model.buf_init_callbacks.items():
new_key = ".".join(["model", key])
self.buf_init_callbacks[new_key] = val
Comment on lines +1137 to +1141
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this should go in the LlamaPreTrainedModel at best!


def get_input_embeddings(self):
return self.model.embed_tokens

Expand Down