diff --git a/collie/driver/io/file.py b/collie/driver/io/file.py index 0196de45..18427fb3 100644 --- a/collie/driver/io/file.py +++ b/collie/driver/io/file.py @@ -4,13 +4,17 @@ import io import torch import shutil +from safetensors.torch import save_file, load_file class FileIODriver(IODriver): @staticmethod def load(path: str, mode: str): assert os.path.exists(path), f"File {path} does not exist." if 'b' in mode.lower(): - return torch.load(path, map_location=torch.device('cpu')) + if path.endswith(".safetensors"): + return load_file(path, device='cpu') + else: + return torch.load(path, map_location=torch.device('cpu')) else: with open(path, 'r') as f: return f.read() diff --git a/collie/models/__init__.py b/collie/models/__init__.py index 9a11a476..a60817e6 100644 --- a/collie/models/__init__.py +++ b/collie/models/__init__.py @@ -6,3 +6,4 @@ from .chatglm2 import ChatGLM2ForCausalLM from .moss_moon import Moss003MoonForCausalLM from .internlm2 import InternLM2ForCausalLM +from .mistral import MistralForCausalLM diff --git a/collie/models/mistral/__init__.py b/collie/models/mistral/__init__.py new file mode 100644 index 00000000..e998c291 --- /dev/null +++ b/collie/models/mistral/__init__.py @@ -0,0 +1,2 @@ +from .model import MistralForCausalLM +from .configuration_mistral import MistralConfig \ No newline at end of file diff --git a/collie/models/mistral/configuration_mistral.py b/collie/models/mistral/configuration_mistral.py new file mode 100644 index 00000000..ad6691b1 --- /dev/null +++ b/collie/models/mistral/configuration_mistral.py @@ -0,0 +1,155 @@ +# coding=utf-8 +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Mistral model configuration""" + +from transformers.configuration_utils import PretrainedConfig +# from transformers.utils import logging +from collie.log.logger import logger + + +# logger = logging.get_logger(__name__) + +MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json", + "mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json", +} + + +class MistralConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an + Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1. + + [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) + [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MistralModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. Mistral's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import MistralModel, MistralConfig + + >>> # Initializing a Mistral 7B style configuration + >>> configuration = MistralConfig() + + >>> # Initializing a model from the Mistral 7B style configuration + >>> model = MistralModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mistral" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=10000.0, + sliding_window=4096, + attention_dropout=0.0, + attn_implementation="flash_attention_2", + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + # 调用父类的初始化函数,将一些公共参数传递给父类处理 + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/collie/models/mistral/model.py b/collie/models/mistral/model.py new file mode 100644 index 00000000..a85d0c72 --- /dev/null +++ b/collie/models/mistral/model.py @@ -0,0 +1,1919 @@ +# coding=utf-8 +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Mistral model.""" +import inspect +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from transformers.modeling_utils import PreTrainedModel, dtype_byte_size +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_mistral import MistralConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MistralConfig" + +#modified for collie +import torch.distributed as dist +import gc +import json +import os +from collections import OrderedDict +from megatron.core import parallel_state, tensor_parallel +from einops import rearrange +from deepspeed.pipe import LayerSpec, TiedLayerSpec + +from collie.config import CollieConfig +from collie.driver.io import IODriver +from collie.log.logger import logger +from collie.module import ( + ColumnParallelLinearWithoutBias, + ColumnParallelLMHead, + RowParallelLinearWithoutBias, +) +from collie.utils import concat_tensor, dict_as_params, env, progress +from collie.models.base import CollieModelForCausalLM +from collie.models.utils import ( + kv_cache_to_inputs_for_layer, inputs_to_kv_cache_for_layer, + kv_cache_to_inputs_for_model, inputs_to_kv_cache_for_model, +) + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral +class MistralRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MistralRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + ans = self.weight * hidden_states.to(input_dtype) + + return self.weight * hidden_states.to(input_dtype) + + +# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral +# TODO @Arthur no longer copied from LLama after static cache +class MistralRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +# TODO @Arthur no longer copied from LLama after static cache +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MistralMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.up_proj = ColumnParallelLinearWithoutBias( + self.hidden_size, + self.intermediate_size, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.gate_proj = ColumnParallelLinearWithoutBias( + self.hidden_size, + self.intermediate_size, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.down_proj = RowParallelLinearWithoutBias( + self.intermediate_size, + self.hidden_size, + bias=False, + input_is_parallel=True, + init_method=lambda x: x, + ) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + output = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return output + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class MistralAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: CollieConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = ColumnParallelLinearWithoutBias( + self.hidden_size, + self.num_heads * self.head_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.k_proj = ColumnParallelLinearWithoutBias( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.v_proj = ColumnParallelLinearWithoutBias( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + # aaaa + self.o_proj = RowParallelLinearWithoutBias( + self.num_heads * self.head_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + init_method=lambda x: x, + ) + + self.rotary_emb = MistralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, # 输入维度 [bsz, q_len, hidden_size] + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) # [bsz, q_len, num_heads * head_dim] + key_states = self.k_proj(hidden_states) # [bsz, q_len, num_key_value_heads * head_dim] + value_states = self.v_proj(hidden_states) # [bsz, q_len, num_key_value_heads * head_dim] + + query_states, key_states, value_states = ( + rearrange(query_states, "b n (h d) -> b n h d", d=self.head_dim), # [bsz, q_len, num_heads, head_dim] + rearrange(key_states, "b n (h d) -> b n h d", d=self.head_dim), # [bsz, q_len, num_key_value_heads, head_dim] + rearrange(value_states, "b n (h d) -> b n h d", d=self.head_dim), # [bsz, q_len, num_key_value_heads, head_dim] + ) + + query_states = query_states.transpose(1, 2) # [bsz, num_heads, q_len, head_dim] + key_states = key_states.transpose(1, 2) # [bsz, num_key_value_heads, q_len, head_dim] + value_states = value_states.transpose(1, 2) # [bsz, num_key_value_heads, q_len, head_dim] + + + if self.config.pp_size > 1: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads/self.config.tp_size, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads/self.config.tp_size, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads/self.config.tp_size, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads/self.config.tp_size, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, int(self.hidden_size/self.config.tp_size)) + + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, past_key_value + + +class MistralFlashAttention2(MistralAttention): + """ + Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states, key_states, value_states = ( + rearrange(query_states, "b n (h d) -> b n h d", d=self.head_dim), + rearrange(key_states, "b n (h d) -> b n h d", d=self.head_dim), + rearrange(value_states, "b n (h d) -> b n h d", d=self.head_dim), + ) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if self.config.pp_size > 1: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + ) + + if not _flash_supports_window_size: + logger.warning_once( + "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" + " make sure to upgrade flash-attn library." + ) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_sliding_windows=use_sliding_windows, + ) + + attn_output = attn_output.reshape(bsz, q_len, int(self.hidden_size/self.config.tp_size)).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_sliding_windows (`bool`, *optional*): + Whether to activate sliding window attention. + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + if not use_sliding_windows: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + if not use_sliding_windows: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + + # On the first iteration we need to properly re-create the padding mask + # by slicing it on the proper place + if kv_seq_len != attention_mask.shape[-1]: + attention_mask_num_tokens = attention_mask.shape[-1] + attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral +# TODO @Arthur no longer copied from LLama after static cache +class MistralSdpaAttention(MistralAttention): + """ + Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `MistralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from MistralAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MistralModel is using MistralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states, key_states, value_states = ( + rearrange(query_states, "b n (h d) -> b n h d", d=self.head_dim), + rearrange(key_states, "b n (h d) -> b n h d", d=self.head_dim), + rearrange(value_states, "b n (h d) -> b n h d", d=self.head_dim), + ) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if self.config.pp_size > 1: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, int(self.hidden_size/self.config.tp_size)) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +MISTRAL_ATTENTION_CLASSES = { + "eager": MistralAttention, + "flash_attention_2": MistralFlashAttention2, + "sdpa": MistralSdpaAttention, +} + + +class MistralDecoderLayer(nn.Module): + def __init__(self, config: CollieConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + config._attn_implementation = "sdpa" + self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + self.config = config + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.idx = layer_idx + # 务必保持变量名一致 + self.use_cache = self.config.model_config.use_cache + self.hidden_states = None + self.output_attentions = False + +class MistralDecoderLayer(nn.Module): + def __init__(self, config: CollieConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + config._attn_implementation = "sdpa" + self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + self.config = config + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.idx = layer_idx + # 务必保持变量名一致 + self.use_cache = self.config.model_config.use_cache + self.hidden_states = None + self.output_attentions = False + + def _forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + # output_attentions: Optional[bool] = False, + # use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + # if "padding_mask" in kwargs: + # warnings.warn( + # "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + # ) + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + # output_attentions=output_attentions, + # use_cache=use_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, present_key_value + + def forward(self, inputs: dict): + layer_past = inputs_to_kv_cache_for_layer(idx=self.idx, inputs=inputs) + + if self.config.checkpointing and self.training: + hidden_states, new_layer_past = torch.utils.checkpoint.checkpoint( + self._forward, + inputs["hidden_states"], + inputs.get("attention_mask", None), + inputs.get("position_ids", None), + layer_past, # inputs.get("past_key_values", None), + ) + else: + hidden_states, new_layer_past = self._forward( + inputs["hidden_states"], + inputs.get("attention_mask", None), + inputs.get("position_ids", None), + layer_past + ) # **inputs + inputs["hidden_states"] = hidden_states + + inputs.update(kv_cache_to_inputs_for_layer(idx=self.idx, new_layer_past=new_layer_past)) + return inputs + + +MISTRAL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MistralConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Mistral Model outputting raw hidden-states without any specific head on top.", + MISTRAL_START_DOCSTRING, +) +class MistralPreTrainedModel(PreTrainedModel): + config_class = MistralConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MistralDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +MISTRAL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Mistral Model outputting raw hidden-states without any specific head on top.", + MISTRAL_START_DOCSTRING, +) +class MistralModel(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`] + + Args: + config: MistralConfig + """ + + def __init__(self, config: CollieConfig): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = tensor_parallel.VocabParallelEmbedding( + config.vocab_size, config.hidden_size, params_dtype=torch.float32 + ) + self.layers = nn.ModuleList( + [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + config._attn_implementation = "sdpa" + self._attn_implementation = config._attn_implementation + self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + # aaaa + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + past_key_values_length = 0 + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + if self._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + hidden_states = inputs_embeds + + inputs = { + "input_ids": input_ids, + "hidden_states": hidden_states, + "attention_mask": attention_mask, + "position_ids": position_ids, + "past_key_values": past_key_values, + "output_attentions": output_attentions, + "use_cache": use_cache, + } + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + # for decoder_layer in self.layers: + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + # all_hidden_states += (hidden_states,) + all_hidden_states += (inputs["hidden_states"],) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + inputs, + ) + else: + layer_outputs = decoder_layer( + inputs, + ) + inputs.update(layer_outputs) + + # hidden_states = layer_outputs[0] + hidden_states = inputs["hidden_states"] + + if use_cache: + # next_decoder_cache = layer_outputs[2 if output_attentions else 1] + next_decoder_cache = inputs["addition_info"][1 if output_attentions else 0] + + if output_attentions: + # all_self_attns += (layer_outputs[1],) + all_self_attns += (inputs["addition_info"][0],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attns, + past_key_values=past_key_values, + ) + + @classmethod + def pipeline_layers(cls, config: CollieConfig): + """ + Get layers of pipeline. + :return: list + """ + if isinstance(config, str): + config = CollieConfig.from_pretrained(config) + + if config.tie_word_embeddings: + embed_tokens = TiedLayerSpec( + "embed_tokens", + dict_as_params(input_keys="input_ids", output_keys="hidden_states"), + tensor_parallel.VocabParallelEmbedding, + config.vocab_size, + config.hidden_size, + ) + else: + embed_tokens = LayerSpec( + dict_as_params(input_keys="input_ids", output_keys="hidden_states"), + tensor_parallel.VocabParallelEmbedding, + config.vocab_size, + config.hidden_size, + ) + + layers = [ + LayerSpec(MistralDecoderLayer, config, i) for i in range(config.num_hidden_layers) + ] + norm = LayerSpec( + dict_as_params(input_keys="hidden_states", output_keys="hidden_states"), + MistralRMSNorm, + hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + ) + + return [ + ("embed_tokens", embed_tokens), + ("layers", layers), + ("norm", norm), + ] + +class MistralForCausalLM(CollieModelForCausalLM): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config:CollieConfig): + super().__init__(config) + self.model = MistralModel(config) + self.vocab_size = config.vocab_size + self.lm_head = ColumnParallelLinearWithoutBias( + self.collie_config.hidden_size, self.collie_config.vocab_size, bias=False + ) + # Initialize weights and apply final processing + # GenerationMixin 需要的额外参数 + self.config.is_decoder = True + if config.model_config.tie_word_embeddings: + self.lm_head.weight = self.embed_tokens.weight + self.main_input_name = "input_ids" + + def clean_cache(self): + self._clean_hidden_states([*self.model.layers, self.lm_head]) + self._set_use_cache(self.model.layers, False) + + def set_cache(self, use_cache): + self._set_use_cache(self.model.layers, use_cache) + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MistralForCausalLM + + >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Ensure tensors are on the same device + shift_labels = shift_labels.to(shift_logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + # Omit tokens covered by past_key_values + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + @classmethod + def pipeline_layers(cls, config: CollieConfig): + """ + Get layers of pipeline. + :return: list + """ + if isinstance(config, str): + config = CollieConfig.from_pretrained(config) + + if config.tie_word_embeddings: + output = TiedLayerSpec( + "embed_tokens", + dict_as_params(input_keys="hidden_states", output_keys="logits"), + ColumnParallelLMHead, + config.hidden_size, + config.vocab_size, + bias=False, + ) + else: + output = LayerSpec( + dict_as_params(input_keys="hidden_states", output_keys="logits"), + ColumnParallelLMHead, + config.hidden_size, + config.vocab_size, + bias=False, + ) + + return [("model", MistralModel.pipeline_layers(config)), ("lm_head", output)] + + @staticmethod + def load_parallel_state_dict( + path: str, + config: Union[CollieConfig, str], + process_exclusion: bool = False, + **kwargs, + ): + ... + + @staticmethod + def load_parallel_state_dict( + path: str, + config: Union[CollieConfig, str], + process_exclusion: bool = False, + protocol: str = "file", # 指定加载state_dict时使用的协议 + **kwargs, + ): + """ + Load state_dict from ``path``. + The format of pretrained model should be the same as that of + `huggingface`. + :return: state_dict. Note that the state_dict should be processed + properly to match the current rank. + """ + # 配置加载 + if isinstance(config, str): + config = CollieConfig.from_pretrained(config) + # IO驱动初始化 + io_driver = IODriver.from_protocol(protocol) + # 检查文件路径是否存在 + if not io_driver.exists(path): + raise FileNotFoundError(f"folder {path} not found.") + # 初始化存储和处理变量 + state_dict = OrderedDict() + weights = [] + parts = None # 变量用于存储模型分割的部分信息 + # 如果开启了进程互斥,那么每个进程都会显示进度条,否则只显示 RANK0 的 + hide_progress = not process_exclusion and int(os.environ.get("RANK", "0")) != 0 + if dist.is_initialized() and process_exclusion: + # 如果启动了进程互斥,则要进行 dist.get_world_size() 次循环 + rank_order = range(dist.get_world_size()) + else: + # 不开启只进行一次循环 + rank_order = range(1) + # 权重文件加载和处理 + for rank in rank_order: + # 如果开启了进程互斥,那么只有对应 RANK 的能进入循环;不开启进程互斥的话就都可以进 + if int(os.environ.get("RANK", "0")) == rank or not process_exclusion: + # PP 分层的方法保存在了 os.environ["COLLIE_PP_PARTS"], 格式类似于 [0, 17, 35], 左闭右开 + if env.is_pipeline: + # 保存的是 json 格式 + parts = env.pipeline_parts + if hasattr(config, "num_key_value_heads"): + # llama2 (transformers >= 4.31.0) + num_key_value_heads = config.num_key_value_heads + else: + num_key_value_heads = config.num_attention_heads + head_dim = config.hidden_size // config.num_attention_heads + # 如果存在 pytorch_model.bin.index.json 文件的话,此时不同的 pp 进程可以按需加载自己需要的权重 + if ( + io_driver.exists(os.path.join(path, "pytorch_model.bin.index.json")) + and "COLLIE_PP_PARTS" in os.environ.keys() + ): + weight_map = json.loads( + io_driver.load( + os.path.join(path, "pytorch_model.bin.index.json"), mode="r" + ) + )["weight_map"] + # layers 表示自己需要的层 + layers = env.pipeline_layers_idx + # 筛选出形似 model.layers.0 这样的层。包含两个条件:1. 有数字的层;2. 数字加一要在 layers 里面(因为最开始还有个 embedding 占一层) + weights.extend( + [ + value + for key, value in weight_map.items() + if len(key.split(".")) > 2 + and key.split(".")[2].isdigit() + and (int(key.split(".")[2]) + 1) in layers + ] + ) + # 去重 + weights = list(set(weights)) + # 继续筛选,如果有 0 层,那么就要加载 embedding;如果有最后一层,那么就要加载 lm_head;如果有倒数第二层,那么就要加载 norm + if 0 in layers: + weights.append(weight_map["model.embed_tokens.weight"]) + if max(parts) - 1 in layers: + weights.append(weight_map["lm_head.weight"]) + if max(parts) - 2 in layers: + weights.append(weight_map["model.norm.weight"]) + else: + # 如果没有 pytorch_model.bin.index.json 文件的话,那么就加载所有的权重 + weights = [ + weight + for weight in io_driver.list(path) + if weight.endswith(".bin") + ] + with progress( + weights, + desc="Loading state dict", + total=len(weights), + disable=hide_progress, + ) as pbar: + for weight in pbar: + part_state_dict = io_driver.load( + os.path.join(path, weight), mode="rb" + ) + state_dict.update(part_state_dict) + del part_state_dict + if parts is not None: + # 这一步是 pp 的复筛 + layers = env.pipeline_layers_idx + for key in list(state_dict.keys()): + if key.startswith("layers"): + layer = int(key.split(".")[1]) + if layer + 1 not in layers: + state_dict.pop(key) + # if key.endswith("tok_embeddings.weight"): + if key.endswith("embed_tokens.weight"): + if 0 not in layers: + state_dict.pop(key) + if key == "norm.weight": + if max(parts) - 2 not in layers: + state_dict.pop(key) + # if key.endswith("output.weight"): + if key.endswith("lm_head.weight"): + if max(parts) - 1 not in layers: + state_dict.pop(key) + # 根据用户配置的新的 tp size 进行分割 + for key in list(state_dict.keys()): + col_filter = [ + "q_proj.weight", + "k_proj.weight", + "v_proj.weight", + "lm_head.weight", + "gate_proj.weight", + "up_proj.weight", + "embed_tokens.weight", + ] + col_split = any([key.endswith(filter) for filter in col_filter]) + + if col_split: + tensor = ( + list(torch.chunk(state_dict[key], config.tp_size, dim=0))[ + env.tp_rank + ] + .detach() + .clone() + ) + del state_dict[key] + if process_exclusion: + # CPU 内存回收(速度很慢) + gc.collect() + state_dict[key] = tensor + elif key.endswith("o_proj.weight") or key.endswith("down_proj.weight"): + tensor = ( + list(torch.chunk(state_dict[key], config.tp_size, dim=1))[ + env.tp_rank + ] + .detach() + .clone() + ) + del state_dict[key] + if process_exclusion: + # CPU 内存回收(速度很慢) + gc.collect() + state_dict[key] = tensor + if dist.is_initialized() and process_exclusion: + # 如果选择了进程互斥,那么本次循环中不需要加载权重的进程需等待 + dist.barrier() + return state_dict + + @staticmethod + def save_parallel_state_dict( + state_dict: dict, + path: str, + config: CollieConfig, + process_exclusion: bool = False, + **kwargs, + ): + ... + + @staticmethod + def save_parallel_state_dict( + state_dict: dict, + path: str, + config: CollieConfig, + process_exclusion: bool = False, + protocol: str = "file", + ): + """ + Save state_dict to ``path``. + The format of saved state dict should be the same as that of + `huggingface`. + """ + io_driver = IODriver.from_protocol(protocol) + # gather to tp rank 0 + if dist.is_initialized() and process_exclusion: + # 如果启动了进程互斥,则要进行 pp_size 次循环 + rank_order = range(config.pp_size) + else: + # 不开启只进行一次循环 + rank_order = range(1) + dst = parallel_state.get_tensor_model_parallel_src_rank() + with progress( + rank_order, + desc="Saving model", + disable=int(os.environ.get("RANK", "0")) != 0, + ) as pbar: + for rank in pbar: + if env.dp_rank == 0 and (env.pp_rank == rank or not process_exclusion): + for key in sorted(list(state_dict.keys())): + tensor_list = None + if env.tp_rank == 0: + tensor_list = [ + torch.zeros_like(state_dict[key]) + .to(state_dict[key].dtype) + .cuda() + for _ in range(config.tp_size) + ] + dist.gather( + state_dict[key].cuda(), + dst=dst, + gather_list=tensor_list, + group=env.tp_group, + ) + if env.tp_rank == 0: + col_filter = [ + "q_proj.weight", + "k_proj.weight", + "v_proj.weight", + "lm_head.weight", + "gate_proj.weight", + "up_proj.weight", + "embed_tokens.weight", + ] + col_split = any( + [key.endswith(filter) for filter in col_filter] + ) + + if col_split: + state_dict[key] = concat_tensor(tensor_list, dim=0) + + if process_exclusion: + # CPU 内存回收(速度很慢) + gc.collect() + + elif key.endswith("o_proj.weight") or key.endswith("down_proj.weight"): + state_dict[key] = concat_tensor(tensor_list, dim=1) + + if process_exclusion: + # CPU 内存回收(速度很慢) + gc.collect() + + if env.tp_rank == 0: + # Save gathered weights + if env.is_pipeline: + ckpt_name = f"pytorch_model-{env.pp_rank + 1:05d}-of-{config.pp_size:05d}.bin" + total_size = 0 + weight_map = {} + for name, weight in state_dict.items(): + weight_size = weight.numel() * dtype_byte_size( + weight.dtype + ) + weight_map[name] = ckpt_name + total_size += weight_size + index_dict = dict( + total_size=total_size, weight_map=weight_map + ) + index_dicts = [None for _ in range(env.pp_size)] + dist.gather_object( + index_dict, index_dicts if env.pp_rank == 0 else None, group=env.pp_group + ) + if env.pp_rank == 0: + total_size = 0 + weight_map = {} + for _index_dict in index_dicts: + total_size += _index_dict["total_size"] + weight_map.update(_index_dict["weight_map"]) + merged_dict = { + "metadata": {"total_size": total_size}, + "weight_map": weight_map, + } + io_driver.save( + json.dumps(merged_dict, indent=2, sort_keys=True) + + "\n", + os.path.join(path, "pytorch_model.bin.index.json"), + ) + + else: + ckpt_name = f"pytorch_model.bin" + ckpt_path = os.path.join(path, ckpt_name) + io_driver.save(state_dict, ckpt_path) + if dist.is_initialized() and process_exclusion: + dist.barrier() + if env.rank == 0: + config.save_pretrained(path, protocol=protocol) + dist.barrier() + + +@add_start_docstrings( + """ + The Mistral Model transformer with a sequence classification head on top (linear layer). + + [`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + MISTRAL_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL +class MistralForSequenceClassification(MistralPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MistralModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/tests/models/mistral/test_generation.py b/tests/models/mistral/test_generation.py new file mode 100644 index 00000000..ac06de86 --- /dev/null +++ b/tests/models/mistral/test_generation.py @@ -0,0 +1,37 @@ +import sys +sys.path.append("../../../") + +from transformers import AutoTokenizer, GenerationConfig + +from collie.models.mistral2 import MistralForCausalLM, MistralConfig +from collie import CollieConfig, env + +model_name_or_path = "mistralai/Mistral-7B-v0.1" + +tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + +config = CollieConfig.from_pretrained(model_name_or_path) + +config.dp_size = 1 +config.tp_size = 2 +config.pp_size = 2 +# config.architectures = ["MistralForCausalLM"] +print("------------------------------") +model = MistralForCausalLM.from_pretrained(model_name_or_path, config=config).cuda() +model.eval() +print("------------------------------") +prompt = "Llama is a" +# prompt = "Q:What do we eat for tonight?A:" +inputs = tokenizer(prompt, return_tensors="pt") +print("inputs:") +print(inputs) + + +gen_config = GenerationConfig(max_new_tokens=256, early_stopping=False, eos_token_id=2) + +outs = model.generate(inputs["input_ids"].cuda(), generation_config=gen_config) +if env.local_rank == 0: + print("outs:") + print(outs) + print("last:") + print(tokenizer.decode(outs[0], skip_special_tokens=True)) \ No newline at end of file diff --git a/tests/models/mistral/test_raw.py b/tests/models/mistral/test_raw.py new file mode 100644 index 00000000..1f50b393 --- /dev/null +++ b/tests/models/mistral/test_raw.py @@ -0,0 +1,22 @@ +import sys + +import torch + +sys.path.append("../../../") + +from transformers import AutoTokenizer, GenerationConfig, AutoModelForCausalLM +from collie.models.mistral.modeling_mistral import MistralForCausalLM + +model_name_or_path = "mistralai/Mistral-7B-v0.1" +tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) +model = MistralForCausalLM.from_pretrained(model_name_or_path).cuda() +model.eval() +prompt = "Llama is a" +# prompt = "Q:What do we eat for tonight?A:" +inputs = tokenizer(prompt, return_tensors="pt") +print(inputs) +gen_config = GenerationConfig(max_new_tokens=256, early_stopping=False, eos_token_id=2) +outs = model.generate(inputs["input_ids"].cuda(), generation_config=gen_config) + +print(outs) +print(tokenizer.decode(outs[0], skip_special_tokens=True)) \ No newline at end of file