diff --git a/examples/AltDiffusion/dreambooth.py b/examples/AltDiffusion/dreambooth.py index 404a8569..006d1d92 100644 --- a/examples/AltDiffusion/dreambooth.py +++ b/examples/AltDiffusion/dreambooth.py @@ -11,6 +11,7 @@ from pathlib import Path import torch +from torch.cuda.amp import autocast as autocast from torch.utils.data import Dataset from PIL import Image from torchvision import transforms @@ -34,7 +35,7 @@ train_text_encoder = False train_only_unet = True -num_train_epochs = 500 +num_train_epochs = 10 batch_size = 4 learning_rate = 5e-6 adam_beta1 = 0.9 @@ -197,20 +198,23 @@ def collate_fn(examples): if with_prior_preservation: x, x_prior = torch.chunk(x, 2, dim=0) c, c_prior = torch.chunk(c, 2, dim=0) - loss, _ = model(x, c) + with autocast(): + loss, _ = model(x, c) prior_loss, _ = model(x_prior, c_prior) loss = loss + prior_loss_weight * prior_loss else: - loss, _ = model(x, c) + with autocast(): + loss, _ = model(x, c) print('*'*20, "loss=", str(loss.detach().item())) - loss.backward() - optimizer.step() - optimizer.zero_grad() + with autocast(): + loss.backward() + optimizer.step() + optimizer.zero_grad() ## mkdir ./checkpoints/DreamBooth and copy ./checkpoints/AltDiffusion to ./checkpoints/DreamBooth/AltDiffusion ## overwrite model.ckpt for latter usage -chekpoint_path = './checkpoints/DreamBooth/AltDiffusion/model.ckpt' +chekpoint_path = './checkpoints/AltDiffusion/dreambooth_model.ckpt' torch.save(model.state_dict(), chekpoint_path) diff --git a/examples/AltDiffusion/instance_images/0.jpg b/examples/AltDiffusion/instance_images/0.jpg new file mode 100644 index 00000000..0ee2ab23 Binary files /dev/null and b/examples/AltDiffusion/instance_images/0.jpg differ diff --git a/examples/AltDiffusion/instance_images/1.jpeg b/examples/AltDiffusion/instance_images/1.jpeg new file mode 100644 index 00000000..67cbf1eb Binary files /dev/null and b/examples/AltDiffusion/instance_images/1.jpeg differ diff --git a/examples/AltDiffusion/instance_images/10.jpeg b/examples/AltDiffusion/instance_images/10.jpeg new file mode 100644 index 00000000..c669c520 Binary files /dev/null and b/examples/AltDiffusion/instance_images/10.jpeg differ diff --git a/examples/AltDiffusion/instance_images/11.jpeg b/examples/AltDiffusion/instance_images/11.jpeg new file mode 100644 index 00000000..6660545c Binary files /dev/null and b/examples/AltDiffusion/instance_images/11.jpeg differ diff --git a/examples/AltDiffusion/instance_images/12.jpeg b/examples/AltDiffusion/instance_images/12.jpeg new file mode 100644 index 00000000..dc80d5b5 Binary files /dev/null and b/examples/AltDiffusion/instance_images/12.jpeg differ diff --git a/examples/AltDiffusion/instance_images/13.jpeg b/examples/AltDiffusion/instance_images/13.jpeg new file mode 100644 index 00000000..9ff429f3 Binary files /dev/null and b/examples/AltDiffusion/instance_images/13.jpeg differ diff --git a/examples/AltDiffusion/instance_images/14.jpeg b/examples/AltDiffusion/instance_images/14.jpeg new file mode 100644 index 00000000..9e2ee8e6 Binary files /dev/null and b/examples/AltDiffusion/instance_images/14.jpeg differ diff --git a/examples/AltDiffusion/instance_images/15.jpg b/examples/AltDiffusion/instance_images/15.jpg new file mode 100644 index 00000000..2ca6e9b7 Binary files /dev/null and b/examples/AltDiffusion/instance_images/15.jpg differ diff --git a/examples/AltDiffusion/instance_images/2.jpg b/examples/AltDiffusion/instance_images/2.jpg new file mode 100644 index 00000000..038fbd05 Binary files /dev/null and b/examples/AltDiffusion/instance_images/2.jpg differ diff --git a/examples/AltDiffusion/instance_images/3.jpeg b/examples/AltDiffusion/instance_images/3.jpeg new file mode 100644 index 00000000..092f1bfa Binary files /dev/null and b/examples/AltDiffusion/instance_images/3.jpeg differ diff --git a/examples/AltDiffusion/instance_images/4.jpeg b/examples/AltDiffusion/instance_images/4.jpeg new file mode 100644 index 00000000..6c4306b7 Binary files /dev/null and b/examples/AltDiffusion/instance_images/4.jpeg differ diff --git a/examples/AltDiffusion/instance_images/5.jpeg b/examples/AltDiffusion/instance_images/5.jpeg new file mode 100644 index 00000000..4852a74c Binary files /dev/null and b/examples/AltDiffusion/instance_images/5.jpeg differ diff --git a/examples/AltDiffusion/instance_images/6.jpg b/examples/AltDiffusion/instance_images/6.jpg new file mode 100644 index 00000000..b7c9e1ad Binary files /dev/null and b/examples/AltDiffusion/instance_images/6.jpg differ diff --git a/examples/AltDiffusion/instance_images/7.jpeg b/examples/AltDiffusion/instance_images/7.jpeg new file mode 100644 index 00000000..aa48f7a9 Binary files /dev/null and b/examples/AltDiffusion/instance_images/7.jpeg differ diff --git a/examples/AltDiffusion/instance_images/8.jpg b/examples/AltDiffusion/instance_images/8.jpg new file mode 100644 index 00000000..5889132a Binary files /dev/null and b/examples/AltDiffusion/instance_images/8.jpg differ diff --git a/examples/AltDiffusion/instance_images/9.jpeg b/examples/AltDiffusion/instance_images/9.jpeg new file mode 100644 index 00000000..61accb47 Binary files /dev/null and b/examples/AltDiffusion/instance_images/9.jpeg differ diff --git a/flagai/model/mm/AltDiffusion.py b/flagai/model/mm/AltDiffusion.py index 9a6b12cf..62c8ae60 100755 --- a/flagai/model/mm/AltDiffusion.py +++ b/flagai/model/mm/AltDiffusion.py @@ -1319,6 +1319,7 @@ def p_losses(self, x_start, cond, t, noise=None): loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + t = t.cpu() logvar_t = self.logvar[t].to(self.device) loss = loss_simple / torch.exp(logvar_t) + logvar_t # loss = loss_simple / torch.exp(self.logvar) + self.logvar @@ -1932,4 +1933,4 @@ def normal_kl(mean1, logvar1, mean2, logvar2): ] return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + - ((mean1 - mean2)**2) * torch.exp(-logvar2)) \ No newline at end of file + ((mean1 - mean2)**2) * torch.exp(-logvar2)) diff --git a/requirements.txt b/requirements.txt index f9250dbc..18f2f4f8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ PyYAML==5.4.1 deepspeed==0.6.5 flash-attn==1.0.2 bminf +torch