diff --git a/src/transformers/models/gpt2/modeling_gpt2_moreh.py b/src/transformers/models/gpt2/modeling_gpt2_moreh.py index 533e292b81ea..7c33855e8cc2 100644 --- a/src/transformers/models/gpt2/modeling_gpt2_moreh.py +++ b/src/transformers/models/gpt2/modeling_gpt2_moreh.py @@ -1017,11 +1017,20 @@ def __init__(self, config): self.post_init() # Moreh Config - self.moreh_pipeline_layers = [] moreh_config = getattr(config, "moreh_config", None) + + # Moreh Pipeline Layers + self.moreh_pipeline_layers = [] if moreh_config is not None and "pipeline_layers" in moreh_config: self.moreh_pipeline_layers = moreh_config["pipeline_layers"] + # Moreh Gradient Checkpoint Layers Step + # If moreh_gradient_checkpoint_layers_step is N, + # then 1st, (1+N)th, (1+2N)th, ... layer's input activations will be checkpointed + self.moreh_gradient_checkpoint_layers_step = None + if moreh_config is not None and "gradient_checkpoint_layers_step" in moreh_config: + self.moreh_gradient_checkpoint_layers_step = moreh_config[ + "gradient_checkpoint_layers_step"] @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): @@ -1212,6 +1221,12 @@ def forward( all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Gradient checkpoint assign + if self.moreh_gradient_checkpoint_layers_step is not None and ( + i % + self.moreh_gradient_checkpoint_layers_step) == 0: + hidden_states = torch.moreh.checkpoint_assign(hidden_states) + # Model parallel if self.model_parallel: torch.cuda.set_device(hidden_states.device) @@ -2075,4 +2090,4 @@ def _reorder_cache( # hidden_states=outputs.hidden_states, # attentions=outputs.attentions, # ) -# \ No newline at end of file +# diff --git a/src/transformers/models/mistral/modeling_mistral_moreh.py b/src/transformers/models/mistral/modeling_mistral_moreh.py index 0ff2da78034a..019ccdac4419 100644 --- a/src/transformers/models/mistral/modeling_mistral_moreh.py +++ b/src/transformers/models/mistral/modeling_mistral_moreh.py @@ -919,11 +919,21 @@ def __init__(self, config: MistralMorehConfig): self.post_init() # Moreh Config - self.moreh_pipeline_layers = [] moreh_config = getattr(config, "moreh_config", None) + + # Moreh Pipeline Layers + self.moreh_pipeline_layers = [] if moreh_config is not None and "pipeline_layers" in moreh_config: self.moreh_pipeline_layers = moreh_config["pipeline_layers"] + # Moreh Gradient Checkpoint Layers Step + # If moreh_gradient_checkpoint_layers_step is N, + # then 1st, (1+N)th, (1+2N)th, ... layer's input activations will be checkpointed + self.moreh_gradient_checkpoint_layers_step = None + if moreh_config is not None and "gradient_checkpoint_layers_step" in moreh_config: + self.moreh_gradient_checkpoint_layers_step = moreh_config[ + "gradient_checkpoint_layers_step"] + def get_input_embeddings(self): return self.embed_tokens @@ -997,6 +1007,12 @@ def forward( next_decoder_cache = None for layer_idx, decoder_layer in enumerate(self.layers): + # Gradient checkpoint assign + if self.moreh_gradient_checkpoint_layers_step is not None and ( + layer_idx % + self.moreh_gradient_checkpoint_layers_step) == 0: + hidden_states = torch.moreh.checkpoint_assign(hidden_states) + if output_hidden_states: all_hidden_states += (hidden_states,) @@ -1579,4 +1595,4 @@ def _reorder_cache(past_key_values, beam_idx): # hidden_states=outputs.hidden_states, # attentions=outputs.attentions, # ) -# \ No newline at end of file +#