Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 33 additions & 40 deletions fastgen/methods/distribution_matching/dmd2.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,39 +121,34 @@ def _generate_noise_and_time(
eps = torch.randn_like(real_data, device=self.device, dtype=real_data.dtype)
return input_student, t_student, t, eps

def _compute_teacher_prediction_gan_loss(
def _compute_teacher_x0(
self, perturbed_data: torch.Tensor, t: torch.Tensor, condition: Optional[Any] = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute teacher prediction and optionally GAN loss for generator.

Args:
perturbed_data: Perturbed data tensor
t: Time steps
condition: Conditioning information

Returns:
tuple of (teacher_x0, fake_feat or None, gan_loss_gen)
"""
if self.config.gan_loss_weight_gen > 0:
teacher_x0, fake_feat = self.teacher(
perturbed_data,
t,
condition=condition,
feature_indices=self.discriminator.feature_indices,
fwd_pred_type="x0",
)
# Compute the GAN loss for the generator
gan_loss_gen = gan_loss_generator(self.discriminator(fake_feat))
else:
) -> torch.Tensor:
"""Compute the teacher x0 target used by VSD."""
with torch.no_grad():
teacher_x0 = self.teacher(
perturbed_data,
t,
condition=condition,
fwd_pred_type="x0",
)
gan_loss_gen = torch.tensor(0.0, device=self.device, dtype=teacher_x0.dtype)
return teacher_x0.detach()

return teacher_x0.detach(), gan_loss_gen
def _compute_generator_gan_loss(
self, perturbed_data: torch.Tensor, t: torch.Tensor, condition: Optional[Any] = None
) -> torch.Tensor:
"""Compute the generator-side GAN loss from fake_score features."""
if self.config.gan_loss_weight_gen <= 0:
return torch.tensor(0.0, device=self.device, dtype=perturbed_data.dtype)

fake_feat = self.fake_score(
perturbed_data,
t,
condition=condition,
return_features_early=True,
feature_indices=self.discriminator.feature_indices,
)
return gan_loss_generator(self.discriminator(fake_feat))

def _apply_classifier_free_guidance(
self,
Expand Down Expand Up @@ -223,7 +218,8 @@ def _student_update_step(
assert (
t.dtype == t_student.dtype == self.net.noise_scheduler.t_precision
), f"t.dtype: {t.dtype}, t_student.dtype: {t_student.dtype}, self.net.noise_scheduler.t_precision: {self.net.noise_scheduler.t_precision}"
teacher_x0, gan_loss_gen = self._compute_teacher_prediction_gan_loss(perturbed_data, t, condition=condition)
teacher_x0 = self._compute_teacher_x0(perturbed_data, t, condition=condition)
gan_loss_gen = self._compute_generator_gan_loss(perturbed_data, t, condition=condition)

# Apply classifier-free guidance if needed
if self.config.guidance_scale is not None:
Expand All @@ -250,7 +246,7 @@ def _student_update_step(
def _compute_real_feat(
self, real_data: torch.Tensor, t: torch.Tensor, eps: torch.Tensor, condition: Optional[Any] = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute discriminator features for both real and fake data.
"""Compute real features for the discriminator using the fake_score backbone.

Args:
real_data: Real data tensor
Expand All @@ -272,9 +268,8 @@ def _compute_real_feat(
device=self.device,
)
eps_real = torch.randn_like(real_data)
# Perturb the real data according to the given forward process
perturbed_real = self.net.noise_scheduler.forward_process(real_data, eps_real, t_real)
real_feat = self.teacher(
real_feat = self.fake_score(
perturbed_real,
t_real,
condition=condition,
Expand Down Expand Up @@ -304,7 +299,7 @@ def _compute_r1_regularization(
"""
perturbed_real_alpha = real_data.add(self.config.gan_r1_reg_alpha * torch.randn_like(real_data))
with torch.no_grad():
real_feat_alpha = self.teacher(
real_feat_alpha = self.fake_score(
perturbed_real_alpha,
t_real,
condition=condition,
Expand Down Expand Up @@ -365,16 +360,14 @@ def _fake_score_discriminator_update_step(
gan_loss_ar1 = torch.zeros_like(loss_fakescore)
if self.config.gan_loss_weight_gen > 0:
# Compute the GAN loss for the discriminator
with torch.no_grad():
fake_feat = self.teacher(
x_t_sg,
t,
condition=condition,
return_features_early=True,
feature_indices=self.discriminator.feature_indices,
)

real_feat, t_real = self._compute_real_feat(real_data=real_data, t=t, eps=eps, condition=condition)
fake_feat = self.fake_score(
x_t_sg,
t,
condition=condition,
return_features_early=True,
feature_indices=self.discriminator.feature_indices,
)
real_feat, t_real = self._compute_real_feat(real_data=real_data, t=t, eps=eps, condition=condition)

real_feat_logit = self.discriminator(real_feat)
gan_loss_disc = gan_loss_discriminator(real_feat_logit, self.discriminator(fake_feat))
Expand Down