From cb5116629d68bed67da2974ad4cb40f22491a746 Mon Sep 17 00:00:00 2001 From: Hyeongjun Jeon Date: Thu, 23 Oct 2025 04:17:26 +0000 Subject: [PATCH 1/2] apply gradient checkpoint config --- .../models/gpt2/modeling_gpt2_moreh.py | 20 +++++++++++++++++-- .../models/mistral/modeling_mistral_moreh.py | 15 ++++++++++++-- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/gpt2/modeling_gpt2_moreh.py b/src/transformers/models/gpt2/modeling_gpt2_moreh.py index 533e292b81ea..2048dcb14903 100644 --- a/src/transformers/models/gpt2/modeling_gpt2_moreh.py +++ b/src/transformers/models/gpt2/modeling_gpt2_moreh.py @@ -1017,11 +1017,21 @@ def __init__(self, config): self.post_init() # Moreh Config - self.moreh_pipeline_layers = [] moreh_config = getattr(config, "moreh_config", None) + + # Moreh Pipeline Layers + self.moreh_pipeline_layers = [] if moreh_config is not None and "pipeline_layers" in moreh_config: self.moreh_pipeline_layers = moreh_config["pipeline_layers"] + # Moreh Gradient Checkpoint Layers Step + # If moreh_gradient_checkpoint_layers_step is N, + # then 1st, (1+N)th, (1+2N)th, ... layer's input activations will be checkpointed + self.moreh_gradient_checkpoint_layers_step = None + if self.moreh_gradient_checkpoint_layers_step is not None and ( + layer_idx % + self.moreh_gradient_checkpoint_layers_step) == 0: + hidden_states = torch.moreh.checkpoint_assign(hidden_states) @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): @@ -1212,6 +1222,12 @@ def forward( all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Gradient checkpoint assign + if self.moreh_gradient_checkpoint_layers_step is not None and ( + layer_idx % + self.moreh_gradient_checkpoint_layers_step) == 0: + hidden_states = torch.moreh.checkpoint_assign(hidden_states) + # Model parallel if self.model_parallel: torch.cuda.set_device(hidden_states.device) @@ -2075,4 +2091,4 @@ def _reorder_cache( # hidden_states=outputs.hidden_states, # attentions=outputs.attentions, # ) -# \ No newline at end of file +# diff --git a/src/transformers/models/mistral/modeling_mistral_moreh.py b/src/transformers/models/mistral/modeling_mistral_moreh.py index 0ff2da78034a..5aeb17b3acbe 100644 --- a/src/transformers/models/mistral/modeling_mistral_moreh.py +++ b/src/transformers/models/mistral/modeling_mistral_moreh.py @@ -919,11 +919,22 @@ def __init__(self, config: MistralMorehConfig): self.post_init() # Moreh Config - self.moreh_pipeline_layers = [] moreh_config = getattr(config, "moreh_config", None) + + # Moreh Pipeline Layers + self.moreh_pipeline_layers = [] if moreh_config is not None and "pipeline_layers" in moreh_config: self.moreh_pipeline_layers = moreh_config["pipeline_layers"] + # Moreh Gradient Checkpoint Layers Step + # If moreh_gradient_checkpoint_layers_step is N, + # then 1st, (1+N)th, (1+2N)th, ... layer's input activations will be checkpointed + self.moreh_gradient_checkpoint_layers_step = None + if self.moreh_gradient_checkpoint_layers_step is not None and ( + layer_idx % + self.moreh_gradient_checkpoint_layers_step) == 0: + hidden_states = torch.moreh.checkpoint_assign(hidden_states) + def get_input_embeddings(self): return self.embed_tokens @@ -1579,4 +1590,4 @@ def _reorder_cache(past_key_values, beam_idx): # hidden_states=outputs.hidden_states, # attentions=outputs.attentions, # ) -# \ No newline at end of file +# From f976c4090b56d586e685f34bf1ce6bd50f074f3a Mon Sep 17 00:00:00 2001 From: Hyeongjun Jeon Date: Thu, 23 Oct 2025 04:58:29 +0000 Subject: [PATCH 2/2] fix --- src/transformers/models/gpt2/modeling_gpt2_moreh.py | 9 ++++----- .../models/mistral/modeling_mistral_moreh.py | 13 +++++++++---- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/gpt2/modeling_gpt2_moreh.py b/src/transformers/models/gpt2/modeling_gpt2_moreh.py index 2048dcb14903..7c33855e8cc2 100644 --- a/src/transformers/models/gpt2/modeling_gpt2_moreh.py +++ b/src/transformers/models/gpt2/modeling_gpt2_moreh.py @@ -1028,10 +1028,9 @@ def __init__(self, config): # If moreh_gradient_checkpoint_layers_step is N, # then 1st, (1+N)th, (1+2N)th, ... layer's input activations will be checkpointed self.moreh_gradient_checkpoint_layers_step = None - if self.moreh_gradient_checkpoint_layers_step is not None and ( - layer_idx % - self.moreh_gradient_checkpoint_layers_step) == 0: - hidden_states = torch.moreh.checkpoint_assign(hidden_states) + if moreh_config is not None and "gradient_checkpoint_layers_step" in moreh_config: + self.moreh_gradient_checkpoint_layers_step = moreh_config[ + "gradient_checkpoint_layers_step"] @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): @@ -1224,7 +1223,7 @@ def forward( for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): # Gradient checkpoint assign if self.moreh_gradient_checkpoint_layers_step is not None and ( - layer_idx % + i % self.moreh_gradient_checkpoint_layers_step) == 0: hidden_states = torch.moreh.checkpoint_assign(hidden_states) diff --git a/src/transformers/models/mistral/modeling_mistral_moreh.py b/src/transformers/models/mistral/modeling_mistral_moreh.py index 5aeb17b3acbe..019ccdac4419 100644 --- a/src/transformers/models/mistral/modeling_mistral_moreh.py +++ b/src/transformers/models/mistral/modeling_mistral_moreh.py @@ -930,10 +930,9 @@ def __init__(self, config: MistralMorehConfig): # If moreh_gradient_checkpoint_layers_step is N, # then 1st, (1+N)th, (1+2N)th, ... layer's input activations will be checkpointed self.moreh_gradient_checkpoint_layers_step = None - if self.moreh_gradient_checkpoint_layers_step is not None and ( - layer_idx % - self.moreh_gradient_checkpoint_layers_step) == 0: - hidden_states = torch.moreh.checkpoint_assign(hidden_states) + if moreh_config is not None and "gradient_checkpoint_layers_step" in moreh_config: + self.moreh_gradient_checkpoint_layers_step = moreh_config[ + "gradient_checkpoint_layers_step"] def get_input_embeddings(self): return self.embed_tokens @@ -1008,6 +1007,12 @@ def forward( next_decoder_cache = None for layer_idx, decoder_layer in enumerate(self.layers): + # Gradient checkpoint assign + if self.moreh_gradient_checkpoint_layers_step is not None and ( + layer_idx % + self.moreh_gradient_checkpoint_layers_step) == 0: + hidden_states = torch.moreh.checkpoint_assign(hidden_states) + if output_hidden_states: all_hidden_states += (hidden_states,)