From 060d88685e4a4fbe1617e4f90166d091904afa85 Mon Sep 17 00:00:00 2001 From: Seventeen17 <17aloha@gmail.com> Date: Thu, 27 Jun 2024 19:40:53 +0800 Subject: [PATCH 1/5] patch fa --- apps/train.py | 5 ++- examples/llama3_acc.sh | 2 +- examples/llama_acc.sh | 2 +- examples/run.sh | 20 +++++------ .../accelerators/acc_llama_accelerator.py | 4 +-- .../accelerators/cuda_llama_accelerator.py | 5 ++- flashmodels/arguments.py | 7 +++- flashmodels/builder.py | 6 +++- flashmodels/patch/__init__.py | 1 - flashmodels/patch/llama_model.py | 4 ++- flashmodels/patch/patch.py | 34 ++++++++----------- hf_models/config/llama-1b/config.json | 2 +- requirements.txt | 2 +- 13 files changed, 47 insertions(+), 47 deletions(-) diff --git a/apps/train.py b/apps/train.py index 7c0a48a..b6d790a 100644 --- a/apps/train.py +++ b/apps/train.py @@ -1,16 +1,15 @@ import torch -from flashmodels import Builder, Trainer, accelerate, arguments - - def train(): torch.manual_seed(101) # parse args + from flashmodels import arguments args = arguments.parse() # build model, tokenizer, loader, optimizer and lr_scheduler # and use accelerator to speed up training + from flashmodels import Builder, Trainer, accelerate builder = Builder(args) model, loader, tokenizer = builder.build_model_dataloader() model, loader = accelerate(model, loader, args) diff --git a/examples/llama3_acc.sh b/examples/llama3_acc.sh index c29061f..530e4cf 100755 --- a/examples/llama3_acc.sh +++ b/examples/llama3_acc.sh @@ -3,4 +3,4 @@ set -ex # FSDP # note: this need transformers>=4.41.0 -./examples/run.sh --model ./hf_models/config/llama-3-1b --accelerator acc --gc --mbs 2 --fsdp 8 --max_seq_length 4096 --no_fa +./examples/run.sh --model ./hf_models/config/llama-3-1b --accelerator acc --gc --mbs 2 --fsdp 8 --max_seq_length 4096 --use_flash_attn diff --git a/examples/llama_acc.sh b/examples/llama_acc.sh index 3080336..a416f7b 100755 --- a/examples/llama_acc.sh +++ b/examples/llama_acc.sh @@ -2,7 +2,7 @@ set -ex # FSDP -./examples/run.sh --model ./hf_models/config/llama-1b --accelerator acc --gc --mbs 4 --fsdp 4 +./examples/run.sh --model ./hf_models/config/llama-1b --accelerator acc --gc --mbs 4 --fsdp 4 --use_flash_attn # TP # ./examples/run.sh --model ./hf_models/config/llama-1b --accelerator acc --gc --mbs 24 --tp 4 diff --git a/examples/run.sh b/examples/run.sh index ab396df..00458cb 100755 --- a/examples/run.sh +++ b/examples/run.sh @@ -18,9 +18,9 @@ DP_NUM=1 # data parallelism number PP_NUM=1 # pipeline parallelism number TP_NUM=1 # tensor parallelism number FSDP_NUM=1 # fsdp number -FLASH_ATTN=1 # enable flash-attn-2 DATA=./data/wikitext-2-raw-v1.json # data name or path MODEL_NAME_OR_PATH="./hf_models/config/llama-1b" # model name or path +USE_FLASH_ATTN=1 OTHER_ARGS="" @@ -28,7 +28,7 @@ OTHER_ARGS="" HELP_STR=("Usage: bash examples/run.sh [-h|--help] [--accelerator {acc, cuda}] [--model MODEL_NAME_OR_PATH] \n" "\t[--data DATASET_NAME_OR_PATH] [--mbs MICRO_BATCH_SIZE] [--max_seq_length MAX_SEQ_LENGTH] \n" "\t[--num_train_epochs NUM_TRAIN_EPOCHS] [--max_steps MAX_TRAIN_STEPS] [--pp PP_NUM] [--tp TP_NUM] [--fsdp FSDP_NUM] \n" - "\t[--ga GRADIENT_ACCUMULATION_STEPS] [--gc] [--bf16] [--fp16] [--fp32] [--no_fa] [--log_interval LOG_INTERVAL] \n" + "\t[--ga GRADIENT_ACCUMULATION_STEPS] [--gc] [--bf16] [--fp16] [--fp32] [--use_flash_attn] [--log_interval LOG_INTERVAL] \n" "\t[other args for apps/train.py] \n" "Examples: \n" "\tbash examples/run.sh --accelerator cuda --model ./hf_models/config/llama-7b\n" @@ -125,8 +125,8 @@ while [[ $# -gt 0 ]]; do BF16=0 shift ;; - --no_fa) - FLASH_ATTN=0 + --use_flash_attn) + ACC_FLASH_ATTN=1 shift ;; --log_interval) @@ -150,6 +150,11 @@ OPTION_ARGS="" [[ "$BF16" -eq 1 ]] && OPTION_ARGS+="--bf16 " [[ "$FP16" -eq 1 ]] && OPTION_ARGS+="--fp16 " +if [[ "$ACC_FLASH_ATTN" == 1 && ( "$FP16" -eq 1 || "$BF16" -eq 1 ) ]]; then + OPTION_ARGS+="--use_flash_attn " + export ACC_FLASH_ATTN=1 +fi + if [ "$ACCELERATOR" == "cuda" ]; then [ "$PP_NUM" -gt 1 ] && echo "Error: Pipeline Parallelism is not supported for cuda accelerator." && exit 1 [ "$TP_NUM" -gt 1 ] && echo "Error: Tensor Parallelism is not supported for cuda accelerator." && exit 1 @@ -160,13 +165,6 @@ if [ "$TP_NUM" -gt "1" ]; then export XLA_USE_SPMD=1 fi - -if [[ "$ACCELERATOR" == "acc" && "FLASH_ATTN" -eq 1 && ( "$FP16" -eq 1 || "$BF16" -eq 1 ) ]]; then - export ACC_FLASH_ATTN=1 -fi - -export XLA_PERSISTENT_CACHE_PATH=./compiled_cache/ - MODEL_NAME=$(basename $MODEL_NAME_OR_PATH) JOB_NAME="${MODEL_NAME}_${ACCELERATOR}_bs${MBS}_seqlen${SEQLEN}_bf16-${BF16}_fp16-${FP16}_pp${PP_NUM}_tp${TP_NUM}_fsdp${FSDP_NUM}" diff --git a/flashmodels/accelerators/acc_llama_accelerator.py b/flashmodels/accelerators/acc_llama_accelerator.py index c2e7951..ab35bfe 100644 --- a/flashmodels/accelerators/acc_llama_accelerator.py +++ b/flashmodels/accelerators/acc_llama_accelerator.py @@ -85,9 +85,7 @@ def accelerate_internal(self, model, loader): model = self.tensor_parallel(model) return model, loader - if self.args.pp_num > 1: - # Prevent unnecessary model outputs - model.model.config.use_cache = False + model.model.config.use_cache = False # TODO: support this in torchacc if self.args.resume_from_checkpoint: assert self.args.fsdp_num == self.args.world_size, \ diff --git a/flashmodels/accelerators/cuda_llama_accelerator.py b/flashmodels/accelerators/cuda_llama_accelerator.py index 288f771..b20d47e 100644 --- a/flashmodels/accelerators/cuda_llama_accelerator.py +++ b/flashmodels/accelerators/cuda_llama_accelerator.py @@ -13,7 +13,6 @@ FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision from torch.distributed.fsdp.wrap import ModuleWrapPolicy -from transformers.models.llama.modeling_llama import LlamaDecoderLayer from flashmodels.accelerators.accelerator import (Accelerator, AcceleratorFactory) @@ -71,7 +70,7 @@ def apply_checkpointing(self, model): checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT, ) - check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer) + check_fn = lambda submodule: isinstance(submodule, transformers.models.llama.modeling_llama.LlamaDecoderLayer) apply_activation_checkpointing( model, checkpoint_wrapper_fn=non_reentrant_wrapper, @@ -98,7 +97,7 @@ def fsdp(self, model): # Use auto_wrap_poliy for nested wrapping instead of only a top-level FSDP. auto_wrap_policy = ModuleWrapPolicy({ - LlamaDecoderLayer, + transformers.models.llama.modeling_llama.LlamaDecoderLayer, }) mixed_precision_policy = None diff --git a/flashmodels/arguments.py b/flashmodels/arguments.py index 3aac6bf..7a552de 100644 --- a/flashmodels/arguments.py +++ b/flashmodels/arguments.py @@ -212,6 +212,11 @@ def parse(): type=int, default=-1, help="Maximum training steps") + parser.add_argument( + "--use_flash_attn", + action="store_true", + default=False, + help="Use TriDao FlashAttention2") parser.add_argument( "--log_loss", action="store_true", help="Print loss when logging steps") @@ -271,7 +276,7 @@ def parse(): if args.model_type == "llama" and args.accelerator == 'acc' and ( args.fp16 or args.bf16): - patch_llama() + patch_llama(args.use_flash_attn) if args.model_type == "gemma" and args.accelerator == 'acc': patch_gemma() diff --git a/flashmodels/builder.py b/flashmodels/builder.py index aaa62bd..b5aca36 100644 --- a/flashmodels/builder.py +++ b/flashmodels/builder.py @@ -4,6 +4,7 @@ import torch import torchacc as ta + from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, get_scheduler) @@ -62,7 +63,9 @@ def build_model_from_ckpt(self): config = AutoConfig.from_pretrained( self.args.model_name_or_path, trust_remote_code=True) return self._init_fn( - AutoModelForCausalLM.from_config, config, trust_remote_code=True) + AutoModelForCausalLM.from_config, config, + attn_implementation="flash_attention_2" if self.args.use_flash_attn else "eager", + trust_remote_code=True) def build_model_from_pretrain(self): has_weight = False @@ -78,6 +81,7 @@ def build_model_from_pretrain(self): return self._init_fn( AutoModelForCausalLM.from_pretrained, self.args.model_name_or_path, + attn_implementation="flash_attention_2" if self.args.use_flash_attn else "eager", cache_dir=self.args.cache_dir, trust_remote_code=True) if self.args.local_rank == 0: diff --git a/flashmodels/patch/__init__.py b/flashmodels/patch/__init__.py index baf734f..4dc7330 100644 --- a/flashmodels/patch/__init__.py +++ b/flashmodels/patch/__init__.py @@ -5,6 +5,5 @@ def patch_amp(): import torchacc as ta ta.patch_amp() - def patch_peft(): patch_lora() diff --git a/flashmodels/patch/llama_model.py b/flashmodels/patch/llama_model.py index 745735f..58485ca 100644 --- a/flashmodels/patch/llama_model.py +++ b/flashmodels/patch/llama_model.py @@ -396,9 +396,11 @@ def flash_attn_fwd( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: from torchacc.ops import flash_attn_varlen_xla diff --git a/flashmodels/patch/patch.py b/flashmodels/patch/patch.py index f334ffb..bb66460 100644 --- a/flashmodels/patch/patch.py +++ b/flashmodels/patch/patch.py @@ -3,16 +3,10 @@ import os import re from typing import Any - import torch import transformers - from flashmodels.logger import logger -from flashmodels.patch.llama_model import (LlamaAttention, LlamaDecoderLayer, - LlamaMLP, flash_attn_fwd, - flash_attn_prep_mask, - make_causal_mask) - +from torchacc import patch_fa def rewrite_load(): """Rewrite `torch.load` in `from_pretrain` in case to use mmap to reduce the CPU @@ -34,23 +28,25 @@ def rewrite_load(): exec(modified, transformers.modeling_utils.__dict__) -def patch_llama(): - transformers.models.llama.modeling_llama._make_causal_mask = make_causal_mask - if os.getenv("ACC_FLASH_ATTN", "0") == "1": - transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = flash_attn_prep_mask - transformers.models.llama.modeling_llama.LlamaAttention.forward = flash_attn_fwd - elif os.environ.get("ACC_LLAMA_TP") == "1": - transformers.models.llama.modeling_llama.LlamaMLP = LlamaMLP - if os.getenv("XLA_USE_SPMD") == "1": - # use einsum in linear for SPMD TP/Ulysses. - transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention - transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer +def patch_llama(use_flash_attn): + if use_flash_attn: + patch_fa() + from transformers.cache_utils import Cache + def update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + return None + transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask = update_causal_mask # (wenting.swt): Delete me when merged in transformers if bool(int(os.environ.get("LOW_CPU_MEM_USAGE", "0"))): rewrite_load() - # Set the attention_mask in LlamaAttention to None to match the pattern of FlashAttentionRewriter. def wrap_for_flash_attention(func): def wrapper(*args, **kwargs): diff --git a/hf_models/config/llama-1b/config.json b/hf_models/config/llama-1b/config.json index 04d6487..03476f1 100644 --- a/hf_models/config/llama-1b/config.json +++ b/hf_models/config/llama-1b/config.json @@ -1 +1 @@ -{"architectures": ["LLaMAForCausalLM"], "bos_token_id": 0, "eos_token_id": 1, "hidden_act": "silu", "hidden_size": 4096, "intermediate_size": 11008, "initializer_range": 0.02, "max_sequence_length": 2048, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 4, "pad_token_id": -1, "rms_norm_eps": 1e-06, "torch_dtype": "float16", "transformers_version": "4.27.0.dev0", "use_cache": true, "vocab_size": 32000} \ No newline at end of file +{"architectures": ["LLaMAForCausalLM"], "bos_token_id": 0, "eos_token_id": 1, "hidden_act": "silu", "hidden_size": 4096, "intermediate_size": 11008, "initializer_range": 0.02, "max_sequence_length": 2048, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 4, "pad_token_id": -1, "rms_norm_eps": 1e-06, "torch_dtype": "float16", "transformers_version": "4.27.0.dev0", "use_cache": true, "vocab_size": 32000} diff --git a/requirements.txt b/requirements.txt index 01a909c..0c07b9b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ datasets numpy torch tokenizers>=0.13.3 -transformers==4.33.0 +transformers==4.41.0 accelerate transformers_stream_generator tiktoken From a0512f4a090241647d4deffe12248cfe012cf76a Mon Sep 17 00:00:00 2001 From: Seventeen17 <17aloha@gmail.com> Date: Thu, 27 Jun 2024 20:32:15 +0800 Subject: [PATCH 2/5] remove unused code --- examples/run.sh | 2 ++ flashmodels/patch/llama_model.py | 5 ++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/run.sh b/examples/run.sh index 00458cb..57cedcb 100755 --- a/examples/run.sh +++ b/examples/run.sh @@ -165,6 +165,8 @@ if [ "$TP_NUM" -gt "1" ]; then export XLA_USE_SPMD=1 fi +export XLA_PERSISTENT_CACHE_PATH=./compiled_cache/ + MODEL_NAME=$(basename $MODEL_NAME_OR_PATH) JOB_NAME="${MODEL_NAME}_${ACCELERATOR}_bs${MBS}_seqlen${SEQLEN}_bf16-${BF16}_fp16-${FP16}_pp${PP_NUM}_tp${TP_NUM}_fsdp${FSDP_NUM}" diff --git a/flashmodels/patch/llama_model.py b/flashmodels/patch/llama_model.py index 58485ca..6b4a7ca 100644 --- a/flashmodels/patch/llama_model.py +++ b/flashmodels/patch/llama_model.py @@ -8,6 +8,7 @@ import torch_xla.core.xla_model as xm from torch import nn from torchacc.dist.tp import Mesh, mark_sharding +from transformer.cache_utils import Cache from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import (ACT2FN, LlamaRMSNorm, LlamaRotaryEmbedding, @@ -396,11 +397,9 @@ def flash_attn_fwd( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: from torchacc.ops import flash_attn_varlen_xla From 641b452668c2dbeca7b4f705d652ff6f4fc31599 Mon Sep 17 00:00:00 2001 From: Seventeen17 <17aloha@gmail.com> Date: Thu, 27 Jun 2024 20:45:54 +0800 Subject: [PATCH 3/5] remove unused code --- examples/llama_acc.sh | 2 +- flashmodels/patch/patch.py | 13 ------------- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/examples/llama_acc.sh b/examples/llama_acc.sh index a416f7b..3080336 100755 --- a/examples/llama_acc.sh +++ b/examples/llama_acc.sh @@ -2,7 +2,7 @@ set -ex # FSDP -./examples/run.sh --model ./hf_models/config/llama-1b --accelerator acc --gc --mbs 4 --fsdp 4 --use_flash_attn +./examples/run.sh --model ./hf_models/config/llama-1b --accelerator acc --gc --mbs 4 --fsdp 4 # TP # ./examples/run.sh --model ./hf_models/config/llama-1b --accelerator acc --gc --mbs 24 --tp 4 diff --git a/flashmodels/patch/patch.py b/flashmodels/patch/patch.py index bb66460..f559acd 100644 --- a/flashmodels/patch/patch.py +++ b/flashmodels/patch/patch.py @@ -47,19 +47,6 @@ def update_causal_mask( if bool(int(os.environ.get("LOW_CPU_MEM_USAGE", "0"))): rewrite_load() - def wrap_for_flash_attention(func): - - def wrapper(*args, **kwargs): - kwargs["attention_mask"] = None - return func(*args, **kwargs) - - return wrapper - - # always attention_mask=None - transformers.models.llama.modeling_llama.LlamaAttention.forward = wrap_for_flash_attention( - transformers.models.llama.modeling_llama.LlamaAttention. - forward) - def patch_gemma(): # Set the attention_mask in GemmaAttention to None to match the pattern of FlashAttentionRewriter. From e200a3b485cf11ef3100813f534e92d8e0abf580 Mon Sep 17 00:00:00 2001 From: Seventeen17 <17aloha@gmail.com> Date: Fri, 28 Jun 2024 11:01:10 +0800 Subject: [PATCH 4/5] fix usage of acclerate --- examples/llama_acc.sh | 2 +- .../accelerators/acc_baichuan_accelerator.py | 2 +- .../accelerators/acc_gemma_accelerator.py | 2 +- .../accelerators/acc_glm_accelerator.py | 2 +- .../accelerators/acc_gpt_accelerator.py | 2 +- .../accelerators/acc_llama_accelerator.py | 2 +- .../accelerators/acc_olmo_accelerator.py | 2 +- .../accelerators/acc_qwen_accelerator.py | 2 +- .../accelerators/cuda_llama_accelerator.py | 7 +- flashmodels/arguments.py | 273 +++++++++--------- flashmodels/arguments.py.bak | 273 ++++++++++++++++++ flashmodels/patch/__init__.py | 4 - flashmodels/patch/__init__.py.bak | 5 + flashmodels/patch/patch.py | 26 +- flashmodels/patch/patch.py.bak | 111 +++++++ requirements.txt | 2 +- 16 files changed, 542 insertions(+), 175 deletions(-) create mode 100644 flashmodels/arguments.py.bak create mode 100644 flashmodels/patch/__init__.py.bak create mode 100644 flashmodels/patch/patch.py.bak diff --git a/examples/llama_acc.sh b/examples/llama_acc.sh index 3080336..a416f7b 100755 --- a/examples/llama_acc.sh +++ b/examples/llama_acc.sh @@ -2,7 +2,7 @@ set -ex # FSDP -./examples/run.sh --model ./hf_models/config/llama-1b --accelerator acc --gc --mbs 4 --fsdp 4 +./examples/run.sh --model ./hf_models/config/llama-1b --accelerator acc --gc --mbs 4 --fsdp 4 --use_flash_attn # TP # ./examples/run.sh --model ./hf_models/config/llama-1b --accelerator acc --gc --mbs 24 --tp 4 diff --git a/flashmodels/accelerators/acc_baichuan_accelerator.py b/flashmodels/accelerators/acc_baichuan_accelerator.py index 935e333..14d497d 100644 --- a/flashmodels/accelerators/acc_baichuan_accelerator.py +++ b/flashmodels/accelerators/acc_baichuan_accelerator.py @@ -17,7 +17,7 @@ def accelerate_internal(self, model, loader): raise NotImplementedError("resume_from_checkpoint.") config = self.get_config(model) - model = ta.accelerate(model, config) + model = ta.accelerate(model, config=config) return model, loader def get_config(self, model): diff --git a/flashmodels/accelerators/acc_gemma_accelerator.py b/flashmodels/accelerators/acc_gemma_accelerator.py index f062105..db12458 100644 --- a/flashmodels/accelerators/acc_gemma_accelerator.py +++ b/flashmodels/accelerators/acc_gemma_accelerator.py @@ -12,7 +12,7 @@ def accelerate(self, model, loader): def accelerate_internal(self, model, loader): config = self.get_config() - model = ta.accelerate(model, config) + model = ta.accelerate(model, config=config) return model, loader def get_config(self): diff --git a/flashmodels/accelerators/acc_glm_accelerator.py b/flashmodels/accelerators/acc_glm_accelerator.py index 6c9801c..7695db0 100644 --- a/flashmodels/accelerators/acc_glm_accelerator.py +++ b/flashmodels/accelerators/acc_glm_accelerator.py @@ -17,7 +17,7 @@ def accelerate_internal(self, model, loader): raise NotImplementedError("resume_from_checkpoint.") config = self.get_config(model) - model = ta.accelerate(model, config) + model = ta.accelerate(model, config=config) return model, loader def get_config(self, model): diff --git a/flashmodels/accelerators/acc_gpt_accelerator.py b/flashmodels/accelerators/acc_gpt_accelerator.py index 95a3231..971c680 100644 --- a/flashmodels/accelerators/acc_gpt_accelerator.py +++ b/flashmodels/accelerators/acc_gpt_accelerator.py @@ -20,7 +20,7 @@ def accelerate_internal(self, model, loader): raise NotImplementedError("resume_from_checkpoint.") config = self.get_config(model) - model = ta.accelerate(model, config) + model = ta.accelerate(model, config=config) return model, loader device = lazy_device() diff --git a/flashmodels/accelerators/acc_llama_accelerator.py b/flashmodels/accelerators/acc_llama_accelerator.py index ab35bfe..459f5e5 100644 --- a/flashmodels/accelerators/acc_llama_accelerator.py +++ b/flashmodels/accelerators/acc_llama_accelerator.py @@ -99,7 +99,7 @@ def accelerate_internal(self, model, loader): self.args.sp) config = self.get_config(model) - model = ta.accelerate(model, config) + model = ta.accelerate(model, config=config) if self.args.tp_num > 1 and self.args.pp_num > 1: self.parallel_3d(model._get_underlay_model()) diff --git a/flashmodels/accelerators/acc_olmo_accelerator.py b/flashmodels/accelerators/acc_olmo_accelerator.py index 8009209..0af0250 100644 --- a/flashmodels/accelerators/acc_olmo_accelerator.py +++ b/flashmodels/accelerators/acc_olmo_accelerator.py @@ -17,7 +17,7 @@ def accelerate_internal(self, model, loader): raise NotImplementedError("resume_from_checkpoint.") config = self.get_config(model) - model = ta.accelerate(model, config) + model = ta.accelerate(model, config=config) return model, loader else: raise NotImplementedError("Currently, only FSDP is supported.") diff --git a/flashmodels/accelerators/acc_qwen_accelerator.py b/flashmodels/accelerators/acc_qwen_accelerator.py index 1a64edf..c1acc0e 100644 --- a/flashmodels/accelerators/acc_qwen_accelerator.py +++ b/flashmodels/accelerators/acc_qwen_accelerator.py @@ -37,7 +37,7 @@ def accelerate_internal(self, model, loader): raise NotImplementedError("resume_from_checkpoint.") config = self.get_config(model) - model = ta.accelerate(model, config) + model = ta.accelerate(model, config=config) return model, loader def get_config(self, model): diff --git a/flashmodels/accelerators/cuda_llama_accelerator.py b/flashmodels/accelerators/cuda_llama_accelerator.py index b20d47e..a27da9d 100644 --- a/flashmodels/accelerators/cuda_llama_accelerator.py +++ b/flashmodels/accelerators/cuda_llama_accelerator.py @@ -13,6 +13,7 @@ FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from transformers.models.llama.modeling_llama import LlamaDecoderLayer from flashmodels.accelerators.accelerator import (Accelerator, AcceleratorFactory) @@ -70,7 +71,7 @@ def apply_checkpointing(self, model): checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT, ) - check_fn = lambda submodule: isinstance(submodule, transformers.models.llama.modeling_llama.LlamaDecoderLayer) + check_fn = lambda submodule: isinstance(LlamaDecoderLayer) apply_activation_checkpointing( model, checkpoint_wrapper_fn=non_reentrant_wrapper, @@ -96,9 +97,7 @@ def fsdp(self, model): convert_outputs_to_fp32(model.forward.__func__), model) # Use auto_wrap_poliy for nested wrapping instead of only a top-level FSDP. - auto_wrap_policy = ModuleWrapPolicy({ - transformers.models.llama.modeling_llama.LlamaDecoderLayer, - }) + auto_wrap_policy = ModuleWrapPolicy({LlamaDecoderLayer, }) mixed_precision_policy = None if self.args.fp16 or self.args.bf16: diff --git a/flashmodels/arguments.py b/flashmodels/arguments.py index 7a552de..8ad9eb7 100644 --- a/flashmodels/arguments.py +++ b/flashmodels/arguments.py @@ -4,7 +4,7 @@ import torch from flashmodels.logger import logger -from flashmodels.patch import patch_amp, patch_gemma, patch_llama, patch_peft +from flashmodels.patch import patch_gemma, patch_llama, patch_peft def print_args(args): @@ -16,10 +16,9 @@ def parse(): parser = argparse.ArgumentParser(description="Flash Models Arguments") # model args - parser.add_argument( - "--model_name_or_path", - type=str, - default="decapoda-research/llama-7b-hf") + parser.add_argument("--model_name_or_path", + type=str, + default="decapoda-research/llama-7b-hf") parser.add_argument("--cache_dir", type=str, default="./models/") parser.add_argument("--max_seq_length", type=int, default=1024) parser.add_argument( @@ -29,97 +28,95 @@ def parse(): choices=["gpt", "llama", "glm", "baichuan", "qwen", "olmo"]) # dataset args - parser.add_argument( - "--dataset_name_or_path", - type=str, - default="./data/wikitext-2-raw-v1.json") + parser.add_argument("--dataset_name_or_path", + type=str, + default="./data/wikitext-2-raw-v1.json") parser.add_argument("--dataset_config", type=str, default="") parser.add_argument("--micro_batch_size", type=int, default=8) parser.add_argument("--padding_side", type=str, default="right") - parser.add_argument( - "--disable_train_sampler", - action="store_true", - help="Disable Train Sampler") + parser.add_argument("--disable_train_sampler", + action="store_true", + help="Disable Train Sampler") # accelerator args - parser.add_argument( - "--accelerator", - type=str, - default="acc", - choices=["cuda", "acc", "megatron"], - help="accelerator name") - parser.add_argument( - "--fsdp_num", - type=int, - default=1, - help="Full sharded data parallel Number") - parser.add_argument( - "--gc", - action="store_true", - default=False, - help="Use gradients checkpoint") + parser.add_argument("--accelerator", + type=str, + default="acc", + choices=["cuda", "acc", "megatron"], + help="accelerator name") + parser.add_argument("--fsdp_num", + type=int, + default=1, + help="Full sharded data parallel Number") + parser.add_argument("--gc", + action="store_true", + default=False, + help="Use gradients checkpoint") parser.add_argument( "--gc_cnt", type=int, default=None, help="Number of decoder layers for gradient checkpointing") - parser.add_argument( - "--tp_num", type=int, default=1, help="Tensor Parallel Number") - parser.add_argument( - "--sp", - action="store_true", - default=False, - help="Use Sequence Parallelism.") + parser.add_argument("--tp_num", + type=int, + default=1, + help="Tensor Parallel Number") + parser.add_argument("--sp", + action="store_true", + default=False, + help="Use Sequence Parallelism.") parser.add_argument( "--sp_reshard_after_forward", action="store_true", default=False, help="To reduce memory usage, reshard weight after forward in TP-SP, \ and perform an extra all-gather in the backward pass") - parser.add_argument( - "--sp_num", - type=int, - default=1, - help="DeepSpeed Ulysses Sequence \ + parser.add_argument("--sp_num", + type=int, + default=1, + help="DeepSpeed Ulysses Sequence \ Parallel Number. ") - parser.add_argument( - "--dp_num", type=int, default=1, help="Data Parallel Number") - parser.add_argument( - "--pp_num", type=int, default=1, help="Pipeline Parallel Number") - parser.add_argument( - "--fp16", action="store_true", help="Run model in fp16 mode.") - parser.add_argument( - "--bf16", action="store_true", help="Run model in bfloat16 mode.") - parser.add_argument( - "--force_use_syncfree_adam", - action="store_true", - help="Force to use \ + parser.add_argument("--dp_num", + type=int, + default=1, + help="Data Parallel Number") + parser.add_argument("--pp_num", + type=int, + default=1, + help="Pipeline Parallel Number") + parser.add_argument("--fp16", + action="store_true", + help="Run model in fp16 mode.") + parser.add_argument("--bf16", + action="store_true", + help="Run model in bfloat16 mode.") + parser.add_argument("--force_use_syncfree_adam", + action="store_true", + help="Force to use \ syncfree.Adam/AdamW for better tracing peformance.") - parser.add_argument( - "--use_zero2", - action="store_true", - help="Use \ + parser.add_argument("--use_zero2", + action="store_true", + help="Use \ distributed optimizer(ZeRO2) for SPMD-DP.") - parser.add_argument( - "--use_zero3", - action="store_true", - help="Use \ + parser.add_argument("--use_zero3", + action="store_true", + help="Use \ ZeRO3 for SPMD-DP.") # lora parser.add_argument("--lora", action="store_true", help="Use lora") - parser.add_argument( - "--lora_r", type=int, default=8, help="lora attention dimension") - parser.add_argument( - "--lora_alpha", - type=int, - default=8, - help="lora scaling alpha parameter") - parser.add_argument( - "--lora_dropout", - type=float, - default=0.0, - help="The dropout probability \ + parser.add_argument("--lora_r", + type=int, + default=8, + help="lora attention dimension") + parser.add_argument("--lora_alpha", + type=int, + default=8, + help="lora scaling alpha parameter") + parser.add_argument("--lora_dropout", + type=float, + default=0.0, + help="The dropout probability \ for Lora layers") parser.add_argument( "--lora_target_modules", @@ -131,55 +128,50 @@ def parse(): # training args parser.add_argument("--global_rank", type=int, default=0) - parser.add_argument( - "--resume_from_checkpoint", - action="store_true", - help="Resume from checkpoint, if true," - " load checkpoint from ckpt_dir") + parser.add_argument("--resume_from_checkpoint", + action="store_true", + help="Resume from checkpoint, if true," + " load checkpoint from ckpt_dir") parser.add_argument("--ckpt_dir", type=str, default="") - parser.add_argument( - "--ckpt_freq", - type=int, - default=100, - help="The checkpoint frequency of local steps.") - parser.add_argument( - "--profile", action="store_true", help="Open pytorch profiler") + parser.add_argument("--ckpt_freq", + type=int, + default=100, + help="The checkpoint frequency of local steps.") + parser.add_argument("--profile", + action="store_true", + help="Open pytorch profiler") parser.add_argument("--profile_dir", type=str, default="./profile/") - parser.add_argument( - "--profile_stop_step", - type=int, - default=10, - help="Maximum profiling steps") + parser.add_argument("--profile_stop_step", + type=int, + default=10, + help="Maximum profiling steps") parser.add_argument("--log_interval", type=int, default=1) parser.add_argument("--gradient_accumulation_steps", type=int, default=1) parser.add_argument("--max_step", type=int, default=-1) - parser.add_argument( - "--learning_rate", - type=float, - default=2e-5, - help="The initial learning rate for AdamW.") - parser.add_argument( - "--weight_decay", - type=float, - default=0.03, - help="Weight decay for AdamW if we apply some.") - parser.add_argument( - "--adam_beta1", - type=float, - default=0.9, - help="Beta1 for AdamW optimizer") - parser.add_argument( - "--adam_beta2", - type=float, - default=0.999, - help="Beta2 for AdamW optimizer") - parser.add_argument( - "--adam_epsilon", - type=float, - default=1e-8, - help="Epsilon for AdamW optimizer.") - parser.add_argument( - "--max_grad_norm", type=float, default=1.0, help="Max gradient norm.") + parser.add_argument("--learning_rate", + type=float, + default=2e-5, + help="The initial learning rate for AdamW.") + parser.add_argument("--weight_decay", + type=float, + default=0.03, + help="Weight decay for AdamW if we apply some.") + parser.add_argument("--adam_beta1", + type=float, + default=0.9, + help="Beta1 for AdamW optimizer") + parser.add_argument("--adam_beta2", + type=float, + default=0.999, + help="Beta2 for AdamW optimizer") + parser.add_argument("--adam_epsilon", + type=float, + default=1e-8, + help="Epsilon for AdamW optimizer.") + parser.add_argument("--max_grad_norm", + type=float, + default=1.0, + help="Max gradient norm.") parser.add_argument( "--lr_scheduler_type", type=str, @@ -195,39 +187,34 @@ def parse(): type=float, default=0.0, help="Linear warmup over warmup_ratio fraction of total steps.") - parser.add_argument( - "--warmup_steps", - type=int, - default=0, - help="Linear warmup over warmup_steps.") + parser.add_argument("--warmup_steps", + type=int, + default=0, + help="Linear warmup over warmup_steps.") parser.add_argument("--num_train_epochs", type=int, default=1) - parser.add_argument( - "--padding_strategy", - type=str, - default="max_length", - help="tokenizer padding strategy", - choices=["max_length", "longest"]) - parser.add_argument( - "--max_train_steps", - type=int, - default=-1, - help="Maximum training steps") - parser.add_argument( - "--use_flash_attn", - action="store_true", - default=False, - help="Use TriDao FlashAttention2") - parser.add_argument( - "--log_loss", action="store_true", help="Print loss when logging steps") + parser.add_argument("--padding_strategy", + type=str, + default="max_length", + help="tokenizer padding strategy", + choices=["max_length", "longest"]) + parser.add_argument("--max_train_steps", + type=int, + default=-1, + help="Maximum training steps") + parser.add_argument("--use_flash_attn", + action="store_true", + default=False, + help="Use TriDao FlashAttention2") + parser.add_argument("--log_loss", + action="store_true", + help="Print loss when logging steps") args = parser.parse_args() if args.lora: patch_peft() - if args.accelerator == "acc": - patch_amp() - else: + if args.accelerator == "cuda": torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) args.global_rank = int(os.getenv("RANK", 0)) diff --git a/flashmodels/arguments.py.bak b/flashmodels/arguments.py.bak new file mode 100644 index 0000000..8ad9eb7 --- /dev/null +++ b/flashmodels/arguments.py.bak @@ -0,0 +1,273 @@ +import argparse +import os + +import torch + +from flashmodels.logger import logger +from flashmodels.patch import patch_gemma, patch_llama, patch_peft + + +def print_args(args): + logger.info("FlashModels Arguments: ") + logger.info(" \n".join(f" {k} = {v}" for k, v in vars(args).items())) + + +def parse(): + parser = argparse.ArgumentParser(description="Flash Models Arguments") + + # model args + parser.add_argument("--model_name_or_path", + type=str, + default="decapoda-research/llama-7b-hf") + parser.add_argument("--cache_dir", type=str, default="./models/") + parser.add_argument("--max_seq_length", type=int, default=1024) + parser.add_argument( + "--model_type", + type=str, + default="", + choices=["gpt", "llama", "glm", "baichuan", "qwen", "olmo"]) + + # dataset args + parser.add_argument("--dataset_name_or_path", + type=str, + default="./data/wikitext-2-raw-v1.json") + parser.add_argument("--dataset_config", type=str, default="") + parser.add_argument("--micro_batch_size", type=int, default=8) + parser.add_argument("--padding_side", type=str, default="right") + parser.add_argument("--disable_train_sampler", + action="store_true", + help="Disable Train Sampler") + + # accelerator args + parser.add_argument("--accelerator", + type=str, + default="acc", + choices=["cuda", "acc", "megatron"], + help="accelerator name") + parser.add_argument("--fsdp_num", + type=int, + default=1, + help="Full sharded data parallel Number") + parser.add_argument("--gc", + action="store_true", + default=False, + help="Use gradients checkpoint") + parser.add_argument( + "--gc_cnt", + type=int, + default=None, + help="Number of decoder layers for gradient checkpointing") + parser.add_argument("--tp_num", + type=int, + default=1, + help="Tensor Parallel Number") + parser.add_argument("--sp", + action="store_true", + default=False, + help="Use Sequence Parallelism.") + parser.add_argument( + "--sp_reshard_after_forward", + action="store_true", + default=False, + help="To reduce memory usage, reshard weight after forward in TP-SP, \ + and perform an extra all-gather in the backward pass") + parser.add_argument("--sp_num", + type=int, + default=1, + help="DeepSpeed Ulysses Sequence \ + Parallel Number. ") + parser.add_argument("--dp_num", + type=int, + default=1, + help="Data Parallel Number") + parser.add_argument("--pp_num", + type=int, + default=1, + help="Pipeline Parallel Number") + parser.add_argument("--fp16", + action="store_true", + help="Run model in fp16 mode.") + parser.add_argument("--bf16", + action="store_true", + help="Run model in bfloat16 mode.") + parser.add_argument("--force_use_syncfree_adam", + action="store_true", + help="Force to use \ + syncfree.Adam/AdamW for better tracing peformance.") + parser.add_argument("--use_zero2", + action="store_true", + help="Use \ + distributed optimizer(ZeRO2) for SPMD-DP.") + parser.add_argument("--use_zero3", + action="store_true", + help="Use \ + ZeRO3 for SPMD-DP.") + + # lora + parser.add_argument("--lora", action="store_true", help="Use lora") + parser.add_argument("--lora_r", + type=int, + default=8, + help="lora attention dimension") + parser.add_argument("--lora_alpha", + type=int, + default=8, + help="lora scaling alpha parameter") + parser.add_argument("--lora_dropout", + type=float, + default=0.0, + help="The dropout probability \ + for Lora layers") + parser.add_argument( + "--lora_target_modules", + type=str, + default="QKV", + choices=["QKV", "ALL"], + help="The modules to apply Lora to. ALL means all linear layers in \ + decoder layer use lora, QKV means only qkv linears use lora") + + # training args + parser.add_argument("--global_rank", type=int, default=0) + parser.add_argument("--resume_from_checkpoint", + action="store_true", + help="Resume from checkpoint, if true," + " load checkpoint from ckpt_dir") + parser.add_argument("--ckpt_dir", type=str, default="") + parser.add_argument("--ckpt_freq", + type=int, + default=100, + help="The checkpoint frequency of local steps.") + parser.add_argument("--profile", + action="store_true", + help="Open pytorch profiler") + parser.add_argument("--profile_dir", type=str, default="./profile/") + parser.add_argument("--profile_stop_step", + type=int, + default=10, + help="Maximum profiling steps") + parser.add_argument("--log_interval", type=int, default=1) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--max_step", type=int, default=-1) + parser.add_argument("--learning_rate", + type=float, + default=2e-5, + help="The initial learning rate for AdamW.") + parser.add_argument("--weight_decay", + type=float, + default=0.03, + help="Weight decay for AdamW if we apply some.") + parser.add_argument("--adam_beta1", + type=float, + default=0.9, + help="Beta1 for AdamW optimizer") + parser.add_argument("--adam_beta2", + type=float, + default=0.999, + help="Beta2 for AdamW optimizer") + parser.add_argument("--adam_epsilon", + type=float, + default=1e-8, + help="Epsilon for AdamW optimizer.") + parser.add_argument("--max_grad_norm", + type=float, + default=1.0, + help="Max gradient norm.") + parser.add_argument( + "--lr_scheduler_type", + type=str, + default="cosine", + help="The scheduler type to use.", + choices=[ + "linear", "cosine", "cosine_with_restarts", "polynomial", + "constant", "constant_with_warmup" + ], + ) + parser.add_argument( + "--warmup_ratio", + type=float, + default=0.0, + help="Linear warmup over warmup_ratio fraction of total steps.") + parser.add_argument("--warmup_steps", + type=int, + default=0, + help="Linear warmup over warmup_steps.") + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument("--padding_strategy", + type=str, + default="max_length", + help="tokenizer padding strategy", + choices=["max_length", "longest"]) + parser.add_argument("--max_train_steps", + type=int, + default=-1, + help="Maximum training steps") + parser.add_argument("--use_flash_attn", + action="store_true", + default=False, + help="Use TriDao FlashAttention2") + parser.add_argument("--log_loss", + action="store_true", + help="Print loss when logging steps") + + args = parser.parse_args() + + if args.lora: + patch_peft() + + if args.accelerator == "cuda": + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + + args.global_rank = int(os.getenv("RANK", 0)) + args.local_rank = int(os.getenv("LOCAL_RANK", 0)) + args.world_size = int(os.getenv("WORLD_SIZE", 1)) + if args.global_rank != 0: + args.profile = False + + # mkdir for ckpt_dir + if len(args.ckpt_dir) > 0: + os.makedirs(args.ckpt_dir, exist_ok=True) + + # amp checks. + args.dtype = torch.float + if args.fp16: + assert not args.bf16 + args.dtype = torch.half + + if args.bf16: + assert not args.fp16 + args.dtype = torch.bfloat16 + + # DP/MP checks. + args.mp_num = args.pp_num * args.tp_num # model parallel size. + args.dp_num = max( + 1, args.world_size // (args.mp_num * args.fsdp_num * args.sp_num)) + + if not args.model_type: + if "llama" in args.model_name_or_path.lower(): + args.model_type = "llama" + elif "gpt" in args.model_name_or_path.lower(): + args.model_type = "gpt" + elif "glm" in args.model_name_or_path.lower(): + args.model_type = "glm" + elif "baichuan" in args.model_name_or_path.lower(): + args.model_type = "baichuan" + elif "qwen" in args.model_name_or_path.lower(): + args.model_type = "qwen" + elif "olmo" in args.model_name_or_path.lower(): + args.model_type = "olmo" + elif "gemma" in args.model_name_or_path.lower(): + args.model_type = "gemma" + else: + raise NotImplementedError( + f"Unsupported model: {args.model_name_or_path}") + + if args.model_type == "llama" and args.accelerator == 'acc' and ( + args.fp16 or args.bf16): + patch_llama(args.use_flash_attn) + if args.model_type == "gemma" and args.accelerator == 'acc': + patch_gemma() + + if args.local_rank == 0: + print_args(args) + + return args diff --git a/flashmodels/patch/__init__.py b/flashmodels/patch/__init__.py index 4dc7330..452a833 100644 --- a/flashmodels/patch/__init__.py +++ b/flashmodels/patch/__init__.py @@ -1,9 +1,5 @@ from flashmodels.patch.patch import patch_gemma, patch_llama, patch_lora -def patch_amp(): - import torchacc as ta - ta.patch_amp() - def patch_peft(): patch_lora() diff --git a/flashmodels/patch/__init__.py.bak b/flashmodels/patch/__init__.py.bak new file mode 100644 index 0000000..452a833 --- /dev/null +++ b/flashmodels/patch/__init__.py.bak @@ -0,0 +1,5 @@ +from flashmodels.patch.patch import patch_gemma, patch_llama, patch_lora + + +def patch_peft(): + patch_lora() diff --git a/flashmodels/patch/patch.py b/flashmodels/patch/patch.py index f559acd..bc76295 100644 --- a/flashmodels/patch/patch.py +++ b/flashmodels/patch/patch.py @@ -3,10 +3,13 @@ import os import re from typing import Any + import torch +import torchacc.utils.patch as patch import transformers + from flashmodels.logger import logger -from torchacc import patch_fa + def rewrite_load(): """Rewrite `torch.load` in `from_pretrain` in case to use mmap to reduce the CPU @@ -29,19 +32,13 @@ def rewrite_load(): def patch_llama(use_flash_attn): - if use_flash_attn: - patch_fa() - from transformers.cache_utils import Cache - def update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool, - ): - return None - transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask = update_causal_mask + patch.patch_llama(use_flash_attn) + if os.environ.get("ACC_LLAMA_TP") == "1": + transformers.models.llama.modeling_llama.LlamaMLP = LlamaMLP + if os.getenv("XLA_USE_SPMD") == "1": + # use einsum in linear for SPMD TP/Ulysses. + transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention + transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer # (wenting.swt): Delete me when merged in transformers if bool(int(os.environ.get("LOW_CPU_MEM_USAGE", "0"))): @@ -51,7 +48,6 @@ def update_causal_mask( def patch_gemma(): # Set the attention_mask in GemmaAttention to None to match the pattern of FlashAttentionRewriter. def wrap_for_flash_attention(func): - def wrapper(*args, **kwargs): kwargs["attention_mask"] = None return func(*args, **kwargs) diff --git a/flashmodels/patch/patch.py.bak b/flashmodels/patch/patch.py.bak new file mode 100644 index 0000000..bc76295 --- /dev/null +++ b/flashmodels/patch/patch.py.bak @@ -0,0 +1,111 @@ +import difflib +import inspect +import os +import re +from typing import Any + +import torch +import torchacc.utils.patch as patch +import transformers + +from flashmodels.logger import logger + + +def rewrite_load(): + """Rewrite `torch.load` in `from_pretrain` in case to use mmap to reduce the CPU + memory pressure of loading multiple copies of data under multiple processes""" + source = inspect.getsource(transformers.modeling_utils) + modified = re.sub(r"torch\.load\((?![^)]*mmap[^)]*\))([^)]*)\)", + r"torch.load(\g<1>, mmap=True)", source) + modified = re.sub(r"partial\(torch.load,(?![^)]*mmap[^)]*\))([^)]*)\)", + r"partial(torch.load,\g<1>, mmap=True)", modified) + if (int(os.environ.get("LOCAL_RANK", 0)) == 0): + lines = difflib.ndiff(source.split("\n"), modified.split("\n")) + diff = "\n".join([ + line for line in lines + if line.startswith("+") or line.startswith("-") + ]) + logger.warning( + f"When set LOW_CPU_MEM_USAGE, all the `torch.load` in transfomers.modeling_utils " + f"are called with `mmap=True`, diff: \n{diff}") + exec(modified, transformers.modeling_utils.__dict__) + + +def patch_llama(use_flash_attn): + patch.patch_llama(use_flash_attn) + if os.environ.get("ACC_LLAMA_TP") == "1": + transformers.models.llama.modeling_llama.LlamaMLP = LlamaMLP + if os.getenv("XLA_USE_SPMD") == "1": + # use einsum in linear for SPMD TP/Ulysses. + transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention + transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer + + # (wenting.swt): Delete me when merged in transformers + if bool(int(os.environ.get("LOW_CPU_MEM_USAGE", "0"))): + rewrite_load() + + +def patch_gemma(): + # Set the attention_mask in GemmaAttention to None to match the pattern of FlashAttentionRewriter. + def wrap_for_flash_attention(func): + def wrapper(*args, **kwargs): + kwargs["attention_mask"] = None + return func(*args, **kwargs) + + return wrapper + + xla_flags = os.getenv('XLA_FLAGS', '').split(' ') + pattern = r'--xla_gpu_enable_flash_attention=(\w+)' + for flag in xla_flags: + match = re.search(pattern, flag) + if match: + value = match.group(1) + if str(value).lower() == "true": + transformers.models.gemma.modeling_gemma.GemmaAttention.forward = wrap_for_flash_attention( + transformers.models.gemma.modeling_gemma.GemmaAttention. + forward) + + +def patch_lora(): + try: + import peft + from peft.tuners import lora + except ImportError: + logger.errors("import lora fail, please install peft.") + + def _forward_linear(self, x: torch.Tensor, *args: Any, + **kwargs: Any) -> torch.Tensor: + if self.disable_adapters: + if self.merged: + self.unmerge() + if version.parse(peft.__version__) > version.parse("0.6.2"): + result = self.base_layer(x, *args, **kwargs) + else: + result = self._linear(x) + elif self.merged: + if version.parse(peft.__version__) > version.parse("0.6.2"): + result = self.base_layer(x, *args, **kwargs) + else: + result = self._linear(x) + else: + if version.parse(peft.__version__) > version.parse("0.6.2"): + result = self.base_layer(x, *args, **kwargs) + else: + result = self._linear(x) + torch_result_dtype = result.dtype + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + x = x.to(lora_A.weight.dtype) + result += lora_B(lora_A(dropout(x))) * scaling + + result = result.to(torch_result_dtype) + return result + + # TODO(baole): delete this patch after + # https://github.com/huggingface/peft/pull/1010 is merged. + lora.Linear.forward = _forward_linear diff --git a/requirements.txt b/requirements.txt index 0c07b9b..01a909c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ datasets numpy torch tokenizers>=0.13.3 -transformers==4.41.0 +transformers==4.33.0 accelerate transformers_stream_generator tiktoken From f35bda90ad284275d974f1523a6ef5dfb4976461 Mon Sep 17 00:00:00 2001 From: Seventeen17 <17aloha@gmail.com> Date: Fri, 28 Jun 2024 15:53:55 +0800 Subject: [PATCH 5/5] move import back --- apps/train.py | 5 +- .../accelerators/cuda_llama_accelerator.py | 6 +- flashmodels/arguments.py.bak | 273 ------------------ flashmodels/patch/__init__.py.bak | 5 - flashmodels/patch/llama_model.py | 1 - flashmodels/patch/patch.py | 2 + flashmodels/patch/patch.py.bak | 111 ------- 7 files changed, 9 insertions(+), 394 deletions(-) delete mode 100644 flashmodels/arguments.py.bak delete mode 100644 flashmodels/patch/__init__.py.bak delete mode 100644 flashmodels/patch/patch.py.bak diff --git a/apps/train.py b/apps/train.py index b6d790a..7c0a48a 100644 --- a/apps/train.py +++ b/apps/train.py @@ -1,15 +1,16 @@ import torch +from flashmodels import Builder, Trainer, accelerate, arguments + + def train(): torch.manual_seed(101) # parse args - from flashmodels import arguments args = arguments.parse() # build model, tokenizer, loader, optimizer and lr_scheduler # and use accelerator to speed up training - from flashmodels import Builder, Trainer, accelerate builder = Builder(args) model, loader, tokenizer = builder.build_model_dataloader() model, loader = accelerate(model, loader, args) diff --git a/flashmodels/accelerators/cuda_llama_accelerator.py b/flashmodels/accelerators/cuda_llama_accelerator.py index a27da9d..288f771 100644 --- a/flashmodels/accelerators/cuda_llama_accelerator.py +++ b/flashmodels/accelerators/cuda_llama_accelerator.py @@ -71,7 +71,7 @@ def apply_checkpointing(self, model): checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT, ) - check_fn = lambda submodule: isinstance(LlamaDecoderLayer) + check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer) apply_activation_checkpointing( model, checkpoint_wrapper_fn=non_reentrant_wrapper, @@ -97,7 +97,9 @@ def fsdp(self, model): convert_outputs_to_fp32(model.forward.__func__), model) # Use auto_wrap_poliy for nested wrapping instead of only a top-level FSDP. - auto_wrap_policy = ModuleWrapPolicy({LlamaDecoderLayer, }) + auto_wrap_policy = ModuleWrapPolicy({ + LlamaDecoderLayer, + }) mixed_precision_policy = None if self.args.fp16 or self.args.bf16: diff --git a/flashmodels/arguments.py.bak b/flashmodels/arguments.py.bak deleted file mode 100644 index 8ad9eb7..0000000 --- a/flashmodels/arguments.py.bak +++ /dev/null @@ -1,273 +0,0 @@ -import argparse -import os - -import torch - -from flashmodels.logger import logger -from flashmodels.patch import patch_gemma, patch_llama, patch_peft - - -def print_args(args): - logger.info("FlashModels Arguments: ") - logger.info(" \n".join(f" {k} = {v}" for k, v in vars(args).items())) - - -def parse(): - parser = argparse.ArgumentParser(description="Flash Models Arguments") - - # model args - parser.add_argument("--model_name_or_path", - type=str, - default="decapoda-research/llama-7b-hf") - parser.add_argument("--cache_dir", type=str, default="./models/") - parser.add_argument("--max_seq_length", type=int, default=1024) - parser.add_argument( - "--model_type", - type=str, - default="", - choices=["gpt", "llama", "glm", "baichuan", "qwen", "olmo"]) - - # dataset args - parser.add_argument("--dataset_name_or_path", - type=str, - default="./data/wikitext-2-raw-v1.json") - parser.add_argument("--dataset_config", type=str, default="") - parser.add_argument("--micro_batch_size", type=int, default=8) - parser.add_argument("--padding_side", type=str, default="right") - parser.add_argument("--disable_train_sampler", - action="store_true", - help="Disable Train Sampler") - - # accelerator args - parser.add_argument("--accelerator", - type=str, - default="acc", - choices=["cuda", "acc", "megatron"], - help="accelerator name") - parser.add_argument("--fsdp_num", - type=int, - default=1, - help="Full sharded data parallel Number") - parser.add_argument("--gc", - action="store_true", - default=False, - help="Use gradients checkpoint") - parser.add_argument( - "--gc_cnt", - type=int, - default=None, - help="Number of decoder layers for gradient checkpointing") - parser.add_argument("--tp_num", - type=int, - default=1, - help="Tensor Parallel Number") - parser.add_argument("--sp", - action="store_true", - default=False, - help="Use Sequence Parallelism.") - parser.add_argument( - "--sp_reshard_after_forward", - action="store_true", - default=False, - help="To reduce memory usage, reshard weight after forward in TP-SP, \ - and perform an extra all-gather in the backward pass") - parser.add_argument("--sp_num", - type=int, - default=1, - help="DeepSpeed Ulysses Sequence \ - Parallel Number. ") - parser.add_argument("--dp_num", - type=int, - default=1, - help="Data Parallel Number") - parser.add_argument("--pp_num", - type=int, - default=1, - help="Pipeline Parallel Number") - parser.add_argument("--fp16", - action="store_true", - help="Run model in fp16 mode.") - parser.add_argument("--bf16", - action="store_true", - help="Run model in bfloat16 mode.") - parser.add_argument("--force_use_syncfree_adam", - action="store_true", - help="Force to use \ - syncfree.Adam/AdamW for better tracing peformance.") - parser.add_argument("--use_zero2", - action="store_true", - help="Use \ - distributed optimizer(ZeRO2) for SPMD-DP.") - parser.add_argument("--use_zero3", - action="store_true", - help="Use \ - ZeRO3 for SPMD-DP.") - - # lora - parser.add_argument("--lora", action="store_true", help="Use lora") - parser.add_argument("--lora_r", - type=int, - default=8, - help="lora attention dimension") - parser.add_argument("--lora_alpha", - type=int, - default=8, - help="lora scaling alpha parameter") - parser.add_argument("--lora_dropout", - type=float, - default=0.0, - help="The dropout probability \ - for Lora layers") - parser.add_argument( - "--lora_target_modules", - type=str, - default="QKV", - choices=["QKV", "ALL"], - help="The modules to apply Lora to. ALL means all linear layers in \ - decoder layer use lora, QKV means only qkv linears use lora") - - # training args - parser.add_argument("--global_rank", type=int, default=0) - parser.add_argument("--resume_from_checkpoint", - action="store_true", - help="Resume from checkpoint, if true," - " load checkpoint from ckpt_dir") - parser.add_argument("--ckpt_dir", type=str, default="") - parser.add_argument("--ckpt_freq", - type=int, - default=100, - help="The checkpoint frequency of local steps.") - parser.add_argument("--profile", - action="store_true", - help="Open pytorch profiler") - parser.add_argument("--profile_dir", type=str, default="./profile/") - parser.add_argument("--profile_stop_step", - type=int, - default=10, - help="Maximum profiling steps") - parser.add_argument("--log_interval", type=int, default=1) - parser.add_argument("--gradient_accumulation_steps", type=int, default=1) - parser.add_argument("--max_step", type=int, default=-1) - parser.add_argument("--learning_rate", - type=float, - default=2e-5, - help="The initial learning rate for AdamW.") - parser.add_argument("--weight_decay", - type=float, - default=0.03, - help="Weight decay for AdamW if we apply some.") - parser.add_argument("--adam_beta1", - type=float, - default=0.9, - help="Beta1 for AdamW optimizer") - parser.add_argument("--adam_beta2", - type=float, - default=0.999, - help="Beta2 for AdamW optimizer") - parser.add_argument("--adam_epsilon", - type=float, - default=1e-8, - help="Epsilon for AdamW optimizer.") - parser.add_argument("--max_grad_norm", - type=float, - default=1.0, - help="Max gradient norm.") - parser.add_argument( - "--lr_scheduler_type", - type=str, - default="cosine", - help="The scheduler type to use.", - choices=[ - "linear", "cosine", "cosine_with_restarts", "polynomial", - "constant", "constant_with_warmup" - ], - ) - parser.add_argument( - "--warmup_ratio", - type=float, - default=0.0, - help="Linear warmup over warmup_ratio fraction of total steps.") - parser.add_argument("--warmup_steps", - type=int, - default=0, - help="Linear warmup over warmup_steps.") - parser.add_argument("--num_train_epochs", type=int, default=1) - parser.add_argument("--padding_strategy", - type=str, - default="max_length", - help="tokenizer padding strategy", - choices=["max_length", "longest"]) - parser.add_argument("--max_train_steps", - type=int, - default=-1, - help="Maximum training steps") - parser.add_argument("--use_flash_attn", - action="store_true", - default=False, - help="Use TriDao FlashAttention2") - parser.add_argument("--log_loss", - action="store_true", - help="Print loss when logging steps") - - args = parser.parse_args() - - if args.lora: - patch_peft() - - if args.accelerator == "cuda": - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) - - args.global_rank = int(os.getenv("RANK", 0)) - args.local_rank = int(os.getenv("LOCAL_RANK", 0)) - args.world_size = int(os.getenv("WORLD_SIZE", 1)) - if args.global_rank != 0: - args.profile = False - - # mkdir for ckpt_dir - if len(args.ckpt_dir) > 0: - os.makedirs(args.ckpt_dir, exist_ok=True) - - # amp checks. - args.dtype = torch.float - if args.fp16: - assert not args.bf16 - args.dtype = torch.half - - if args.bf16: - assert not args.fp16 - args.dtype = torch.bfloat16 - - # DP/MP checks. - args.mp_num = args.pp_num * args.tp_num # model parallel size. - args.dp_num = max( - 1, args.world_size // (args.mp_num * args.fsdp_num * args.sp_num)) - - if not args.model_type: - if "llama" in args.model_name_or_path.lower(): - args.model_type = "llama" - elif "gpt" in args.model_name_or_path.lower(): - args.model_type = "gpt" - elif "glm" in args.model_name_or_path.lower(): - args.model_type = "glm" - elif "baichuan" in args.model_name_or_path.lower(): - args.model_type = "baichuan" - elif "qwen" in args.model_name_or_path.lower(): - args.model_type = "qwen" - elif "olmo" in args.model_name_or_path.lower(): - args.model_type = "olmo" - elif "gemma" in args.model_name_or_path.lower(): - args.model_type = "gemma" - else: - raise NotImplementedError( - f"Unsupported model: {args.model_name_or_path}") - - if args.model_type == "llama" and args.accelerator == 'acc' and ( - args.fp16 or args.bf16): - patch_llama(args.use_flash_attn) - if args.model_type == "gemma" and args.accelerator == 'acc': - patch_gemma() - - if args.local_rank == 0: - print_args(args) - - return args diff --git a/flashmodels/patch/__init__.py.bak b/flashmodels/patch/__init__.py.bak deleted file mode 100644 index 452a833..0000000 --- a/flashmodels/patch/__init__.py.bak +++ /dev/null @@ -1,5 +0,0 @@ -from flashmodels.patch.patch import patch_gemma, patch_llama, patch_lora - - -def patch_peft(): - patch_lora() diff --git a/flashmodels/patch/llama_model.py b/flashmodels/patch/llama_model.py index 6b4a7ca..745735f 100644 --- a/flashmodels/patch/llama_model.py +++ b/flashmodels/patch/llama_model.py @@ -8,7 +8,6 @@ import torch_xla.core.xla_model as xm from torch import nn from torchacc.dist.tp import Mesh, mark_sharding -from transformer.cache_utils import Cache from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import (ACT2FN, LlamaRMSNorm, LlamaRotaryEmbedding, diff --git a/flashmodels/patch/patch.py b/flashmodels/patch/patch.py index bc76295..4d6d36f 100644 --- a/flashmodels/patch/patch.py +++ b/flashmodels/patch/patch.py @@ -33,6 +33,8 @@ def rewrite_load(): def patch_llama(use_flash_attn): patch.patch_llama(use_flash_attn) + from flashmodels.patch.llama_model import (LlamaAttention, + LlamaDecoderLayer, LlamaMLP) if os.environ.get("ACC_LLAMA_TP") == "1": transformers.models.llama.modeling_llama.LlamaMLP = LlamaMLP if os.getenv("XLA_USE_SPMD") == "1": diff --git a/flashmodels/patch/patch.py.bak b/flashmodels/patch/patch.py.bak deleted file mode 100644 index bc76295..0000000 --- a/flashmodels/patch/patch.py.bak +++ /dev/null @@ -1,111 +0,0 @@ -import difflib -import inspect -import os -import re -from typing import Any - -import torch -import torchacc.utils.patch as patch -import transformers - -from flashmodels.logger import logger - - -def rewrite_load(): - """Rewrite `torch.load` in `from_pretrain` in case to use mmap to reduce the CPU - memory pressure of loading multiple copies of data under multiple processes""" - source = inspect.getsource(transformers.modeling_utils) - modified = re.sub(r"torch\.load\((?![^)]*mmap[^)]*\))([^)]*)\)", - r"torch.load(\g<1>, mmap=True)", source) - modified = re.sub(r"partial\(torch.load,(?![^)]*mmap[^)]*\))([^)]*)\)", - r"partial(torch.load,\g<1>, mmap=True)", modified) - if (int(os.environ.get("LOCAL_RANK", 0)) == 0): - lines = difflib.ndiff(source.split("\n"), modified.split("\n")) - diff = "\n".join([ - line for line in lines - if line.startswith("+") or line.startswith("-") - ]) - logger.warning( - f"When set LOW_CPU_MEM_USAGE, all the `torch.load` in transfomers.modeling_utils " - f"are called with `mmap=True`, diff: \n{diff}") - exec(modified, transformers.modeling_utils.__dict__) - - -def patch_llama(use_flash_attn): - patch.patch_llama(use_flash_attn) - if os.environ.get("ACC_LLAMA_TP") == "1": - transformers.models.llama.modeling_llama.LlamaMLP = LlamaMLP - if os.getenv("XLA_USE_SPMD") == "1": - # use einsum in linear for SPMD TP/Ulysses. - transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention - transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer - - # (wenting.swt): Delete me when merged in transformers - if bool(int(os.environ.get("LOW_CPU_MEM_USAGE", "0"))): - rewrite_load() - - -def patch_gemma(): - # Set the attention_mask in GemmaAttention to None to match the pattern of FlashAttentionRewriter. - def wrap_for_flash_attention(func): - def wrapper(*args, **kwargs): - kwargs["attention_mask"] = None - return func(*args, **kwargs) - - return wrapper - - xla_flags = os.getenv('XLA_FLAGS', '').split(' ') - pattern = r'--xla_gpu_enable_flash_attention=(\w+)' - for flag in xla_flags: - match = re.search(pattern, flag) - if match: - value = match.group(1) - if str(value).lower() == "true": - transformers.models.gemma.modeling_gemma.GemmaAttention.forward = wrap_for_flash_attention( - transformers.models.gemma.modeling_gemma.GemmaAttention. - forward) - - -def patch_lora(): - try: - import peft - from peft.tuners import lora - except ImportError: - logger.errors("import lora fail, please install peft.") - - def _forward_linear(self, x: torch.Tensor, *args: Any, - **kwargs: Any) -> torch.Tensor: - if self.disable_adapters: - if self.merged: - self.unmerge() - if version.parse(peft.__version__) > version.parse("0.6.2"): - result = self.base_layer(x, *args, **kwargs) - else: - result = self._linear(x) - elif self.merged: - if version.parse(peft.__version__) > version.parse("0.6.2"): - result = self.base_layer(x, *args, **kwargs) - else: - result = self._linear(x) - else: - if version.parse(peft.__version__) > version.parse("0.6.2"): - result = self.base_layer(x, *args, **kwargs) - else: - result = self._linear(x) - torch_result_dtype = result.dtype - for active_adapter in self.active_adapters: - if active_adapter not in self.lora_A.keys(): - continue - lora_A = self.lora_A[active_adapter] - lora_B = self.lora_B[active_adapter] - dropout = self.lora_dropout[active_adapter] - scaling = self.scaling[active_adapter] - x = x.to(lora_A.weight.dtype) - result += lora_B(lora_A(dropout(x))) * scaling - - result = result.to(torch_result_dtype) - return result - - # TODO(baole): delete this patch after - # https://github.com/huggingface/peft/pull/1010 is merged. - lora.Linear.forward = _forward_linear