From 874ffbd4f3a36c779d49200a13ba0243413824fb Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Mon, 22 Apr 2024 08:33:26 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- pippy/PipelineSchedule.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/pippy/PipelineSchedule.py b/pippy/PipelineSchedule.py index 09532e579..20420fc24 100644 --- a/pippy/PipelineSchedule.py +++ b/pippy/PipelineSchedule.py @@ -44,9 +44,9 @@ def _maybe_compute_loss(self, stage, output, target_mbs, mb_index): f"[{stage.stage_index}] Loss of microbatch {mb_index}: {loss}" ) - def _maybe_get_loss(self, mb_index): + def _maybe_get_loss(self, stage, mb_index): valid_index = 0 <= mb_index < len(self._internal_losses) - if self._has_backward and valid_index: + if stage.is_last and self._has_backward and valid_index: return self._internal_losses[mb_index] elif len(self._internal_losses) != 0 and not valid_index: raise RuntimeError( @@ -56,12 +56,17 @@ def _maybe_get_loss(self, mb_index): else: return None - def _update_losses(self, losses): + def _update_losses(self, stages, losses): """ Update the losses to those in the internal state """ + # if stages not a list turn into a list + if not isinstance(stages, list): + stages = [stages] + contains_last_stage = any([stage.is_last for stage in stages]) + # Return losses if there is a container passed in - if losses is not None: + if contains_last_stage and losses is not None: if len(self._internal_losses) != self._n_microbatches: raise RuntimeError( f"Expecting {self._n_microbatches} losses but got {len(self._internal_losses)}" @@ -330,7 +335,7 @@ def step_microbatches( for work in works.values(): work.wait() - loss = self._maybe_get_loss(i) + loss = self._maybe_get_loss(self._stage, i) self._stage.backward_one_chunk(loss=loss) ops = self._stage.get_bwd_send_ops() @@ -342,7 +347,7 @@ def step_microbatches( ) # Return losses if there is a container passed in - self._update_losses(losses) + self._update_losses(self._stage, losses) # Wait for all backward sends to finish for work in bwd_sends_to_wait: @@ -423,7 +428,7 @@ def step_microbatches( for work in works.values(): work.wait() - loss = self._maybe_get_loss(bwd_mb_index) + loss = self._maybe_get_loss(self._stage, bwd_mb_index) self._stage.backward_one_chunk(loss=loss) ops = self._stage.get_bwd_send_ops() @@ -440,7 +445,7 @@ def step_microbatches( work.wait() # Return losses if there is a container passed in - self._update_losses(losses) + self._update_losses(self._stage, losses) class PipelineScheduleMulti(PipelineSchedule): @@ -553,14 +558,14 @@ def step_microbatches( if ops: dist.batch_isend_irecv(ops).pop().wait() - loss = self._maybe_get_loss(i) + loss = self._maybe_get_loss(stage, i) stage.backward_one_chunk(loss=loss) ops = stage.get_bwd_send_ops() if ops: dist.batch_isend_irecv(ops) - self._update_losses(losses) + self._update_losses(self._stages, losses) class ScheduleInterleaved1F1B(PipelineScheduleMulti): @@ -739,7 +744,7 @@ def backward_stage_local_index(step): ) # bwd - loss = self._maybe_get_loss(bwd_mb_index) + loss = self._maybe_get_loss(bwd_stage, bwd_mb_index) bwd_stage.backward_one_chunk(loss=loss) ops.extend(bwd_stage.get_bwd_send_ops()) @@ -764,7 +769,7 @@ def backward_stage_local_index(step): for work in works.values(): work.wait() - loss = self._maybe_get_loss(bwd_mb_index) + loss = self._maybe_get_loss(bwd_stage, bwd_mb_index) bwd_stage.backward_one_chunk(loss=loss) ops = bwd_stage.get_bwd_send_ops() @@ -776,4 +781,4 @@ def backward_stage_local_index(step): work.wait() # Return losses if there is a container passed in - self._update_losses(losses) + self._update_losses(self._stages, losses)