diff --git a/fastgen/methods/distribution_matching/dmd2.py b/fastgen/methods/distribution_matching/dmd2.py index b106e61..6421901 100644 --- a/fastgen/methods/distribution_matching/dmd2.py +++ b/fastgen/methods/distribution_matching/dmd2.py @@ -121,39 +121,34 @@ def _generate_noise_and_time( eps = torch.randn_like(real_data, device=self.device, dtype=real_data.dtype) return input_student, t_student, t, eps - def _compute_teacher_prediction_gan_loss( + def _compute_teacher_x0( self, perturbed_data: torch.Tensor, t: torch.Tensor, condition: Optional[Any] = None - ) -> tuple[torch.Tensor, torch.Tensor]: - """Compute teacher prediction and optionally GAN loss for generator. - - Args: - perturbed_data: Perturbed data tensor - t: Time steps - condition: Conditioning information - - Returns: - tuple of (teacher_x0, fake_feat or None, gan_loss_gen) - """ - if self.config.gan_loss_weight_gen > 0: - teacher_x0, fake_feat = self.teacher( - perturbed_data, - t, - condition=condition, - feature_indices=self.discriminator.feature_indices, - fwd_pred_type="x0", - ) - # Compute the GAN loss for the generator - gan_loss_gen = gan_loss_generator(self.discriminator(fake_feat)) - else: + ) -> torch.Tensor: + """Compute the teacher x0 target used by VSD.""" + with torch.no_grad(): teacher_x0 = self.teacher( perturbed_data, t, condition=condition, fwd_pred_type="x0", ) - gan_loss_gen = torch.tensor(0.0, device=self.device, dtype=teacher_x0.dtype) + return teacher_x0.detach() - return teacher_x0.detach(), gan_loss_gen + def _compute_generator_gan_loss( + self, perturbed_data: torch.Tensor, t: torch.Tensor, condition: Optional[Any] = None + ) -> torch.Tensor: + """Compute the generator-side GAN loss from fake_score features.""" + if self.config.gan_loss_weight_gen <= 0: + return torch.tensor(0.0, device=self.device, dtype=perturbed_data.dtype) + + fake_feat = self.fake_score( + perturbed_data, + t, + condition=condition, + return_features_early=True, + feature_indices=self.discriminator.feature_indices, + ) + return gan_loss_generator(self.discriminator(fake_feat)) def _apply_classifier_free_guidance( self, @@ -223,7 +218,8 @@ def _student_update_step( assert ( t.dtype == t_student.dtype == self.net.noise_scheduler.t_precision ), f"t.dtype: {t.dtype}, t_student.dtype: {t_student.dtype}, self.net.noise_scheduler.t_precision: {self.net.noise_scheduler.t_precision}" - teacher_x0, gan_loss_gen = self._compute_teacher_prediction_gan_loss(perturbed_data, t, condition=condition) + teacher_x0 = self._compute_teacher_x0(perturbed_data, t, condition=condition) + gan_loss_gen = self._compute_generator_gan_loss(perturbed_data, t, condition=condition) # Apply classifier-free guidance if needed if self.config.guidance_scale is not None: @@ -250,7 +246,7 @@ def _student_update_step( def _compute_real_feat( self, real_data: torch.Tensor, t: torch.Tensor, eps: torch.Tensor, condition: Optional[Any] = None ) -> tuple[torch.Tensor, torch.Tensor]: - """Compute discriminator features for both real and fake data. + """Compute real features for the discriminator using the fake_score backbone. Args: real_data: Real data tensor @@ -272,9 +268,8 @@ def _compute_real_feat( device=self.device, ) eps_real = torch.randn_like(real_data) - # Perturb the real data according to the given forward process perturbed_real = self.net.noise_scheduler.forward_process(real_data, eps_real, t_real) - real_feat = self.teacher( + real_feat = self.fake_score( perturbed_real, t_real, condition=condition, @@ -304,7 +299,7 @@ def _compute_r1_regularization( """ perturbed_real_alpha = real_data.add(self.config.gan_r1_reg_alpha * torch.randn_like(real_data)) with torch.no_grad(): - real_feat_alpha = self.teacher( + real_feat_alpha = self.fake_score( perturbed_real_alpha, t_real, condition=condition, @@ -365,16 +360,14 @@ def _fake_score_discriminator_update_step( gan_loss_ar1 = torch.zeros_like(loss_fakescore) if self.config.gan_loss_weight_gen > 0: # Compute the GAN loss for the discriminator - with torch.no_grad(): - fake_feat = self.teacher( - x_t_sg, - t, - condition=condition, - return_features_early=True, - feature_indices=self.discriminator.feature_indices, - ) - - real_feat, t_real = self._compute_real_feat(real_data=real_data, t=t, eps=eps, condition=condition) + fake_feat = self.fake_score( + x_t_sg, + t, + condition=condition, + return_features_early=True, + feature_indices=self.discriminator.feature_indices, + ) + real_feat, t_real = self._compute_real_feat(real_data=real_data, t=t, eps=eps, condition=condition) real_feat_logit = self.discriminator(real_feat) gan_loss_disc = gan_loss_discriminator(real_feat_logit, self.discriminator(fake_feat))