diff --git a/.gitignore b/.gitignore index c7b25d0..efa6a9a 100644 --- a/.gitignore +++ b/.gitignore @@ -15,7 +15,6 @@ iccv21 # data data -stats *.png _cache_* figures diff --git a/Readme.md b/Readme.md index 6f32f6e..65106d5 100644 --- a/Readme.md +++ b/Readme.md @@ -1,11 +1,6 @@ -# Gradient Normalization for Generative Adversarial Networks - -Yi-Lun Wu, Hong-Han Shuai, Zhi-Rui Tam, Hong-Yu Chiu - -Paper: [https://arxiv.org/abs/2109.02235](https://arxiv.org/abs/2109.02235) - -This is the official implementation of Gradient Normalized GAN (GN-GAN). - +## NOTE +THIS IS A FORK OF A FORK TO TEST FOR SN +THE FILES ARE IN SN_TESTING ## Requirements - Python 3.8.9 - Python packages @@ -116,44 +111,3 @@ All the reported values (Inception Score and FID) in our paper are calculated by --eval \ --save path/to/generated/images ``` - -## How to integrate Gradient Normalization into your work? -The function `normalize_gradient` is implemented based on `torch.autograd` module, which can easily normalize your forward propagation of discriminator by updating a single line. -```python -from torch.nn import BCEWithLogitsLoss -from models.gradnorm import normalize_gradient - -net_D = ... # discriminator -net_G = ... # generator -loss_fn = BCEWithLogitsLoss() - -# Update discriminator -x_real = ... # real data -x_fake = net_G(torch.randn(64, 3, 32, 32)) # fake data -pred_real = normalize_gradient(net_D, x_real) # net_D(x_real) -pred_fake = normalize_gradient(net_D, x_fake) # net_D(x_fake) -loss_real = loss_fn(pred_real, torch.ones_like(pred_real)) -loss_fake = loss_fn(pred_fake, torch.zeros_like(pred_fake)) -(loss_real + loss_fake).backward() # backward propagation -... - -# Update generator -x_fake = net_G(torch.randn(64, 3, 32, 32)) # fake data -pred_fake = normalize_gradient(net_D, x_fake) # net_D(x_fake) -loss_fake = loss_fn(pred_fake, torch.ones_like(pred_fake)) -loss.backward() # backward propagation -... - -``` - -## Citation -If you find our work is relevant to your research, please cite: -``` -@InProceedings{GNGAN_2021_ICCV, - author = {Yi-Lun Wu, Hong-Han Shuai, Zhi Rui Tam, Hong-Yu Chiu}, - title = {Gradient Normalization for Generative Adversarial Networks}, - booktitle = {Proceedings of the IEEE International Conference on Computer Vision (ICCV)}, - month = {Oct}, - year = {2021} -} -``` \ No newline at end of file diff --git a/config/GN-GAN-CR_CIFAR10_BIGGAN.txt b/config/GN-GAN-CR_CIFAR10_BIGGAN.txt index 96f5208..41fb099 100644 --- a/config/GN-GAN-CR_CIFAR10_BIGGAN.txt +++ b/config/GN-GAN-CR_CIFAR10_BIGGAN.txt @@ -5,7 +5,7 @@ --lr_decay_start=125000 --batch_size_D=50 --batch_size_G=50 ---num_workers=8 +--num_workers=10 --lr_D=0.0002 --lr_G=0.0001 --betas=0.0 diff --git a/config/GN-GAN-CR_CIFAR10_CNN.txt b/config/GN-GAN-CR_CIFAR10_CNN.txt index b9edbf4..be68522 100644 --- a/config/GN-GAN-CR_CIFAR10_CNN.txt +++ b/config/GN-GAN-CR_CIFAR10_CNN.txt @@ -5,7 +5,7 @@ --lr_decay_start=200000 --batch_size_D=64 --batch_size_G=128 ---num_workers=8 +--num_workers=10 --lr_D=0.0002 --lr_G=0.0002 --n_dis=1 diff --git a/config/GN-GAN-CR_CIFAR10_RES.txt b/config/GN-GAN-CR_CIFAR10_RES.txt index 8bdf67e..1cbec9d 100644 --- a/config/GN-GAN-CR_CIFAR10_RES.txt +++ b/config/GN-GAN-CR_CIFAR10_RES.txt @@ -5,7 +5,7 @@ --lr_decay_start=0 --batch_size_D=64 --batch_size_G=128 ---num_workers=8 +--num_workers=10 --lr_D=0.0004 --lr_G=0.0002 --n_dis=5 diff --git a/config/GN-GAN-CR_STL10_CNN.txt b/config/GN-GAN-CR_STL10_CNN.txt index 2b4b8ac..64c5822 100644 --- a/config/GN-GAN-CR_STL10_CNN.txt +++ b/config/GN-GAN-CR_STL10_CNN.txt @@ -5,7 +5,7 @@ --lr_decay_start=200000 --batch_size_D=64 --batch_size_G=128 ---num_workers=8 +--num_workers=10 --lr_D=0.0002 --lr_G=0.0002 --n_dis=1 diff --git a/config/GN-GAN-CR_STL10_CNN_MODULE.txt b/config/GN-GAN-CR_STL10_CNN_MODULE.txt new file mode 100644 index 0000000..64c5822 --- /dev/null +++ b/config/GN-GAN-CR_STL10_CNN_MODULE.txt @@ -0,0 +1,25 @@ +--dataset=stl10.48 +--arch=dcgan.48 +--loss=bce +--total_steps=200000 +--lr_decay_start=200000 +--batch_size_D=64 +--batch_size_G=128 +--num_workers=10 +--lr_D=0.0002 +--lr_G=0.0002 +--n_dis=1 +--z_dim=128 +--cr=5 +--n_classes=1 + +--ema_decay=0.9999 +--ema_start=0 + +--sample_step=500 +--sample_size=64 +--eval_step=5000 +--save_step=20000 +--num_images=50000 +--fid_stats=./stats/stl10.unlabeled.48.npz +--logdir=./logs/GN-GAN-CR_STL10_CNN_0 diff --git a/config/GN-GAN-CR_STL10_RES.txt b/config/GN-GAN-CR_STL10_RES.txt index ecb2a6d..59f0a6c 100644 --- a/config/GN-GAN-CR_STL10_RES.txt +++ b/config/GN-GAN-CR_STL10_RES.txt @@ -5,7 +5,7 @@ --lr_decay_start=0 --batch_size_D=64 --batch_size_G=128 ---num_workers=8 +--num_workers=10 --lr_D=0.0004 --lr_G=0.0002 --n_dis=5 diff --git a/config/GN-GAN_CELEBAHQ128_RES.txt b/config/GN-GAN_CELEBAHQ128_RES.txt index 5bc6c14..9218e51 100644 --- a/config/GN-GAN_CELEBAHQ128_RES.txt +++ b/config/GN-GAN_CELEBAHQ128_RES.txt @@ -4,7 +4,7 @@ --batch_size_D=64 --batch_size_G=128 --accumulation=1 ---num_workers=8 +--num_workers=10 --lr_D=0.0002 --lr_G=0.0002 --n_dis=5 diff --git a/config/GN-GAN_CELEBAHQ256_RES.txt b/config/GN-GAN_CELEBAHQ256_RES.txt index ba5c995..ea820a6 100644 --- a/config/GN-GAN_CELEBAHQ256_RES.txt +++ b/config/GN-GAN_CELEBAHQ256_RES.txt @@ -4,7 +4,7 @@ --batch_size_D=64 --batch_size_G=128 --accumulation=1 ---num_workers=8 +--num_workers=10 --lr_D=0.0002 --lr_G=0.0002 --n_dis=5 diff --git a/config/GN-GAN_CHURCH256_RES.txt b/config/GN-GAN_CHURCH256_RES.txt index 8a84438..c90e41d 100644 --- a/config/GN-GAN_CHURCH256_RES.txt +++ b/config/GN-GAN_CHURCH256_RES.txt @@ -4,7 +4,7 @@ --batch_size_D=64 --batch_size_G=128 --accumulation=1 ---num_workers=8 +--num_workers=10 --lr_D=0.0002 --lr_G=0.0002 --n_dis=5 diff --git a/config/GN-GAN_CIFAR10_BIGGAN.txt b/config/GN-GAN_CIFAR10_BIGGAN.txt index 1ed4053..c807fb3 100644 --- a/config/GN-GAN_CIFAR10_BIGGAN.txt +++ b/config/GN-GAN_CIFAR10_BIGGAN.txt @@ -1,8 +1,8 @@ --dataset=cifar10.32 --arch=biggan.32 --loss=hinge ---total_steps=125000 ---lr_decay_start=125000 +--total_steps=75000 +--lr_decay_start=75000 --batch_size_D=50 --batch_size_G=50 --num_workers=8 diff --git a/config/GN-GAN_CIFAR10_BIGGAN_MODULE.txt b/config/GN-GAN_CIFAR10_BIGGAN_MODULE.txt new file mode 100644 index 0000000..88db8a5 --- /dev/null +++ b/config/GN-GAN_CIFAR10_BIGGAN_MODULE.txt @@ -0,0 +1,27 @@ +--dataset=cifar10.32 +--arch=biggan.32 +--loss=hinge +--total_steps=75000 +--lr_decay_start=75000 +--batch_size_D=50 +--batch_size_G=50 +--num_workers=8 +--lr_D=0.0002 +--lr_G=0.0001 +--betas=0.0 +--betas=0.999 +--n_dis=4 +--z_dim=128 +--cr=0 +--n_classes=10 + +--ema_decay=0.9999 +--ema_start=1000 + +--sample_step=500 +--sample_size=64 +--eval_step=5000 +--save_step=20000 +--num_images=50000 +--fid_stats=./stats/cifar10.train.npz +--logdir=./logs/GN-GAN_CIFAR10_BIGGAN_MODULE_0 diff --git a/config/GN-GAN_CIFAR10_CNN.txt b/config/GN-GAN_CIFAR10_CNN.txt index 58c3ff4..734d26f 100644 --- a/config/GN-GAN_CIFAR10_CNN.txt +++ b/config/GN-GAN_CIFAR10_CNN.txt @@ -5,7 +5,7 @@ --lr_decay_start=200000 --batch_size_D=64 --batch_size_G=128 ---num_workers=8 +--num_workers=10 --lr_D=0.0002 --lr_G=0.0002 --n_dis=1 diff --git a/config/GN-GAN_CIFAR10_CNN_MODULE.txt b/config/GN-GAN_CIFAR10_CNN_MODULE.txt new file mode 100644 index 0000000..3ad5aa5 --- /dev/null +++ b/config/GN-GAN_CIFAR10_CNN_MODULE.txt @@ -0,0 +1,25 @@ +--dataset=cifar10.32 +--arch=dcgan.32 +--loss=bce +--total_steps=200000 +--lr_decay_start=200000 +--batch_size_D=64 +--batch_size_G=128 +--num_workers=10 +--lr_D=0.0002 +--lr_G=0.0002 +--n_dis=1 +--z_dim=128 +--cr=0 +--n_classes=1 + +--ema_decay=0.9999 +--ema_start=0 + +--sample_step=500 +--sample_size=64 +--eval_step=5000 +--save_step=20000 +--num_images=50000 +--fid_stats=./stats/cifar10.train.npz +--logdir=./logs/GN-GAN_CIFAR10_CNN_MODULE_0 diff --git a/config/GN-GAN_CIFAR10_RES.txt b/config/GN-GAN_CIFAR10_RES.txt index 9dddafa..4310265 100644 --- a/config/GN-GAN_CIFAR10_RES.txt +++ b/config/GN-GAN_CIFAR10_RES.txt @@ -5,7 +5,7 @@ --lr_decay_start=0 --batch_size_D=64 --batch_size_G=128 ---num_workers=8 +--num_workers=10 --lr_D=0.0004 --lr_G=0.0002 --n_dis=5 diff --git a/config/GN-GAN_STL10_CNN.txt b/config/GN-GAN_STL10_CNN.txt index 6908ac0..6c9f938 100644 --- a/config/GN-GAN_STL10_CNN.txt +++ b/config/GN-GAN_STL10_CNN.txt @@ -5,7 +5,7 @@ --lr_decay_start=200000 --batch_size_D=64 --batch_size_G=128 ---num_workers=8 +--num_workers=10 --lr_D=0.0002 --lr_G=0.0002 --n_dis=1 diff --git a/config/GN-GAN_STL10_CNN_MODULE.txt b/config/GN-GAN_STL10_CNN_MODULE.txt new file mode 100644 index 0000000..3dcc243 --- /dev/null +++ b/config/GN-GAN_STL10_CNN_MODULE.txt @@ -0,0 +1,25 @@ +--dataset=stl10.48 +--arch=dcgan.48 +--loss=bce +--total_steps=200000 +--lr_decay_start=200000 +--batch_size_D=64 +--batch_size_G=128 +--num_workers=10 +--lr_D=0.0002 +--lr_G=0.0002 +--n_dis=1 +--z_dim=128 +--cr=0 +--n_classes=1 + +--ema_decay=0.9999 +--ema_start=0 + +--sample_step=500 +--sample_size=64 +--eval_step=5000 +--save_step=20000 +--num_images=50000 +--fid_stats=./stats/stl10.unlabeled.48.npz +--logdir=./logs/GN-GAN_STL10_CNN_MODULE_0 diff --git a/config/GN-GAN_STL10_RES.txt b/config/GN-GAN_STL10_RES.txt index 6c4cfdb..a00e627 100644 --- a/config/GN-GAN_STL10_RES.txt +++ b/config/GN-GAN_STL10_RES.txt @@ -5,7 +5,7 @@ --lr_decay_start=0 --batch_size_D=64 --batch_size_G=128 ---num_workers=8 +--num_workers=10 --lr_D=0.0004 --lr_G=0.0002 --n_dis=5 diff --git a/generated/GN-GAN_CIFAR10_CNN_0/Results/results.txt b/generated/GN-GAN_CIFAR10_CNN_0/Results/results.txt new file mode 100644 index 0000000..48a3c0b --- /dev/null +++ b/generated/GN-GAN_CIFAR10_CNN_0/Results/results.txt @@ -0,0 +1 @@ +IS: 7.683(0.082), FID: 21.932 \ No newline at end of file diff --git a/models/biggan_module.py b/models/biggan_module.py new file mode 100644 index 0000000..c0b9c57 --- /dev/null +++ b/models/biggan_module.py @@ -0,0 +1,279 @@ +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from models.gradnorm import GradNorm + +sn = partial(torch.nn.utils.spectral_norm, eps=1e-6) + + +class Attention(nn.Module): + """ + SA-GAN: https://arxiv.org/abs/1805.08318 + """ + def __init__(self, ch, use_spectral_norm): + super().__init__() + if use_spectral_norm: + spectral_norm = sn + else: + spectral_norm = (lambda x: x) + self.q = spectral_norm(nn.Conv2d( + ch, ch // 8, kernel_size=1, padding=0, bias=False)) + self.k = spectral_norm(nn.Conv2d( + ch, ch // 8, kernel_size=1, padding=0, bias=False)) + self.v = spectral_norm(nn.Conv2d( + ch, ch // 2, kernel_size=1, padding=0, bias=False)) + self.o = spectral_norm(nn.Conv2d( + ch // 2, ch, kernel_size=1, padding=0, bias=False)) + self.gamma = nn.Parameter(torch.tensor(0.), requires_grad=True) + + def forward(self, x, y=None): + B, C, H, W = x.size() + q = self.q(x) + k = F.max_pool2d(self.k(x), [2, 2]) + v = F.max_pool2d(self.v(x), [2, 2]) + # flatten + q = q.view(B, C // 8, H * W) # query + k = k.view(B, C // 8, H * W // 4) # key + v = v.view(B, C // 2, H * W // 4) # value + # attention weights + w = F.softmax(torch.bmm(q.transpose(1, 2), k), -1) + # attend and project + o = self.o(torch.bmm(v, w.transpose(1, 2)).view(B, C // 2, H, W)) + return self.gamma * o + x + + +class ConditionalBatchNorm2d(nn.Module): + def __init__(self, in_channel, cond_size, linear=True): + super().__init__() + if linear: + self.gain = sn(nn.Linear(cond_size, in_channel, bias=False)) + self.bias = sn(nn.Linear(cond_size, in_channel, bias=False)) + else: + self.gain = nn.Embedding(cond_size, in_channel) + self.bias = nn.Embedding(cond_size, in_channel) + self.batchnorm2d = nn.BatchNorm2d(in_channel, affine=False) + + def forward(self, x, y): + gain = self.gain(y).view(y.size(0), -1, 1, 1) + 1 + bias = self.bias(y).view(y.size(0), -1, 1, 1) + x = self.batchnorm2d(x) + return x * gain + bias + + +class GenBlock(nn.Module): + def __init__(self, in_channels, out_channels, cbn_in_dim, cbn_linear=True): + """ + cbn_in_dim(int): output size of shared embedding + cbn_linear(bool): use linear layer in conditional batchnorm to + get gain and bias of normalization. Otherwise, + use embedding. + """ + super().__init__() + + # residual + self.bn1 = ConditionalBatchNorm2d(in_channels, cbn_in_dim, cbn_linear) + self.residual1 = nn.Sequential( + nn.ReLU(inplace=True), + nn.Upsample(scale_factor=2), + sn(nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1))) + self.bn2 = ConditionalBatchNorm2d(out_channels, cbn_in_dim, cbn_linear) + self.residual2 = nn.Sequential( + nn.ReLU(inplace=True), + sn(nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1))) + + # shortcut + self.shortcut = nn.Sequential( + nn.Upsample(scale_factor=2), + sn(nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0))) + + def forward(self, x, y): + h = self.residual1(self.bn1(x, y)) + h = self.residual2(self.bn2(h, y)) + return h + self.shortcut(x) + + +class Generator32(nn.Module): + def __init__(self, z_dim=128, n_classes=10, ch=64): + super().__init__() + # channels_multipler = [4, 4, 4, 4] + self.linear = sn(nn.Linear(z_dim, (ch * 4) * 4 * 4)) + self.blocks = nn.ModuleList([ + GenBlock(ch * 4, ch * 4, n_classes, False), # 4ch x 8 x 8 + GenBlock(ch * 4, ch * 4, n_classes, False), # 4ch x 16 x 16 + GenBlock(ch * 4, ch * 4, n_classes, False), # 4ch x 32 x 32 + ]) + self.output_layer = nn.Sequential( + nn.BatchNorm2d(ch * 4), + nn.ReLU(inplace=True), + sn(nn.Conv2d(ch * 4, 3, 3, padding=1)), # 3 x 32 x 32 + nn.Tanh()) + res32_weights_init(self) + + def forward(self, z, y): + h = self.linear(z).view(z.size(0), -1, 4, 4) + for block in self.blocks: + h = block(h, y) + h = self.output_layer(h) + return h + + +class Generator128(nn.Module): + def __init__(self, z_dim=128, n_classes=1000, ch=96, shared_dim=128): + super().__init__() + channels_multipler = [16, 16, 8, 4, 2, 1] + num_slots = len(channels_multipler) + self.chunk_size = (z_dim // num_slots) + z_dim = self.chunk_size * num_slots + cbn_in_dim = (shared_dim + self.chunk_size) + + self.shared_embedding = nn.Embedding(n_classes, shared_dim) + self.linear = sn(nn.Linear(z_dim // num_slots, (ch * 16) * 4 * 4)) + + self.blocks = nn.ModuleList([ + GenBlock(ch * 16, ch * 16, cbn_in_dim), # ch*16 x 4 x 4 + GenBlock(ch * 16, ch * 8, cbn_in_dim), # ch*16 x 8 x 8 + GenBlock(ch * 8, ch * 4, cbn_in_dim), # ch*8 x 16 x 16 + nn.ModuleList([ # ch*4 x 32 x 32 + GenBlock(ch * 4, ch * 2, cbn_in_dim), + Attention(ch * 2, True), # ch*2 x 64 x 64 + ]), + GenBlock(ch * 2, ch * 1, cbn_in_dim), # ch*1 x 128 x 128 + ]) + + self.output_layer = nn.Sequential( + nn.BatchNorm2d(ch * 1), + nn.ReLU(inplace=True), + sn(nn.Conv2d(ch * 1, 3, 3, padding=1)), # 3 x 128 x 128 + nn.Tanh()) + # res128_weights_init(self) + + def forward(self, z, y): + y = self.shared_embedding(y) + zs = torch.split(z, self.chunk_size, 1) + ys = [torch.cat([y, item], 1) for item in zs[1:]] + + h = self.linear(zs[0]).view(z.size(0), -1, 4, 4) + for i, block in enumerate(self.blocks): + if isinstance(block, nn.ModuleList): + for module in block: + h = module(h, ys[i]) + else: + h = block(h, ys[i]) + h = self.output_layer(h) + + return h + + +class OptimizedDisblock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.shortcut = GradNorm(nn.Sequential( + nn.AvgPool2d(2), + nn.Conv2d(in_channels, out_channels, 1, padding=0))) + self.residual = nn.Sequential( + nn.Conv2d(in_channels, out_channels, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, 3, padding=1), + nn.AvgPool2d(2)) + + def forward(self, x): + return self.residual(x) + self.shortcut(x) + + +class DisBlock(nn.Module): + def __init__(self, in_channels, out_channels, down=False): + super().__init__() + shortcut = [] + if in_channels != out_channels or down: + shortcut.append(nn.Conv2d(in_channels, out_channels, 1, 1, 0)) + if down: + shortcut.append(nn.AvgPool2d(2)) + self.shortcut = nn.Sequential(*shortcut) + + residual = [ + nn.ReLU(), + nn.Conv2d(in_channels, out_channels, 3, 1, 1), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, 3, 1, 1), + ] + if down: + residual.append(nn.AvgPool2d(2)) + self.residual = nn.Sequential(*residual) + #self.gradnorm = GradNorm(DisBlock(out_channels, out_channels)) + + def forward(self, x): + x = self.residual(x) + self.shortcut(x) + #x = self.gradnorm(x) + return x + + +class Discriminator32(nn.Module): + def __init__(self, n_classes=10, ch=64): + super().__init__() + self.fp16 = False + # channels_multipler = [2, 2, 2, 2] + self.blocks = nn.Sequential( + OptimizedDisblock(3, ch * 4), # 3 x 32 x 32 + DisBlock(ch * 4, ch * 4, down=True), # ch*4 x 16 x 16 + DisBlock(ch * 4, ch * 4), # ch*4 x 8 x 8 + DisBlock(ch * 4, ch * 4), # ch*4 x 8 x 8 + nn.ReLU(inplace=True), + ) + #self.gradnorm = GradNorm(Discriminator32()) + self.linear = nn.Linear(ch * 4, 1) + self.embedding = nn.Embedding(n_classes, ch * 4) + res32_weights_init(self) + + def forward(self, x, y): + h = self.blocks(x).sum(dim=[2, 3]) + #h = self.gradnorm(h) + h = self.linear(h) + (self.embedding(y) * h).sum(dim=1, keepdim=True) + return h + + +class Discriminator128(nn.Module): + def __init__(self, n_classes=1000, ch=96): + super().__init__() + # channels_multipler = [1, 2, 4, 8, 16, 16] + self.blocks = nn.Sequential( + OptimizedDisblock(3, ch * 1), # 3 x 128 x 128 + Attention(ch, False), # ch*1 x 64 x 64 + DisBlock(ch * 1, ch * 2, down=True), # ch*1 x 32 x 32 + GradNorm(Discriminator128(ch * 2)), + DisBlock(ch * 2, ch * 4, down=True), # ch*2 x 16 x 16 + GradNorm(Discriminator128(ch * 4)), + DisBlock(ch * 4, ch * 8, down=True), # ch*4 x 8 x 8 + GradNorm(Discriminator128(ch * 8)), + DisBlock(ch * 8, ch * 16, down=True), # ch*8 x 4 x 4 + GradNorm(Discriminator128(ch * 16)), + DisBlock(ch * 16, ch * 16), # ch*16 x 4 x 4 + GradNorm(Discriminator128(ch * 16)), + nn.ReLU(inplace=True), # ch*16 x 4 x 4 + ) + self.gradnorm = GradNorm(Discriminator128(ch * 16)) + self.linear = nn.Linear(ch * 16, 1) + self.embedding = nn.Embedding(n_classes, ch * 16) + # res128_weights_init(self) + + def forward(self, x, y): + h = self.blocks(x).sum(dim=[2, 3]) + h = self.gradnorm(h) + h = self.linear(h) + (self.embedding(y) * h).sum(dim=1, keepdim=True) + return h + + +def res32_weights_init(m): + for name, module in m.named_modules(): + if isinstance(module, (nn.Conv2d, nn.Linear, nn.Embedding)): + torch.nn.init.xavier_uniform_(module.weight) + if hasattr(module, 'bias') and module.bias is not None: + torch.nn.init.zeros_(module.bias) + + +def res128_weights_init(m): + for module in m.modules(): + if isinstance(module, (nn.Conv2d, nn.Linear, nn.Embedding)): + torch.nn.init.orthogonal_(module.weight) diff --git a/models/dcgan_module.py b/models/dcgan_module.py new file mode 100644 index 0000000..173168e --- /dev/null +++ b/models/dcgan_module.py @@ -0,0 +1,131 @@ +import torch +import torch.nn as nn +import torch.nn.init as init +from models.gradnorm import GradNorm + + +class Generator(nn.Module): + def __init__(self, z_dim, M=4): + super().__init__() + self.M = M + self.linear = nn.Linear(z_dim, M * M * 512) + self.main = nn.Sequential( + nn.BatchNorm2d(512), + nn.ReLU(True), + nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(256), + nn.ReLU(True), + nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(128), + nn.ReLU(True), + nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(64), + nn.ReLU(True), + nn.ConvTranspose2d(64, 3, kernel_size=3, stride=1, padding=1), + nn.Tanh()) + self.initialize() + + def initialize(self): + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)): + init.normal_(m.weight, std=0.02) + init.zeros_(m.bias) + + def forward(self, z, *args, **kwargs): + x = self.linear(z) + x = x.view(x.size(0), -1, self.M, self.M) + x = self.main(x) + return x + + +class Discriminator(nn.Module): + def __init__(self, M=32): + super().__init__() + self.M = M + + # M + self.block1 = GradNorm(nn.Sequential( + nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.1, inplace=False) + )) + self.block2 = GradNorm(nn.Sequential( + nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), + nn.LeakyReLU(0.1, inplace=False) + )) + # M / 2 + self.block3 = GradNorm(nn.Sequential( + nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.1, inplace=False) + )) + self.block4 = GradNorm(nn.Sequential( + nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), + nn.LeakyReLU(0.1, inplace=False) + )) + # M / 4 + self.block5 = GradNorm(nn.Sequential( + nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.1, inplace=False) + )) + self.block6 = GradNorm(nn.Sequential( + nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), + nn.LeakyReLU(0.1, inplace=False) + )) + # M / 8 + self.block7 = GradNorm(nn.Sequential( + nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.1, inplace=False) + )) + + self.linear = GradNorm(nn.Linear(M // 8 * M // 8 * 512, 1)) + self.initialize() + + def initialize(self): + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.Linear)): + init.normal_(m.weight, std=0.02) + init.zeros_(m.bias) + + def rescale_weight(self, min_norm=1.0, max_norm=1.33): + a = 1.0 + with torch.no_grad(): + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.Linear)): + w_norm = m.weight.norm(p=2) + print(m, w_norm) + w_norm = max(w_norm, min_norm) + w_norm = min(w_norm, max_norm) + a = a * w_norm + m.weight.data.div_(w_norm) + m.bias.data.div_(a) + + def forward(self, x, *args, **kwargs): + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + x = self.block4(x) + x = self.block5(x) + x = self.block6(x) + x = self.block7(x) + x = torch.flatten(x, start_dim=1) + x = self.linear(x) + return x + + +class Generator32(Generator): + def __init__(self, z_dim, *args): + super().__init__(z_dim, M=4) + + +class Generator48(Generator): + def __init__(self, z_dim, *args): + super().__init__(z_dim, M=6) + + +class Discriminator32(Discriminator): + def __init__(self, *args): + super().__init__(M=32) + + +class Discriminator48(Discriminator): + def __init__(self, *args): + super().__init__(M=48) diff --git a/models/gradnorm.py b/models/gradnorm.py index 8b4d1ce..f149cc5 100644 --- a/models/gradnorm.py +++ b/models/gradnorm.py @@ -1,5 +1,5 @@ import torch - +import torch.nn as nn def normalize_gradient(net_D, x, **kwargs): """ @@ -11,7 +11,32 @@ def normalize_gradient(net_D, x, **kwargs): f = net_D(x, **kwargs) grad = torch.autograd.grad( f, [x], torch.ones_like(f), create_graph=True, retain_graph=True)[0] - grad_norm = torch.norm(torch.flatten(grad, start_dim=1), p=2, dim=1) + grad_norm = torch.norm(torch.flatten(grad, start_dim=1).cuda(), p=2, dim=1) grad_norm = grad_norm.view(-1, *[1 for _ in range(len(f.shape) - 1)]) - f_hat = (f / (grad_norm + torch.abs(f))) + f_hat = (f / (grad_norm.cuda() + torch.abs(f))) return f_hat + +def get_gradient(net_D, x, **kwargs): + f = net_D(x, **kwargs) + return f + +class GradNorm(nn.Module): + def __init__(self, module, out_channels = None): + super().__init__() + self.module = module + self.out_channels = out_channels + + def forward(self, input): + """ + f + f_hat = -------------------- + || grad_f || + | f | + """ + input.requires_grad_(True) + f = self.module(input) + grad = torch.autograd.grad( + f, [input], torch.ones_like(f), create_graph=True, retain_graph=True)[0] + grad_norm = torch.norm(torch.flatten(grad, start_dim=1), p=2, dim=1) + grad_norm = grad_norm.view(-1, *[1 for _ in range(len(f.shape) - 1)]) + f_hat = (f / (grad_norm + torch.abs(f))) + return f_hat diff --git a/models/resnet_module.py b/models/resnet_module.py new file mode 100644 index 0000000..64bd002 --- /dev/null +++ b/models/resnet_module.py @@ -0,0 +1,334 @@ +import torch +import torch.nn as nn +import torch.nn.init as init + + +class GenBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + # shortcut + self.shortcut = nn.Sequential( + nn.Upsample(scale_factor=2), + nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0) + ) + # residual + self.residual = nn.Sequential( + nn.BatchNorm2d(in_channels), + nn.ReLU(True), + nn.Upsample(scale_factor=2), + nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(True), + nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1), + ) + # initialize weight + self.initialize() + + def initialize(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight) + init.zeros_(m.bias) + + def forward(self, x): + return self.residual(x) + self.shortcut(x) + + +class ResGenerator32(nn.Module): + def __init__(self, z_dim, *args): + super().__init__() + self.linear = nn.Linear(z_dim, 4 * 4 * 256) + self.blocks = nn.Sequential( + GenBlock(256, 256), + GenBlock(256, 256), + GenBlock(256, 256), + ) + self.output = nn.Sequential( + nn.BatchNorm2d(256), + nn.ReLU(True), + nn.Conv2d(256, 3, 3, stride=1, padding=1), + nn.Tanh(), + ) + # initialize weight + self.initialize() + + def initialize(self): + init.kaiming_normal_(self.linear.weight) + init.zeros_(self.linear.bias) + for m in self.output.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight) + init.zeros_(m.bias) + + def forward(self, z, *args, **kwargs): + z = self.linear(z) + z = z.view(-1, 256, 4, 4) + return self.output(self.blocks(z)) + + +class ResGenerator48(nn.Module): + def __init__(self, z_dim, *args): + super().__init__() + self.linear = nn.Linear(z_dim, 6 * 6 * 512) + self.blocks = nn.Sequential( + GenBlock(512, 256), + GenBlock(256, 128), + GenBlock(128, 64), + ) + self.output = nn.Sequential( + nn.BatchNorm2d(64), + nn.ReLU(True), + nn.Conv2d(64, 3, 3, stride=1, padding=1), + nn.Tanh(), + ) + # initialize weight + self.initialize() + + def initialize(self): + init.kaiming_normal_(self.linear.weight) + init.zeros_(self.linear.bias) + for m in self.output.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight) + init.zeros_(m.bias) + + def forward(self, z, *args, **kwargs): + z = self.linear(z) + z = z.view(-1, 512, 6, 6) + return self.output(self.blocks(z)) + + +class ResGenerator128(nn.Module): + def __init__(self, z_dim): + super().__init__() + self.linear = nn.Linear(z_dim, 4 * 4 * 1024) + + self.blocks = nn.Sequential( + GenBlock(1024, 1024), + GenBlock(1024, 512), + GenBlock(512, 256), + GenBlock(256, 128), + GenBlock(128, 64), + ) + self.output = nn.Sequential( + nn.BatchNorm2d(64), + nn.ReLU(True), + nn.Conv2d(64, 3, 3, stride=1, padding=1), + nn.Tanh(), + ) + # initialize weight + self.initialize() + + def initialize(self): + init.kaiming_normal_(self.linear.weight) + init.zeros_(self.linear.bias) + for m in self.output.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight) + init.zeros_(m.bias) + + def forward(self, z): + inputs = self.linear(z) + inputs = inputs.view(-1, 1024, 4, 4) + return self.output(self.blocks(inputs)) + + +class ResGenerator256(nn.Module): + def __init__(self, z_dim): + super().__init__() + self.linear = nn.Linear(z_dim, 4 * 4 * 1024) + + self.blocks = nn.Sequential( + GenBlock(1024, 1024), + GenBlock(1024, 512), + GenBlock(512, 512), + GenBlock(512, 256), + GenBlock(256, 128), + GenBlock(128, 64), + ) + self.output = nn.Sequential( + nn.BatchNorm2d(64), + nn.ReLU(True), + nn.Conv2d(64, 3, 3, stride=1, padding=1), + nn.Tanh(), + ) + # initialize weight + self.initialize() + + def initialize(self): + init.kaiming_normal_(self.linear.weight) + init.zeros_(self.linear.bias) + for m in self.output.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight) + init.zeros_(m.bias) + + def forward(self, z): + inputs = self.linear(z) + inputs = inputs.view(-1, 1024, 4, 4) + return self.output(self.blocks(inputs)) + + +class OptimizedDisblock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + # shortcut + self.shortcut = nn.Sequential( + nn.AvgPool2d(2), + nn.Conv2d(in_channels, out_channels, 1, 1, 0)) + # residual + self.residual = nn.Sequential( + nn.Conv2d(in_channels, out_channels, 3, 1, 1), + nn.ReLU(True), + nn.Conv2d(out_channels, out_channels, 3, 1, 1), + nn.AvgPool2d(2)) + # initialize weight + self.initialize() + + def initialize(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight) + init.zeros_(m.bias) + + def forward(self, x): + return self.residual(x) + self.shortcut(x) + + +class DisBlock(nn.Module): + def __init__(self, in_channels, out_channels, down=False): + super().__init__() + # shortcut + shortcut = [] + if in_channels != out_channels or down: + shortcut.append( + nn.Conv2d(in_channels, out_channels, 1, 1, 0)) + if down: + shortcut.append(nn.AvgPool2d(2)) + self.shortcut = nn.Sequential(*shortcut) + # residual + residual = [ + nn.ReLU(), + nn.Conv2d(in_channels, out_channels, 3, 1, 1), + nn.ReLU(True), + nn.Conv2d(out_channels, out_channels, 3, 1, 1), + ] + if down: + residual.append(nn.AvgPool2d(2)) + self.residual = nn.Sequential(*residual) + # initialize weight + self.initialize() + + def initialize(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight) + init.zeros_(m.bias) + + def forward(self, x): + return (self.residual(x) + self.shortcut(x)) + + +class ResDiscriminator32(nn.Module): + def __init__(self, *args): + super().__init__() + self.model = nn.Sequential( + OptimizedDisblock(3, 128), + DisBlock(128, 128, down=True), + DisBlock(128, 128), + DisBlock(128, 128), + nn.ReLU(True), + nn.AdaptiveAvgPool2d((1, 1))) + self.linear = nn.Linear(128, 1) + # initialize weight + self.initialize() + + def initialize(self): + init.kaiming_normal_(self.linear.weight) + init.zeros_(self.linear.bias) + + def forward(self, x, *args, **kwargs): + x = self.model(x) + x = torch.flatten(x, start_dim=1) + x = self.linear(x) + return x + + +class ResDiscriminator48(nn.Module): + def __init__(self, *args): + super().__init__() + self.model = nn.Sequential( + OptimizedDisblock(3, 64), + DisBlock(64, 128, down=True), + DisBlock(128, 256, down=True), + DisBlock(256, 512, down=True), + nn.ReLU(True), + nn.AdaptiveAvgPool2d((1, 1))) + self.linear = nn.Linear(512, 1) + # initialize weight + self.initialize() + + def initialize(self): + init.kaiming_normal_(self.linear.weight) + init.zeros_(self.linear.bias) + + def forward(self, x, *args, **kwargs): + x = self.model(x) + x = torch.flatten(x, start_dim=1) + x = self.linear(x) + return x + + +class ResDiscriminator128(nn.Module): + def __init__(self): + super().__init__() + self.model = nn.Sequential( + OptimizedDisblock(3, 64), + DisBlock(64, 128, down=True), + DisBlock(128, 256, down=True), + DisBlock(256, 512, down=True), + DisBlock(512, 1024, down=True), + DisBlock(1024, 1024), + nn.ReLU(True), + nn.AdaptiveAvgPool2d((1, 1))) + self.linear = nn.Linear(1024, 1) + # initialize weight + self.initialize() + + def initialize(self): + init.kaiming_normal_(self.linear.weight) + init.zeros_(self.linear.bias) + + def forward(self, x): + x = self.model(x) + x = torch.flatten(x, start_dim=1) + x = self.linear(x) + return x + + +class ResDiscriminator256(nn.Module): + def __init__(self): + super().__init__() + self.model = nn.Sequential( + OptimizedDisblock(3, 64), + DisBlock(64, 128, down=True), + DisBlock(128, 256, down=True), + DisBlock(256, 512, down=True), + DisBlock(512, 512, down=True), + DisBlock(512, 1024, down=True), + DisBlock(1024, 1024), + nn.ReLU(True), + nn.AdaptiveAvgPool2d((1, 1))) + self.linear = nn.Linear(1024, 1) + # initialize weight + self.initialize() + + def initialize(self): + init.kaiming_normal_(self.linear.weight) + init.zeros_(self.linear.bias) + + def forward(self, x): + x = self.model(x) + x = torch.flatten(x, start_dim=1) + x = self.linear(x) + return x diff --git a/output.txt b/output.txt new file mode 100644 index 0000000..e69de29 diff --git a/sn_testing/GN-GAN-CR_STL10_CNN.txt b/sn_testing/GN-GAN-CR_STL10_CNN.txt new file mode 100644 index 0000000..64c5822 --- /dev/null +++ b/sn_testing/GN-GAN-CR_STL10_CNN.txt @@ -0,0 +1,25 @@ +--dataset=stl10.48 +--arch=dcgan.48 +--loss=bce +--total_steps=200000 +--lr_decay_start=200000 +--batch_size_D=64 +--batch_size_G=128 +--num_workers=10 +--lr_D=0.0002 +--lr_G=0.0002 +--n_dis=1 +--z_dim=128 +--cr=5 +--n_classes=1 + +--ema_decay=0.9999 +--ema_start=0 + +--sample_step=500 +--sample_size=64 +--eval_step=5000 +--save_step=20000 +--num_images=50000 +--fid_stats=./stats/stl10.unlabeled.48.npz +--logdir=./logs/GN-GAN-CR_STL10_CNN_0 diff --git a/sn_testing/SN_models/biggan.py b/sn_testing/SN_models/biggan.py new file mode 100644 index 0000000..5a448cf --- /dev/null +++ b/sn_testing/SN_models/biggan.py @@ -0,0 +1,268 @@ +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +sn = partial(torch.nn.utils.spectral_norm, eps=1e-5) + + +class Attention(nn.Module): + """ + SA-GAN: https://arxiv.org/abs/1805.08318 + """ + def __init__(self, ch, use_spectral_norm): + super().__init__() + if use_spectral_norm: + spectral_norm = sn + else: + spectral_norm = (lambda x: x) + self.q = spectral_norm(nn.Conv2d( + ch, ch // 8, kernel_size=1, padding=0, bias=False)) + self.k = spectral_norm(nn.Conv2d( + ch, ch // 8, kernel_size=1, padding=0, bias=False)) + self.v = spectral_norm(nn.Conv2d( + ch, ch // 2, kernel_size=1, padding=0, bias=False)) + self.o = spectral_norm(nn.Conv2d( + ch // 2, ch, kernel_size=1, padding=0, bias=False)) + self.gamma = nn.Parameter(torch.tensor(0.), requires_grad=True) + + def forward(self, x, y=None): + B, C, H, W = x.size() + q = self.q(x) + k = F.max_pool2d(self.k(x), [2, 2]) + v = F.max_pool2d(self.v(x), [2, 2]) + # flatten + q = q.view(B, C // 8, H * W) # query + k = k.view(B, C // 8, H * W // 4) # key + v = v.view(B, C // 2, H * W // 4) # value + # attention weights + w = F.softmax(torch.bmm(q.transpose(1, 2), k), -1) + # attend and project + o = self.o(torch.bmm(v, w.transpose(1, 2)).view(B, C // 2, H, W)) + return self.gamma * o + x + + +class ConditionalBatchNorm2d(nn.Module): + def __init__(self, in_channel, cond_size, linear=True): + super().__init__() + if linear: + self.gain = sn(nn.Linear(cond_size, in_channel, bias=False)) + self.bias = sn(nn.Linear(cond_size, in_channel, bias=False)) + else: + self.gain = nn.Embedding(cond_size, in_channel) + self.bias = nn.Embedding(cond_size, in_channel) + self.batchnorm2d = nn.BatchNorm2d(in_channel, affine=False) + + def forward(self, x, y): + gain = self.gain(y).view(y.size(0), -1, 1, 1) + 1 + bias = self.bias(y).view(y.size(0), -1, 1, 1) + x = self.batchnorm2d(x) + return x * gain + bias + + +class GenBlock(nn.Module): + def __init__(self, in_channels, out_channels, cbn_in_dim, cbn_linear=True): + """ + cbn_in_dim(int): output size of shared embedding + cbn_linear(bool): use linear layer in conditional batchnorm to + get gain and bias of normalization. Otherwise, + use embedding. + """ + super().__init__() + + # residual + self.bn1 = ConditionalBatchNorm2d(in_channels, cbn_in_dim, cbn_linear) + self.residual1 = nn.Sequential( + nn.ReLU(inplace=True), + nn.Upsample(scale_factor=2), + sn(nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1))) + self.bn2 = ConditionalBatchNorm2d(out_channels, cbn_in_dim, cbn_linear) + self.residual2 = nn.Sequential( + nn.ReLU(inplace=True), + sn(nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1))) + + # shortcut + self.shortcut = nn.Sequential( + nn.Upsample(scale_factor=2), + sn(nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0))) + + def forward(self, x, y): + h = self.residual1(self.bn1(x, y)) + h = self.residual2(self.bn2(h, y)) + return h + self.shortcut(x) + + +class Generator32(nn.Module): + def __init__(self, z_dim=128, n_classes=10, ch=64): + super().__init__() + # channels_multipler = [4, 4, 4, 4] + self.linear = sn(nn.Linear(z_dim, (ch * 4) * 4 * 4)) + self.blocks = nn.ModuleList([ + GenBlock(ch * 4, ch * 4, n_classes, False), # 4ch x 8 x 8 + GenBlock(ch * 4, ch * 4, n_classes, False), # 4ch x 16 x 16 + GenBlock(ch * 4, ch * 4, n_classes, False), # 4ch x 32 x 32 + ]) + self.output_layer = nn.Sequential( + nn.BatchNorm2d(ch * 4), + nn.ReLU(inplace=True), + sn(nn.Conv2d(ch * 4, 3, 3, padding=1)), # 3 x 32 x 32 + nn.Tanh()) + res32_weights_init(self) + + def forward(self, z, y): + h = self.linear(z).view(z.size(0), -1, 4, 4) + for block in self.blocks: + h = block(h, y) + h = self.output_layer(h) + return h + + +class Generator128(nn.Module): + def __init__(self, z_dim=128, n_classes=1000, ch=96, shared_dim=128): + super().__init__() + channels_multipler = [16, 16, 8, 4, 2, 1] + num_slots = len(channels_multipler) + self.chunk_size = (z_dim // num_slots) + z_dim = self.chunk_size * num_slots + cbn_in_dim = (shared_dim + self.chunk_size) + + self.shared_embedding = nn.Embedding(n_classes, shared_dim) + self.linear = sn(nn.Linear(z_dim // num_slots, (ch * 16) * 4 * 4)) + + self.blocks = nn.ModuleList([ + GenBlock(ch * 16, ch * 16, cbn_in_dim), # ch*16 x 4 x 4 + GenBlock(ch * 16, ch * 8, cbn_in_dim), # ch*16 x 8 x 8 + GenBlock(ch * 8, ch * 4, cbn_in_dim), # ch*8 x 16 x 16 + nn.ModuleList([ # ch*4 x 32 x 32 + GenBlock(ch * 4, ch * 2, cbn_in_dim), + Attention(ch * 2, True), # ch*2 x 64 x 64 + ]), + GenBlock(ch * 2, ch * 1, cbn_in_dim), # ch*1 x 128 x 128 + ]) + + self.output_layer = nn.Sequential( + nn.BatchNorm2d(ch * 1), + nn.ReLU(inplace=True), + sn(nn.Conv2d(ch * 1, 3, 3, padding=1)), # 3 x 128 x 128 + nn.Tanh()) + # res128_weights_init(self) + + def forward(self, z, y): + y = self.shared_embedding(y) + zs = torch.split(z, self.chunk_size, 1) + ys = [torch.cat([y, item], 1) for item in zs[1:]] + + h = self.linear(zs[0]).view(z.size(0), -1, 4, 4) + for i, block in enumerate(self.blocks): + if isinstance(block, nn.ModuleList): + for module in block: + h = module(h, ys[i]) + else: + h = block(h, ys[i]) + h = self.output_layer(h) + + return h + + +class OptimizedDisblock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.shortcut = nn.Sequential( + nn.AvgPool2d(2), + nn.Conv2d(in_channels, out_channels, 1, padding=0)) + self.residual = nn.Sequential( + nn.Conv2d(in_channels, out_channels, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, 3, padding=1), + nn.AvgPool2d(2)) + + def forward(self, x): + return self.residual(x) + self.shortcut(x) + + +class DisBlock(nn.Module): + def __init__(self, in_channels, out_channels, down=False): + super().__init__() + shortcut = [] + if in_channels != out_channels or down: + shortcut.append(nn.Conv2d(in_channels, out_channels, 1, 1, 0)) + if down: + shortcut.append(nn.AvgPool2d(2)) + self.shortcut = nn.Sequential(*shortcut) + + residual = [ + nn.ReLU(), + nn.Conv2d(in_channels, out_channels, 3, 1, 1), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, 3, 1, 1), + ] + if down: + residual.append(nn.AvgPool2d(2)) + self.residual = nn.Sequential(*residual) + + def forward(self, x): + return self.residual(x) + self.shortcut(x) + + +class Discriminator32(nn.Module): + def __init__(self, n_classes=10, ch=64): + super().__init__() + self.fp16 = False + # channels_multipler = [2, 2, 2, 2] + self.blocks = nn.Sequential( + OptimizedDisblock(3, ch * 4), # 3 x 32 x 32 + DisBlock(ch * 4, ch * 4, down=True), # ch*4 x 16 x 16 + DisBlock(ch * 4, ch * 4), # ch*4 x 8 x 8 + DisBlock(ch * 4, ch * 4), # ch*4 x 8 x 8 + nn.ReLU(inplace=True), + ) + + self.linear = nn.Linear(ch * 4, 1) + self.embedding = nn.Embedding(n_classes, ch * 4) + res32_weights_init(self) + + def forward(self, x, y): + h = self.blocks(x).sum(dim=[2, 3]) + h = self.linear(h) + (self.embedding(y) * h).sum(dim=1, keepdim=True) + return h + + +class Discriminator128(nn.Module): + def __init__(self, n_classes=1000, ch=96): + super().__init__() + # channels_multipler = [1, 2, 4, 8, 16, 16] + self.blocks = nn.Sequential( + OptimizedDisblock(3, ch * 1), # 3 x 128 x 128 + Attention(ch, False), # ch*1 x 64 x 64 + DisBlock(ch * 1, ch * 2, down=True), # ch*1 x 32 x 32 + DisBlock(ch * 2, ch * 4, down=True), # ch*2 x 16 x 16 + DisBlock(ch * 4, ch * 8, down=True), # ch*4 x 8 x 8 + DisBlock(ch * 8, ch * 16, down=True), # ch*8 x 4 x 4 + DisBlock(ch * 16, ch * 16), # ch*16 x 4 x 4 + nn.ReLU(inplace=True), # ch*16 x 4 x 4 + ) + + self.linear = nn.Linear(ch * 16, 1) + self.embedding = nn.Embedding(n_classes, ch * 16) + # res128_weights_init(self) + + def forward(self, x, y): + h = self.blocks(x).sum(dim=[2, 3]) + h = self.linear(h) + (self.embedding(y) * h).sum(dim=1, keepdim=True) + return h + + +def res32_weights_init(m): + for name, module in m.named_modules(): + if isinstance(module, (nn.Conv2d, nn.Linear, nn.Embedding)): + torch.nn.init.xavier_uniform_(module.weight) + if hasattr(module, 'bias') and module.bias is not None: + torch.nn.init.zeros_(module.bias) + + +def res128_weights_init(m): + for module in m.modules(): + if isinstance(module, (nn.Conv2d, nn.Linear, nn.Embedding)): + torch.nn.init.orthogonal_(module.weight) diff --git a/sn_testing/SN_models/biggan_module.py b/sn_testing/SN_models/biggan_module.py new file mode 100644 index 0000000..ff4444a --- /dev/null +++ b/sn_testing/SN_models/biggan_module.py @@ -0,0 +1,278 @@ +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +sn = partial(torch.nn.utils.spectral_norm, eps=1e-6) + + +class Attention(nn.Module): + """ + SA-GAN: https://arxiv.org/abs/1805.08318 + """ + def __init__(self, ch, use_spectral_norm): + super().__init__() + if use_spectral_norm: + spectral_norm = sn + else: + spectral_norm = (lambda x: x) + self.q = spectral_norm(nn.Conv2d( + ch, ch // 8, kernel_size=1, padding=0, bias=False)) + self.k = spectral_norm(nn.Conv2d( + ch, ch // 8, kernel_size=1, padding=0, bias=False)) + self.v = spectral_norm(nn.Conv2d( + ch, ch // 2, kernel_size=1, padding=0, bias=False)) + self.o = spectral_norm(nn.Conv2d( + ch // 2, ch, kernel_size=1, padding=0, bias=False)) + self.gamma = nn.Parameter(torch.tensor(0.), requires_grad=True) + + def forward(self, x, y=None): + B, C, H, W = x.size() + q = self.q(x) + k = F.max_pool2d(self.k(x), [2, 2]) + v = F.max_pool2d(self.v(x), [2, 2]) + # flatten + q = q.view(B, C // 8, H * W) # query + k = k.view(B, C // 8, H * W // 4) # key + v = v.view(B, C // 2, H * W // 4) # value + # attention weights + w = F.softmax(torch.bmm(q.transpose(1, 2), k), -1) + # attend and project + o = self.o(torch.bmm(v, w.transpose(1, 2)).view(B, C // 2, H, W)) + return self.gamma * o + x + + +class ConditionalBatchNorm2d(nn.Module): + def __init__(self, in_channel, cond_size, linear=True): + super().__init__() + if linear: + self.gain = sn(nn.Linear(cond_size, in_channel, bias=False)) + self.bias = sn(nn.Linear(cond_size, in_channel, bias=False)) + else: + self.gain = nn.Embedding(cond_size, in_channel) + self.bias = nn.Embedding(cond_size, in_channel) + self.batchnorm2d = nn.BatchNorm2d(in_channel, affine=False) + + def forward(self, x, y): + gain = self.gain(y).view(y.size(0), -1, 1, 1) + 1 + bias = self.bias(y).view(y.size(0), -1, 1, 1) + x = self.batchnorm2d(x) + return x * gain + bias + + +class GenBlock(nn.Module): + def __init__(self, in_channels, out_channels, cbn_in_dim, cbn_linear=True): + """ + cbn_in_dim(int): output size of shared embedding + cbn_linear(bool): use linear layer in conditional batchnorm to + get gain and bias of normalization. Otherwise, + use embedding. + """ + super().__init__() + + # residual + self.bn1 = ConditionalBatchNorm2d(in_channels, cbn_in_dim, cbn_linear) + self.residual1 = nn.Sequential( + nn.ReLU(inplace=True), + nn.Upsample(scale_factor=2), + sn(nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1))) + self.bn2 = ConditionalBatchNorm2d(out_channels, cbn_in_dim, cbn_linear) + self.residual2 = nn.Sequential( + nn.ReLU(inplace=True), + sn(nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1))) + + # shortcut + self.shortcut = nn.Sequential( + nn.Upsample(scale_factor=2), + sn(nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0))) + + def forward(self, x, y): + h = self.residual1(self.bn1(x, y)) + h = self.residual2(self.bn2(h, y)) + return h + self.shortcut(x) + + +class Generator32(nn.Module): + def __init__(self, z_dim=128, n_classes=10, ch=64): + super().__init__() + # channels_multipler = [4, 4, 4, 4] + self.linear = sn(nn.Linear(z_dim, (ch * 4) * 4 * 4)) + self.blocks = nn.ModuleList([ + GenBlock(ch * 4, ch * 4, n_classes, False), # 4ch x 8 x 8 + GenBlock(ch * 4, ch * 4, n_classes, False), # 4ch x 16 x 16 + GenBlock(ch * 4, ch * 4, n_classes, False), # 4ch x 32 x 32 + ]) + self.output_layer = nn.Sequential( + nn.BatchNorm2d(ch * 4), + nn.ReLU(inplace=True), + sn(nn.Conv2d(ch * 4, 3, 3, padding=1)), # 3 x 32 x 32 + nn.Tanh()) + res32_weights_init(self) + + def forward(self, z, y): + h = self.linear(z).view(z.size(0), -1, 4, 4) + for block in self.blocks: + h = block(h, y) + h = self.output_layer(h) + return h + + +class Generator128(nn.Module): + def __init__(self, z_dim=128, n_classes=1000, ch=96, shared_dim=128): + super().__init__() + channels_multipler = [16, 16, 8, 4, 2, 1] + num_slots = len(channels_multipler) + self.chunk_size = (z_dim // num_slots) + z_dim = self.chunk_size * num_slots + cbn_in_dim = (shared_dim + self.chunk_size) + + self.shared_embedding = nn.Embedding(n_classes, shared_dim) + self.linear = sn(nn.Linear(z_dim // num_slots, (ch * 16) * 4 * 4)) + + self.blocks = nn.ModuleList([ + GenBlock(ch * 16, ch * 16, cbn_in_dim), # ch*16 x 4 x 4 + GenBlock(ch * 16, ch * 8, cbn_in_dim), # ch*16 x 8 x 8 + GenBlock(ch * 8, ch * 4, cbn_in_dim), # ch*8 x 16 x 16 + nn.ModuleList([ # ch*4 x 32 x 32 + GenBlock(ch * 4, ch * 2, cbn_in_dim), + Attention(ch * 2, True), # ch*2 x 64 x 64 + ]), + GenBlock(ch * 2, ch * 1, cbn_in_dim), # ch*1 x 128 x 128 + ]) + + self.output_layer = nn.Sequential( + nn.BatchNorm2d(ch * 1), + nn.ReLU(inplace=True), + sn(nn.Conv2d(ch * 1, 3, 3, padding=1)), # 3 x 128 x 128 + nn.Tanh()) + # res128_weights_init(self) + + def forward(self, z, y): + y = self.shared_embedding(y) + zs = torch.split(z, self.chunk_size, 1) + ys = [torch.cat([y, item], 1) for item in zs[1:]] + + h = self.linear(zs[0]).view(z.size(0), -1, 4, 4) + for i, block in enumerate(self.blocks): + if isinstance(block, nn.ModuleList): + for module in block: + h = module(h, ys[i]) + else: + h = block(h, ys[i]) + h = self.output_layer(h) + + return h + + +class OptimizedDisblock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.shortcut = sn(nn.Sequential( + nn.AvgPool2d(2), + nn.Conv2d(in_channels, out_channels, 1, padding=0))) + self.residual = nn.Sequential( + nn.Conv2d(in_channels, out_channels, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, 3, padding=1), + nn.AvgPool2d(2)) + + def forward(self, x): + return self.residual(x) + self.shortcut(x) + + +class DisBlock(nn.Module): + def __init__(self, in_channels, out_channels, down=False): + super().__init__() + shortcut = [] + if in_channels != out_channels or down: + shortcut.append(nn.Conv2d(in_channels, out_channels, 1, 1, 0)) + if down: + shortcut.append(nn.AvgPool2d(2)) + self.shortcut = nn.Sequential(*shortcut) + + residual = [ + nn.ReLU(), + nn.Conv2d(in_channels, out_channels, 3, 1, 1), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, 3, 1, 1), + ] + if down: + residual.append(nn.AvgPool2d(2)) + self.residual = nn.Sequential(*residual) + #self.gradnorm = GradNorm(DisBlock(out_channels, out_channels)) + + def forward(self, x): + x = self.residual(x) + self.shortcut(x) + #x = self.gradnorm(x) + return x + + +class Discriminator32(nn.Module): + def __init__(self, n_classes=10, ch=64): + super().__init__() + self.fp16 = False + # channels_multipler = [2, 2, 2, 2] + self.blocks = nn.Sequential( + OptimizedDisblock(3, ch * 4), # 3 x 32 x 32 + DisBlock(ch * 4, ch * 4, down=True), # ch*4 x 16 x 16 + DisBlock(ch * 4, ch * 4), # ch*4 x 8 x 8 + DisBlock(ch * 4, ch * 4), # ch*4 x 8 x 8 + nn.ReLU(inplace=True), + ) + #self.gradnorm = GradNorm(Discriminator32()) + self.linear = nn.Linear(ch * 4, 1) + self.embedding = nn.Embedding(n_classes, ch * 4) + res32_weights_init(self) + + def forward(self, x, y): + h = self.blocks(x).sum(dim=[2, 3]) + #h = self.gradnorm(h) + h = self.linear(h) + (self.embedding(y) * h).sum(dim=1, keepdim=True) + return h + + +class Discriminator128(nn.Module): + def __init__(self, n_classes=1000, ch=96): + super().__init__() + # channels_multipler = [1, 2, 4, 8, 16, 16] + self.blocks = nn.Sequential( + OptimizedDisblock(3, ch * 1), # 3 x 128 x 128 + Attention(ch, False), # ch*1 x 64 x 64 + DisBlock(ch * 1, ch * 2, down=True), # ch*1 x 32 x 32 + sn(Discriminator128(ch * 2)), + DisBlock(ch * 2, ch * 4, down=True), # ch*2 x 16 x 16 + sn(Discriminator128(ch * 4)), + DisBlock(ch * 4, ch * 8, down=True), # ch*4 x 8 x 8 + sn(Discriminator128(ch * 8)), + DisBlock(ch * 8, ch * 16, down=True), # ch*8 x 4 x 4 + sn(Discriminator128(ch * 16)), + DisBlock(ch * 16, ch * 16), # ch*16 x 4 x 4 + sn(Discriminator128(ch * 16)), + nn.ReLU(inplace=True), # ch*16 x 4 x 4 + ) + self.gradnorm = sn(Discriminator128(ch * 16)) + self.linear = nn.Linear(ch * 16, 1) + self.embedding = nn.Embedding(n_classes, ch * 16) + # res128_weights_init(self) + + def forward(self, x, y): + h = self.blocks(x).sum(dim=[2, 3]) + h = self.gradnorm(h) + h = self.linear(h) + (self.embedding(y) * h).sum(dim=1, keepdim=True) + return h + + +def res32_weights_init(m): + for name, module in m.named_modules(): + if isinstance(module, (nn.Conv2d, nn.Linear, nn.Embedding)): + torch.nn.init.xavier_uniform_(module.weight) + if hasattr(module, 'bias') and module.bias is not None: + torch.nn.init.zeros_(module.bias) + + +def res128_weights_init(m): + for module in m.modules(): + if isinstance(module, (nn.Conv2d, nn.Linear, nn.Embedding)): + torch.nn.init.orthogonal_(module.weight) diff --git a/sn_testing/SN_models/dcgan.py b/sn_testing/SN_models/dcgan.py new file mode 100644 index 0000000..910d443 --- /dev/null +++ b/sn_testing/SN_models/dcgan.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn +import torch.nn.init as init + + +class Generator(nn.Module): + def __init__(self, z_dim, M=4): + super().__init__() + self.M = M + self.linear = nn.Linear(z_dim, M * M * 512) + self.main = nn.Sequential( + nn.BatchNorm2d(512), + nn.ReLU(True), + nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(256), + nn.ReLU(True), + nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(128), + nn.ReLU(True), + nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(64), + nn.ReLU(True), + nn.ConvTranspose2d(64, 3, kernel_size=3, stride=1, padding=1), + nn.Tanh()) + self.initialize() + + def initialize(self): + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)): + init.normal_(m.weight, std=0.02) + init.zeros_(m.bias) + + def forward(self, z, *args, **kwargs): + x = self.linear(z) + x = x.view(x.size(0), -1, self.M, self.M) + x = self.main(x) + return x + + +class Discriminator(nn.Module): + def __init__(self, M=32): + super().__init__() + self.M = M + + self.main = nn.Sequential( + # M + nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), + nn.LeakyReLU(0.1, inplace=True), + # M / 2 + nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), + nn.LeakyReLU(0.1, inplace=True), + # M / 4 + nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), + nn.LeakyReLU(0.1, inplace=True), + # M / 8 + nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.1, inplace=True)) + + self.linear = nn.Linear(M // 8 * M // 8 * 512, 1) + self.initialize() + + def initialize(self): + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.Linear)): + init.normal_(m.weight, std=0.02) + init.zeros_(m.bias) + + def rescale_weight(self, min_norm=1.0, max_norm=1.33): + a = 1.0 + with torch.no_grad(): + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.Linear)): + w_norm = m.weight.norm(p=2) + #print(m, w_norm) + w_norm = max(w_norm, min_norm) + w_norm = min(w_norm, max_norm) + a = a * w_norm + m.weight.data.div_(w_norm) + m.bias.data.div_(a) + + def forward(self, x, *args, **kwargs): + x = self.main(x) + x = torch.flatten(x, start_dim=1) + x = self.linear(x) + return x + + +class Generator32(Generator): + def __init__(self, z_dim, *args): + super().__init__(z_dim, M=4) + + +class Generator48(Generator): + def __init__(self, z_dim, *args): + super().__init__(z_dim, M=6) + + +class Discriminator32(Discriminator): + def __init__(self, *args): + super().__init__(M=32) + + +class Discriminator48(Discriminator): + def __init__(self, *args): + super().__init__(M=48) diff --git a/sn_testing/SN_models/dcgan_module.py b/sn_testing/SN_models/dcgan_module.py new file mode 100644 index 0000000..30ca968 --- /dev/null +++ b/sn_testing/SN_models/dcgan_module.py @@ -0,0 +1,133 @@ +import torch +import torch.nn as nn +import torch.nn.init as init +from functools import partial + +sn = partial(torch.nn.utils.spectral_norm, eps=1e-5) + + +class Generator(nn.Module): + def __init__(self, z_dim, M=4): + super().__init__() + self.M = M + self.linear = nn.Linear(z_dim, M * M * 512) + self.main = nn.Sequential( + nn.BatchNorm2d(512), + nn.ReLU(True), + nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(256), + nn.ReLU(True), + nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(128), + nn.ReLU(True), + nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(64), + nn.ReLU(True), + nn.ConvTranspose2d(64, 3, kernel_size=3, stride=1, padding=1), + nn.Tanh()) + self.initialize() + + def initialize(self): + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)): + init.normal_(m.weight, std=0.02) + init.zeros_(m.bias) + + def forward(self, z, *args, **kwargs): + x = self.linear(z) + x = x.view(x.size(0), -1, self.M, self.M) + x = self.main(x) + return x + + +class Discriminator(nn.Module): + def __init__(self, M=32): + super().__init__() + self.M = M + + # M + self.block1 = sn(nn.Sequential( + nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.1, inplace=False) + )) + self.block2 = sn(nn.Sequential( + nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), + nn.LeakyReLU(0.1, inplace=False) + )) + # M / 2 + self.block3 = sn(nn.Sequential( + nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.1, inplace=False) + )) + self.block4 = sn(nn.Sequential( + nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), + nn.LeakyReLU(0.1, inplace=False) + )) + # M / 4 + self.block5 = sn(nn.Sequential( + nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.1, inplace=False) + )) + self.block6 = sn(nn.Sequential( + nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), + nn.LeakyReLU(0.1, inplace=False) + )) + # M / 8 + self.block7 = sn(nn.Sequential( + nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.1, inplace=False) + )) + + self.linear = sn(nn.Linear(M // 8 * M // 8 * 512, 1)) + self.initialize() + + def initialize(self): + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.Linear)): + init.normal_(m.weight, std=0.02) + init.zeros_(m.bias) + + def rescale_weight(self, min_norm=1.0, max_norm=1.33): + a = 1.0 + with torch.no_grad(): + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.Linear)): + w_norm = m.weight.norm(p=2) + #print(m, w_norm) + w_norm = max(w_norm, min_norm) + w_norm = min(w_norm, max_norm) + a = a * w_norm + m.weight.data.div_(w_norm) + m.bias.data.div_(a) + + def forward(self, x, *args, **kwargs): + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + x = self.block4(x) + x = self.block5(x) + x = self.block6(x) + x = self.block7(x) + x = torch.flatten(x, start_dim=1) + x = self.linear(x) + return x + + +class Generator32(Generator): + def __init__(self, z_dim, *args): + super().__init__(z_dim, M=4) + + +class Generator48(Generator): + def __init__(self, z_dim, *args): + super().__init__(z_dim, M=6) + + +class Discriminator32(Discriminator): + def __init__(self, *args): + super().__init__(M=32) + + +class Discriminator48(Discriminator): + def __init__(self, *args): + super().__init__(M=48) diff --git a/sn_testing/SN_models/gradnorm.py b/sn_testing/SN_models/gradnorm.py new file mode 100644 index 0000000..c9b7c40 --- /dev/null +++ b/sn_testing/SN_models/gradnorm.py @@ -0,0 +1,37 @@ +import torch +import torch.nn as nn + +def normalize_gradient(net_D, x, **kwargs): + """ + f + f_hat = -------------------- + || grad_f || + | f | + """ + x.requires_grad_(True) + f = net_D(x, **kwargs) + return f + +def get_gradient(net_D, x, **kwargs): + f = net_D(x, **kwargs) + return f + +class GradNorm(nn.Module): + def __init__(self, module, out_channels = None): + super().__init__() + self.module = module + self.out_channels = out_channels + + def forward(self, input): + """ + f + f_hat = -------------------- + || grad_f || + | f | + """ + input.requires_grad_(True) + f = self.module(input) + grad = torch.autograd.grad( + f, [input], torch.ones_like(f), create_graph=True, retain_graph=True)[0] + grad_norm = torch.norm(torch.flatten(grad, start_dim=1), p=2, dim=1) + grad_norm = grad_norm.view(-1, *[1 for _ in range(len(f.shape) - 1)]) + f_hat = (f / (grad_norm + torch.abs(f))) + return f_hat diff --git a/sn_testing/SN_models/resnet.py b/sn_testing/SN_models/resnet.py new file mode 100644 index 0000000..64bd002 --- /dev/null +++ b/sn_testing/SN_models/resnet.py @@ -0,0 +1,334 @@ +import torch +import torch.nn as nn +import torch.nn.init as init + + +class GenBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + # shortcut + self.shortcut = nn.Sequential( + nn.Upsample(scale_factor=2), + nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0) + ) + # residual + self.residual = nn.Sequential( + nn.BatchNorm2d(in_channels), + nn.ReLU(True), + nn.Upsample(scale_factor=2), + nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(True), + nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1), + ) + # initialize weight + self.initialize() + + def initialize(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight) + init.zeros_(m.bias) + + def forward(self, x): + return self.residual(x) + self.shortcut(x) + + +class ResGenerator32(nn.Module): + def __init__(self, z_dim, *args): + super().__init__() + self.linear = nn.Linear(z_dim, 4 * 4 * 256) + self.blocks = nn.Sequential( + GenBlock(256, 256), + GenBlock(256, 256), + GenBlock(256, 256), + ) + self.output = nn.Sequential( + nn.BatchNorm2d(256), + nn.ReLU(True), + nn.Conv2d(256, 3, 3, stride=1, padding=1), + nn.Tanh(), + ) + # initialize weight + self.initialize() + + def initialize(self): + init.kaiming_normal_(self.linear.weight) + init.zeros_(self.linear.bias) + for m in self.output.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight) + init.zeros_(m.bias) + + def forward(self, z, *args, **kwargs): + z = self.linear(z) + z = z.view(-1, 256, 4, 4) + return self.output(self.blocks(z)) + + +class ResGenerator48(nn.Module): + def __init__(self, z_dim, *args): + super().__init__() + self.linear = nn.Linear(z_dim, 6 * 6 * 512) + self.blocks = nn.Sequential( + GenBlock(512, 256), + GenBlock(256, 128), + GenBlock(128, 64), + ) + self.output = nn.Sequential( + nn.BatchNorm2d(64), + nn.ReLU(True), + nn.Conv2d(64, 3, 3, stride=1, padding=1), + nn.Tanh(), + ) + # initialize weight + self.initialize() + + def initialize(self): + init.kaiming_normal_(self.linear.weight) + init.zeros_(self.linear.bias) + for m in self.output.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight) + init.zeros_(m.bias) + + def forward(self, z, *args, **kwargs): + z = self.linear(z) + z = z.view(-1, 512, 6, 6) + return self.output(self.blocks(z)) + + +class ResGenerator128(nn.Module): + def __init__(self, z_dim): + super().__init__() + self.linear = nn.Linear(z_dim, 4 * 4 * 1024) + + self.blocks = nn.Sequential( + GenBlock(1024, 1024), + GenBlock(1024, 512), + GenBlock(512, 256), + GenBlock(256, 128), + GenBlock(128, 64), + ) + self.output = nn.Sequential( + nn.BatchNorm2d(64), + nn.ReLU(True), + nn.Conv2d(64, 3, 3, stride=1, padding=1), + nn.Tanh(), + ) + # initialize weight + self.initialize() + + def initialize(self): + init.kaiming_normal_(self.linear.weight) + init.zeros_(self.linear.bias) + for m in self.output.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight) + init.zeros_(m.bias) + + def forward(self, z): + inputs = self.linear(z) + inputs = inputs.view(-1, 1024, 4, 4) + return self.output(self.blocks(inputs)) + + +class ResGenerator256(nn.Module): + def __init__(self, z_dim): + super().__init__() + self.linear = nn.Linear(z_dim, 4 * 4 * 1024) + + self.blocks = nn.Sequential( + GenBlock(1024, 1024), + GenBlock(1024, 512), + GenBlock(512, 512), + GenBlock(512, 256), + GenBlock(256, 128), + GenBlock(128, 64), + ) + self.output = nn.Sequential( + nn.BatchNorm2d(64), + nn.ReLU(True), + nn.Conv2d(64, 3, 3, stride=1, padding=1), + nn.Tanh(), + ) + # initialize weight + self.initialize() + + def initialize(self): + init.kaiming_normal_(self.linear.weight) + init.zeros_(self.linear.bias) + for m in self.output.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight) + init.zeros_(m.bias) + + def forward(self, z): + inputs = self.linear(z) + inputs = inputs.view(-1, 1024, 4, 4) + return self.output(self.blocks(inputs)) + + +class OptimizedDisblock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + # shortcut + self.shortcut = nn.Sequential( + nn.AvgPool2d(2), + nn.Conv2d(in_channels, out_channels, 1, 1, 0)) + # residual + self.residual = nn.Sequential( + nn.Conv2d(in_channels, out_channels, 3, 1, 1), + nn.ReLU(True), + nn.Conv2d(out_channels, out_channels, 3, 1, 1), + nn.AvgPool2d(2)) + # initialize weight + self.initialize() + + def initialize(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight) + init.zeros_(m.bias) + + def forward(self, x): + return self.residual(x) + self.shortcut(x) + + +class DisBlock(nn.Module): + def __init__(self, in_channels, out_channels, down=False): + super().__init__() + # shortcut + shortcut = [] + if in_channels != out_channels or down: + shortcut.append( + nn.Conv2d(in_channels, out_channels, 1, 1, 0)) + if down: + shortcut.append(nn.AvgPool2d(2)) + self.shortcut = nn.Sequential(*shortcut) + # residual + residual = [ + nn.ReLU(), + nn.Conv2d(in_channels, out_channels, 3, 1, 1), + nn.ReLU(True), + nn.Conv2d(out_channels, out_channels, 3, 1, 1), + ] + if down: + residual.append(nn.AvgPool2d(2)) + self.residual = nn.Sequential(*residual) + # initialize weight + self.initialize() + + def initialize(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight) + init.zeros_(m.bias) + + def forward(self, x): + return (self.residual(x) + self.shortcut(x)) + + +class ResDiscriminator32(nn.Module): + def __init__(self, *args): + super().__init__() + self.model = nn.Sequential( + OptimizedDisblock(3, 128), + DisBlock(128, 128, down=True), + DisBlock(128, 128), + DisBlock(128, 128), + nn.ReLU(True), + nn.AdaptiveAvgPool2d((1, 1))) + self.linear = nn.Linear(128, 1) + # initialize weight + self.initialize() + + def initialize(self): + init.kaiming_normal_(self.linear.weight) + init.zeros_(self.linear.bias) + + def forward(self, x, *args, **kwargs): + x = self.model(x) + x = torch.flatten(x, start_dim=1) + x = self.linear(x) + return x + + +class ResDiscriminator48(nn.Module): + def __init__(self, *args): + super().__init__() + self.model = nn.Sequential( + OptimizedDisblock(3, 64), + DisBlock(64, 128, down=True), + DisBlock(128, 256, down=True), + DisBlock(256, 512, down=True), + nn.ReLU(True), + nn.AdaptiveAvgPool2d((1, 1))) + self.linear = nn.Linear(512, 1) + # initialize weight + self.initialize() + + def initialize(self): + init.kaiming_normal_(self.linear.weight) + init.zeros_(self.linear.bias) + + def forward(self, x, *args, **kwargs): + x = self.model(x) + x = torch.flatten(x, start_dim=1) + x = self.linear(x) + return x + + +class ResDiscriminator128(nn.Module): + def __init__(self): + super().__init__() + self.model = nn.Sequential( + OptimizedDisblock(3, 64), + DisBlock(64, 128, down=True), + DisBlock(128, 256, down=True), + DisBlock(256, 512, down=True), + DisBlock(512, 1024, down=True), + DisBlock(1024, 1024), + nn.ReLU(True), + nn.AdaptiveAvgPool2d((1, 1))) + self.linear = nn.Linear(1024, 1) + # initialize weight + self.initialize() + + def initialize(self): + init.kaiming_normal_(self.linear.weight) + init.zeros_(self.linear.bias) + + def forward(self, x): + x = self.model(x) + x = torch.flatten(x, start_dim=1) + x = self.linear(x) + return x + + +class ResDiscriminator256(nn.Module): + def __init__(self): + super().__init__() + self.model = nn.Sequential( + OptimizedDisblock(3, 64), + DisBlock(64, 128, down=True), + DisBlock(128, 256, down=True), + DisBlock(256, 512, down=True), + DisBlock(512, 512, down=True), + DisBlock(512, 1024, down=True), + DisBlock(1024, 1024), + nn.ReLU(True), + nn.AdaptiveAvgPool2d((1, 1))) + self.linear = nn.Linear(1024, 1) + # initialize weight + self.initialize() + + def initialize(self): + init.kaiming_normal_(self.linear.weight) + init.zeros_(self.linear.bias) + + def forward(self, x): + x = self.model(x) + x = torch.flatten(x, start_dim=1) + x = self.linear(x) + return x diff --git a/sn_testing/SN_models/resnet_module.py b/sn_testing/SN_models/resnet_module.py new file mode 100644 index 0000000..64bd002 --- /dev/null +++ b/sn_testing/SN_models/resnet_module.py @@ -0,0 +1,334 @@ +import torch +import torch.nn as nn +import torch.nn.init as init + + +class GenBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + # shortcut + self.shortcut = nn.Sequential( + nn.Upsample(scale_factor=2), + nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0) + ) + # residual + self.residual = nn.Sequential( + nn.BatchNorm2d(in_channels), + nn.ReLU(True), + nn.Upsample(scale_factor=2), + nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(True), + nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1), + ) + # initialize weight + self.initialize() + + def initialize(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight) + init.zeros_(m.bias) + + def forward(self, x): + return self.residual(x) + self.shortcut(x) + + +class ResGenerator32(nn.Module): + def __init__(self, z_dim, *args): + super().__init__() + self.linear = nn.Linear(z_dim, 4 * 4 * 256) + self.blocks = nn.Sequential( + GenBlock(256, 256), + GenBlock(256, 256), + GenBlock(256, 256), + ) + self.output = nn.Sequential( + nn.BatchNorm2d(256), + nn.ReLU(True), + nn.Conv2d(256, 3, 3, stride=1, padding=1), + nn.Tanh(), + ) + # initialize weight + self.initialize() + + def initialize(self): + init.kaiming_normal_(self.linear.weight) + init.zeros_(self.linear.bias) + for m in self.output.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight) + init.zeros_(m.bias) + + def forward(self, z, *args, **kwargs): + z = self.linear(z) + z = z.view(-1, 256, 4, 4) + return self.output(self.blocks(z)) + + +class ResGenerator48(nn.Module): + def __init__(self, z_dim, *args): + super().__init__() + self.linear = nn.Linear(z_dim, 6 * 6 * 512) + self.blocks = nn.Sequential( + GenBlock(512, 256), + GenBlock(256, 128), + GenBlock(128, 64), + ) + self.output = nn.Sequential( + nn.BatchNorm2d(64), + nn.ReLU(True), + nn.Conv2d(64, 3, 3, stride=1, padding=1), + nn.Tanh(), + ) + # initialize weight + self.initialize() + + def initialize(self): + init.kaiming_normal_(self.linear.weight) + init.zeros_(self.linear.bias) + for m in self.output.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight) + init.zeros_(m.bias) + + def forward(self, z, *args, **kwargs): + z = self.linear(z) + z = z.view(-1, 512, 6, 6) + return self.output(self.blocks(z)) + + +class ResGenerator128(nn.Module): + def __init__(self, z_dim): + super().__init__() + self.linear = nn.Linear(z_dim, 4 * 4 * 1024) + + self.blocks = nn.Sequential( + GenBlock(1024, 1024), + GenBlock(1024, 512), + GenBlock(512, 256), + GenBlock(256, 128), + GenBlock(128, 64), + ) + self.output = nn.Sequential( + nn.BatchNorm2d(64), + nn.ReLU(True), + nn.Conv2d(64, 3, 3, stride=1, padding=1), + nn.Tanh(), + ) + # initialize weight + self.initialize() + + def initialize(self): + init.kaiming_normal_(self.linear.weight) + init.zeros_(self.linear.bias) + for m in self.output.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight) + init.zeros_(m.bias) + + def forward(self, z): + inputs = self.linear(z) + inputs = inputs.view(-1, 1024, 4, 4) + return self.output(self.blocks(inputs)) + + +class ResGenerator256(nn.Module): + def __init__(self, z_dim): + super().__init__() + self.linear = nn.Linear(z_dim, 4 * 4 * 1024) + + self.blocks = nn.Sequential( + GenBlock(1024, 1024), + GenBlock(1024, 512), + GenBlock(512, 512), + GenBlock(512, 256), + GenBlock(256, 128), + GenBlock(128, 64), + ) + self.output = nn.Sequential( + nn.BatchNorm2d(64), + nn.ReLU(True), + nn.Conv2d(64, 3, 3, stride=1, padding=1), + nn.Tanh(), + ) + # initialize weight + self.initialize() + + def initialize(self): + init.kaiming_normal_(self.linear.weight) + init.zeros_(self.linear.bias) + for m in self.output.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight) + init.zeros_(m.bias) + + def forward(self, z): + inputs = self.linear(z) + inputs = inputs.view(-1, 1024, 4, 4) + return self.output(self.blocks(inputs)) + + +class OptimizedDisblock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + # shortcut + self.shortcut = nn.Sequential( + nn.AvgPool2d(2), + nn.Conv2d(in_channels, out_channels, 1, 1, 0)) + # residual + self.residual = nn.Sequential( + nn.Conv2d(in_channels, out_channels, 3, 1, 1), + nn.ReLU(True), + nn.Conv2d(out_channels, out_channels, 3, 1, 1), + nn.AvgPool2d(2)) + # initialize weight + self.initialize() + + def initialize(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight) + init.zeros_(m.bias) + + def forward(self, x): + return self.residual(x) + self.shortcut(x) + + +class DisBlock(nn.Module): + def __init__(self, in_channels, out_channels, down=False): + super().__init__() + # shortcut + shortcut = [] + if in_channels != out_channels or down: + shortcut.append( + nn.Conv2d(in_channels, out_channels, 1, 1, 0)) + if down: + shortcut.append(nn.AvgPool2d(2)) + self.shortcut = nn.Sequential(*shortcut) + # residual + residual = [ + nn.ReLU(), + nn.Conv2d(in_channels, out_channels, 3, 1, 1), + nn.ReLU(True), + nn.Conv2d(out_channels, out_channels, 3, 1, 1), + ] + if down: + residual.append(nn.AvgPool2d(2)) + self.residual = nn.Sequential(*residual) + # initialize weight + self.initialize() + + def initialize(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight) + init.zeros_(m.bias) + + def forward(self, x): + return (self.residual(x) + self.shortcut(x)) + + +class ResDiscriminator32(nn.Module): + def __init__(self, *args): + super().__init__() + self.model = nn.Sequential( + OptimizedDisblock(3, 128), + DisBlock(128, 128, down=True), + DisBlock(128, 128), + DisBlock(128, 128), + nn.ReLU(True), + nn.AdaptiveAvgPool2d((1, 1))) + self.linear = nn.Linear(128, 1) + # initialize weight + self.initialize() + + def initialize(self): + init.kaiming_normal_(self.linear.weight) + init.zeros_(self.linear.bias) + + def forward(self, x, *args, **kwargs): + x = self.model(x) + x = torch.flatten(x, start_dim=1) + x = self.linear(x) + return x + + +class ResDiscriminator48(nn.Module): + def __init__(self, *args): + super().__init__() + self.model = nn.Sequential( + OptimizedDisblock(3, 64), + DisBlock(64, 128, down=True), + DisBlock(128, 256, down=True), + DisBlock(256, 512, down=True), + nn.ReLU(True), + nn.AdaptiveAvgPool2d((1, 1))) + self.linear = nn.Linear(512, 1) + # initialize weight + self.initialize() + + def initialize(self): + init.kaiming_normal_(self.linear.weight) + init.zeros_(self.linear.bias) + + def forward(self, x, *args, **kwargs): + x = self.model(x) + x = torch.flatten(x, start_dim=1) + x = self.linear(x) + return x + + +class ResDiscriminator128(nn.Module): + def __init__(self): + super().__init__() + self.model = nn.Sequential( + OptimizedDisblock(3, 64), + DisBlock(64, 128, down=True), + DisBlock(128, 256, down=True), + DisBlock(256, 512, down=True), + DisBlock(512, 1024, down=True), + DisBlock(1024, 1024), + nn.ReLU(True), + nn.AdaptiveAvgPool2d((1, 1))) + self.linear = nn.Linear(1024, 1) + # initialize weight + self.initialize() + + def initialize(self): + init.kaiming_normal_(self.linear.weight) + init.zeros_(self.linear.bias) + + def forward(self, x): + x = self.model(x) + x = torch.flatten(x, start_dim=1) + x = self.linear(x) + return x + + +class ResDiscriminator256(nn.Module): + def __init__(self): + super().__init__() + self.model = nn.Sequential( + OptimizedDisblock(3, 64), + DisBlock(64, 128, down=True), + DisBlock(128, 256, down=True), + DisBlock(256, 512, down=True), + DisBlock(512, 512, down=True), + DisBlock(512, 1024, down=True), + DisBlock(1024, 1024), + nn.ReLU(True), + nn.AdaptiveAvgPool2d((1, 1))) + self.linear = nn.Linear(1024, 1) + # initialize weight + self.initialize() + + def initialize(self): + init.kaiming_normal_(self.linear.weight) + init.zeros_(self.linear.bias) + + def forward(self, x): + x = self.model(x) + x = torch.flatten(x, start_dim=1) + x = self.linear(x) + return x diff --git a/sn_testing/config/GN-GAN-CR_CIFAR10_BIGGAN.txt b/sn_testing/config/GN-GAN-CR_CIFAR10_BIGGAN.txt new file mode 100644 index 0000000..41fb099 --- /dev/null +++ b/sn_testing/config/GN-GAN-CR_CIFAR10_BIGGAN.txt @@ -0,0 +1,27 @@ +--dataset=cifar10.32 +--arch=biggan.32 +--loss=hinge +--total_steps=125000 +--lr_decay_start=125000 +--batch_size_D=50 +--batch_size_G=50 +--num_workers=10 +--lr_D=0.0002 +--lr_G=0.0001 +--betas=0.0 +--betas=0.999 +--n_dis=4 +--z_dim=128 +--cr=5 +--n_classes=10 + +--ema_decay=0.9999 +--ema_start=1000 + +--sample_step=500 +--sample_size=64 +--eval_step=5000 +--save_step=20000 +--num_images=50000 +--fid_stats=./stats/cifar10.train.npz +--logdir=./logs/GN-GAN-CR_CIFAR10_BIGGAN_0 diff --git a/sn_testing/config/GN-GAN-CR_CIFAR10_CNN.txt b/sn_testing/config/GN-GAN-CR_CIFAR10_CNN.txt new file mode 100644 index 0000000..be68522 --- /dev/null +++ b/sn_testing/config/GN-GAN-CR_CIFAR10_CNN.txt @@ -0,0 +1,25 @@ +--dataset=cifar10.32 +--arch=dcgan.32 +--loss=bce +--total_steps=200000 +--lr_decay_start=200000 +--batch_size_D=64 +--batch_size_G=128 +--num_workers=10 +--lr_D=0.0002 +--lr_G=0.0002 +--n_dis=1 +--z_dim=128 +--cr=5 +--n_classes=1 + +--ema_decay=0.9999 +--ema_start=0 + +--sample_step=500 +--sample_size=64 +--eval_step=5000 +--save_step=20000 +--num_images=50000 +--fid_stats=./stats/cifar10.train.npz +--logdir=./logs/GN-GAN-CR_CIFAR10_CNN_0 diff --git a/sn_testing/config/GN-GAN-CR_CIFAR10_RES.txt b/sn_testing/config/GN-GAN-CR_CIFAR10_RES.txt new file mode 100644 index 0000000..1cbec9d --- /dev/null +++ b/sn_testing/config/GN-GAN-CR_CIFAR10_RES.txt @@ -0,0 +1,25 @@ +--dataset=cifar10.32 +--arch=resnet.32 +--loss=hinge +--total_steps=200000 +--lr_decay_start=0 +--batch_size_D=64 +--batch_size_G=128 +--num_workers=10 +--lr_D=0.0004 +--lr_G=0.0002 +--n_dis=5 +--z_dim=128 +--cr=5 +--n_classes=1 + +--ema_decay=0.9999 +--ema_start=0 + +--sample_step=500 +--sample_size=64 +--eval_step=5000 +--save_step=20000 +--num_images=50000 +--fid_stats=./stats/cifar10.train.npz +--logdir=./logs/GN-GAN-CR_CIFAR10_RES_0 diff --git a/sn_testing/config/GN-GAN-CR_STL10_CNN.txt b/sn_testing/config/GN-GAN-CR_STL10_CNN.txt new file mode 100644 index 0000000..64c5822 --- /dev/null +++ b/sn_testing/config/GN-GAN-CR_STL10_CNN.txt @@ -0,0 +1,25 @@ +--dataset=stl10.48 +--arch=dcgan.48 +--loss=bce +--total_steps=200000 +--lr_decay_start=200000 +--batch_size_D=64 +--batch_size_G=128 +--num_workers=10 +--lr_D=0.0002 +--lr_G=0.0002 +--n_dis=1 +--z_dim=128 +--cr=5 +--n_classes=1 + +--ema_decay=0.9999 +--ema_start=0 + +--sample_step=500 +--sample_size=64 +--eval_step=5000 +--save_step=20000 +--num_images=50000 +--fid_stats=./stats/stl10.unlabeled.48.npz +--logdir=./logs/GN-GAN-CR_STL10_CNN_0 diff --git a/sn_testing/config/GN-GAN-CR_STL10_CNN_MODULE.txt b/sn_testing/config/GN-GAN-CR_STL10_CNN_MODULE.txt new file mode 100644 index 0000000..64c5822 --- /dev/null +++ b/sn_testing/config/GN-GAN-CR_STL10_CNN_MODULE.txt @@ -0,0 +1,25 @@ +--dataset=stl10.48 +--arch=dcgan.48 +--loss=bce +--total_steps=200000 +--lr_decay_start=200000 +--batch_size_D=64 +--batch_size_G=128 +--num_workers=10 +--lr_D=0.0002 +--lr_G=0.0002 +--n_dis=1 +--z_dim=128 +--cr=5 +--n_classes=1 + +--ema_decay=0.9999 +--ema_start=0 + +--sample_step=500 +--sample_size=64 +--eval_step=5000 +--save_step=20000 +--num_images=50000 +--fid_stats=./stats/stl10.unlabeled.48.npz +--logdir=./logs/GN-GAN-CR_STL10_CNN_0 diff --git a/sn_testing/config/GN-GAN-CR_STL10_RES.txt b/sn_testing/config/GN-GAN-CR_STL10_RES.txt new file mode 100644 index 0000000..59f0a6c --- /dev/null +++ b/sn_testing/config/GN-GAN-CR_STL10_RES.txt @@ -0,0 +1,25 @@ +--dataset=stl10.48 +--arch=resnet.48 +--loss=hinge +--total_steps=200000 +--lr_decay_start=0 +--batch_size_D=64 +--batch_size_G=128 +--num_workers=10 +--lr_D=0.0004 +--lr_G=0.0002 +--n_dis=5 +--z_dim=128 +--cr=5 +--n_classes=1 + +--ema_decay=0.9999 +--ema_start=0 + +--sample_size=64 +--sample_step=500 +--eval_step=5000 +--save_step=20000 +--num_images=50000 +--fid_stats=./stats/stl10.unlabeled.48.npz +--logdir=./logs/GN-GAN-CR_STL10_RES_0 diff --git a/sn_testing/config/GN-GAN_CELEBAHQ128_RES.txt b/sn_testing/config/GN-GAN_CELEBAHQ128_RES.txt new file mode 100644 index 0000000..9218e51 --- /dev/null +++ b/sn_testing/config/GN-GAN_CELEBAHQ128_RES.txt @@ -0,0 +1,22 @@ +--dataset=celebahq.128 +--arch=resnet.128 +--total_steps=100000 +--batch_size_D=64 +--batch_size_G=128 +--accumulation=1 +--num_workers=10 +--lr_D=0.0002 +--lr_G=0.0002 +--n_dis=5 +--z_dim=128 + +--ema_decay=0.9999 +--ema_start=10000 + +--sample_step=500 +--sample_size=64 +--eval_step=1000 +--save_step=20000 +--num_images=3000 +--fid_stats=./stats/celebahq.3k.128.npz +--logdir=./logs/GN-GAN_CELEBAHQ128_RES_0_BS64 diff --git a/sn_testing/config/GN-GAN_CELEBAHQ256_RES.txt b/sn_testing/config/GN-GAN_CELEBAHQ256_RES.txt new file mode 100644 index 0000000..ea820a6 --- /dev/null +++ b/sn_testing/config/GN-GAN_CELEBAHQ256_RES.txt @@ -0,0 +1,22 @@ +--dataset=celebahq.256 +--arch=resnet.256 +--total_steps=100000 +--batch_size_D=64 +--batch_size_G=128 +--accumulation=1 +--num_workers=10 +--lr_D=0.0002 +--lr_G=0.0002 +--n_dis=5 +--z_dim=128 + +--ema_decay=0.9999 +--ema_start=10000 + +--sample_step=500 +--sample_size=64 +--eval_step=1000 +--save_step=20000 +--num_images=10000 +--fid_stats=./stats/celebahq.all.256.npz +--logdir=./logs/GN-GAN_CELEBAHQ256_RES_0_BS64 diff --git a/sn_testing/config/GN-GAN_CHURCH256_RES.txt b/sn_testing/config/GN-GAN_CHURCH256_RES.txt new file mode 100644 index 0000000..c90e41d --- /dev/null +++ b/sn_testing/config/GN-GAN_CHURCH256_RES.txt @@ -0,0 +1,22 @@ +--dataset=lsun_church.256 +--arch=resnet.256 +--total_steps=100000 +--batch_size_D=64 +--batch_size_G=128 +--accumulation=1 +--num_workers=10 +--lr_D=0.0002 +--lr_G=0.0002 +--n_dis=5 +--z_dim=128 + +--ema_decay=0.9999 +--ema_start=20000 + +--sample_step=500 +--sample_size=64 +--eval_step=1000 +--save_step=20000 +--num_images=10000 +--fid_stats=./stats/church.train.256.npz +--logdir=./logs/GN-GAN_CHURCH256_RES_0_BS64 diff --git a/sn_testing/config/GN-GAN_CIFAR10_BIGGAN.txt b/sn_testing/config/GN-GAN_CIFAR10_BIGGAN.txt new file mode 100644 index 0000000..c807fb3 --- /dev/null +++ b/sn_testing/config/GN-GAN_CIFAR10_BIGGAN.txt @@ -0,0 +1,27 @@ +--dataset=cifar10.32 +--arch=biggan.32 +--loss=hinge +--total_steps=75000 +--lr_decay_start=75000 +--batch_size_D=50 +--batch_size_G=50 +--num_workers=8 +--lr_D=0.0002 +--lr_G=0.0001 +--betas=0.0 +--betas=0.999 +--n_dis=4 +--z_dim=128 +--cr=0 +--n_classes=10 + +--ema_decay=0.9999 +--ema_start=1000 + +--sample_step=500 +--sample_size=64 +--eval_step=5000 +--save_step=20000 +--num_images=50000 +--fid_stats=./stats/cifar10.train.npz +--logdir=./logs/GN-GAN_CIFAR10_BIGGAN_0 diff --git a/sn_testing/config/GN-GAN_CIFAR10_BIGGAN_MODULE.txt b/sn_testing/config/GN-GAN_CIFAR10_BIGGAN_MODULE.txt new file mode 100644 index 0000000..88db8a5 --- /dev/null +++ b/sn_testing/config/GN-GAN_CIFAR10_BIGGAN_MODULE.txt @@ -0,0 +1,27 @@ +--dataset=cifar10.32 +--arch=biggan.32 +--loss=hinge +--total_steps=75000 +--lr_decay_start=75000 +--batch_size_D=50 +--batch_size_G=50 +--num_workers=8 +--lr_D=0.0002 +--lr_G=0.0001 +--betas=0.0 +--betas=0.999 +--n_dis=4 +--z_dim=128 +--cr=0 +--n_classes=10 + +--ema_decay=0.9999 +--ema_start=1000 + +--sample_step=500 +--sample_size=64 +--eval_step=5000 +--save_step=20000 +--num_images=50000 +--fid_stats=./stats/cifar10.train.npz +--logdir=./logs/GN-GAN_CIFAR10_BIGGAN_MODULE_0 diff --git a/sn_testing/config/GN-GAN_CIFAR10_CNN.txt b/sn_testing/config/GN-GAN_CIFAR10_CNN.txt new file mode 100644 index 0000000..734d26f --- /dev/null +++ b/sn_testing/config/GN-GAN_CIFAR10_CNN.txt @@ -0,0 +1,25 @@ +--dataset=cifar10.32 +--arch=dcgan.32 +--loss=bce +--total_steps=200000 +--lr_decay_start=200000 +--batch_size_D=64 +--batch_size_G=128 +--num_workers=10 +--lr_D=0.0002 +--lr_G=0.0002 +--n_dis=1 +--z_dim=128 +--cr=0 +--n_classes=1 + +--ema_decay=0.9999 +--ema_start=0 + +--sample_step=500 +--sample_size=64 +--eval_step=5000 +--save_step=20000 +--num_images=50000 +--fid_stats=./stats/cifar10.train.npz +--logdir=./logs/GN-GAN_CIFAR10_CNN_0 diff --git a/sn_testing/config/GN-GAN_CIFAR10_CNN_MODULE.txt b/sn_testing/config/GN-GAN_CIFAR10_CNN_MODULE.txt new file mode 100644 index 0000000..3ad5aa5 --- /dev/null +++ b/sn_testing/config/GN-GAN_CIFAR10_CNN_MODULE.txt @@ -0,0 +1,25 @@ +--dataset=cifar10.32 +--arch=dcgan.32 +--loss=bce +--total_steps=200000 +--lr_decay_start=200000 +--batch_size_D=64 +--batch_size_G=128 +--num_workers=10 +--lr_D=0.0002 +--lr_G=0.0002 +--n_dis=1 +--z_dim=128 +--cr=0 +--n_classes=1 + +--ema_decay=0.9999 +--ema_start=0 + +--sample_step=500 +--sample_size=64 +--eval_step=5000 +--save_step=20000 +--num_images=50000 +--fid_stats=./stats/cifar10.train.npz +--logdir=./logs/GN-GAN_CIFAR10_CNN_MODULE_0 diff --git a/sn_testing/config/GN-GAN_CIFAR10_RES.txt b/sn_testing/config/GN-GAN_CIFAR10_RES.txt new file mode 100644 index 0000000..4310265 --- /dev/null +++ b/sn_testing/config/GN-GAN_CIFAR10_RES.txt @@ -0,0 +1,25 @@ +--dataset=cifar10.32 +--arch=resnet.32 +--loss=hinge +--total_steps=200000 +--lr_decay_start=0 +--batch_size_D=64 +--batch_size_G=128 +--num_workers=10 +--lr_D=0.0004 +--lr_G=0.0002 +--n_dis=5 +--z_dim=128 +--cr=0 +--n_classes=1 + +--ema_decay=0.9999 +--ema_start=0 + +--sample_step=500 +--sample_size=64 +--eval_step=5000 +--save_step=20000 +--num_images=50000 +--fid_stats=./stats/cifar10.train.npz +--logdir=./logs/GN-GAN_CIFAR10_RES_0 diff --git a/sn_testing/config/GN-GAN_STL10_CNN.txt b/sn_testing/config/GN-GAN_STL10_CNN.txt new file mode 100644 index 0000000..6c9f938 --- /dev/null +++ b/sn_testing/config/GN-GAN_STL10_CNN.txt @@ -0,0 +1,25 @@ +--dataset=stl10.48 +--arch=dcgan.48 +--loss=bce +--total_steps=200000 +--lr_decay_start=200000 +--batch_size_D=64 +--batch_size_G=128 +--num_workers=10 +--lr_D=0.0002 +--lr_G=0.0002 +--n_dis=1 +--z_dim=128 +--cr=0 +--n_classes=1 + +--ema_decay=0.9999 +--ema_start=0 + +--sample_step=500 +--sample_size=64 +--eval_step=5000 +--save_step=20000 +--num_images=50000 +--fid_stats=./stats/stl10.unlabeled.48.npz +--logdir=./logs/GN-GAN_STL10_CNN_0 diff --git a/sn_testing/config/GN-GAN_STL10_CNN_MODULE.txt b/sn_testing/config/GN-GAN_STL10_CNN_MODULE.txt new file mode 100644 index 0000000..3dcc243 --- /dev/null +++ b/sn_testing/config/GN-GAN_STL10_CNN_MODULE.txt @@ -0,0 +1,25 @@ +--dataset=stl10.48 +--arch=dcgan.48 +--loss=bce +--total_steps=200000 +--lr_decay_start=200000 +--batch_size_D=64 +--batch_size_G=128 +--num_workers=10 +--lr_D=0.0002 +--lr_G=0.0002 +--n_dis=1 +--z_dim=128 +--cr=0 +--n_classes=1 + +--ema_decay=0.9999 +--ema_start=0 + +--sample_step=500 +--sample_size=64 +--eval_step=5000 +--save_step=20000 +--num_images=50000 +--fid_stats=./stats/stl10.unlabeled.48.npz +--logdir=./logs/GN-GAN_STL10_CNN_MODULE_0 diff --git a/sn_testing/config/GN-GAN_STL10_RES.txt b/sn_testing/config/GN-GAN_STL10_RES.txt new file mode 100644 index 0000000..a00e627 --- /dev/null +++ b/sn_testing/config/GN-GAN_STL10_RES.txt @@ -0,0 +1,25 @@ +--dataset=stl10.48 +--arch=resnet.48 +--loss=hinge +--total_steps=200000 +--lr_decay_start=0 +--batch_size_D=64 +--batch_size_G=128 +--num_workers=10 +--lr_D=0.0004 +--lr_G=0.0002 +--n_dis=5 +--z_dim=128 +--cr=0 +--n_classes=1 + +--ema_decay=0.9999 +--ema_start=0 + +--sample_size=64 +--sample_step=500 +--eval_step=5000 +--save_step=20000 +--num_images=50000 +--fid_stats=./stats/stl10.unlabeled.48.npz +--logdir=./logs/GN-GAN_STL10_RES_0 diff --git a/sn_testing/datasets.py b/sn_testing/datasets.py new file mode 100644 index 0000000..c0a8db5 --- /dev/null +++ b/sn_testing/datasets.py @@ -0,0 +1,107 @@ +import io + +import lmdb +from PIL import Image +from torchvision import datasets +from torchvision import transforms as T +from torchvision.datasets import VisionDataset + + +class LMDBDataset(VisionDataset): + def __init__(self, path, transform): + self.env = lmdb.open(path, max_readers=1, readonly=True, lock=False, + readahead=False, meminit=False) + + with self.env.begin(write=False) as txn: + self.length = txn.stat()['entries'] + + self.transform = transform + + def __len__(self): + return self.length + + def __getitem__(self, index): + env = self.env + with env.begin(write=False) as txn: + imgbytes = txn.get(f'{index}'.encode()) + + buf = io.BytesIO() + buf.write(imgbytes) + buf.seek(0) + img = Image.open(buf) + + if self.transform is not None: + img = self.transform(img) + + return img, 0 + + +def get_dataset(name, in_memory=True): + """Get datasets + + Args: + name: the format [name].[resolution], + i.g., cifar10.32, celebahq.256 + in_memory: load dataset into memory. + """ + name, img_size = name.split('.') + img_size = int(img_size) + + transform = T.Compose([ + T.Resize((img_size, img_size)), + T.RandomHorizontalFlip(), + T.ToTensor(), + T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ]) + + dataset = None + if name == 'cifar10': + dataset = datasets.CIFAR10( + root='./data', train=True, download=True, transform=transform) + if name == 'stl10': + dataset = datasets.STL10( + './data', split='unlabeled', download=True, transform=transform) + if name == 'celebahq': + dataset = LMDBDataset( + f'./data/celebahq/{img_size}', transform=transform) + if name == 'lsun_church': + dataset = datasets.LSUNClass( + './data/lsun/church/', transform, (lambda x: 0)) + if name == 'lsun_bedroom': + dataset = datasets.LSUNClass( + './data/lsun/bedroom', transform, (lambda x: 0)) + if name == 'lsun_horse': + dataset = datasets.LSUNClass( + './data/lsun/horse', transform, (lambda x: 0)) + if dataset is None: + raise ValueError(f'Unknown dataset {name}') + return dataset + + +if __name__ == '__main__': + import argparse + import os + from glob import glob + from tqdm import tqdm + + parser = argparse.ArgumentParser() + parser.add_argument('path', type=str) + parser.add_argument('out', type=str) + args = parser.parse_args() + + with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env: + with env.begin(write=True) as txn: + files = glob(os.path.join(args.path, '*.jpg')) + try: + files = sorted( + files, + key=lambda f: int(os.path.splitext(os.path.basename(f))[0]) + ) + print("Sort by file number") + except ValueError: + files = sorted(files) + print("Sort by file path") + for i, file in enumerate(tqdm(files, dynamic_ncols=True)): + key = f'{i}'.encode() + img = open(file, 'rb').read() + txn.put(key, img) diff --git a/sn_testing/losses.py b/sn_testing/losses.py new file mode 100644 index 0000000..f69ce5e --- /dev/null +++ b/sn_testing/losses.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BCEWithLogits(nn.Module): + def __init__(self): + super().__init__() + self.bce = nn.BCEWithLogitsLoss() + + def forward(self, pred_real, pred_fake=None): + if pred_fake is not None: + loss_real = self.bce(pred_real, torch.ones_like(pred_real)) + loss_fake = self.bce(pred_fake, torch.zeros_like(pred_fake)) + loss = loss_real + loss_fake + return loss, loss_real, loss_fake + else: + loss = self.bce(pred_real, torch.ones_like(pred_real)) + return loss + + +class HingeLoss(nn.Module): + def forward(self, pred_real, pred_fake=None): + if pred_fake is not None: + loss_real = F.relu(1 - pred_real).mean() + loss_fake = F.relu(1 + pred_fake).mean() + loss = loss_real + loss_fake + return loss, loss_real, loss_fake + else: + loss = -pred_real.mean() + return loss + + +class Wasserstein(nn.Module): + def forward(self, pred_real, pred_fake=None): + if pred_fake is not None: + loss_real = pred_real.mean() + loss_fake = pred_fake.mean() + loss = -loss_real + loss_fake + return loss, loss_real, loss_fake + else: + loss = -pred_real.mean() + return loss + + +class BCE(nn.Module): + def __init__(self): + super().__init__() + self.bce = nn.BCELoss() + + def forward(self, pred_real, pred_fake=None): + if pred_fake is not None: + loss_real = self.bce( + (pred_real + 1) / 2, torch.ones_like(pred_real)) + loss_fake = self.bce( + (pred_fake + 1) / 2, torch.zeros_like(pred_fake)) + loss = loss_real + loss_fake + return loss, loss_real, loss_fake + else: + loss = self.bce( + (pred_real + 1) / 2, torch.ones_like(pred_real)) + return loss diff --git a/sn_testing/optim.py b/sn_testing/optim.py new file mode 100644 index 0000000..8cdaaa3 --- /dev/null +++ b/sn_testing/optim.py @@ -0,0 +1,193 @@ +import math + +import torch +from torch.optim import Optimizer +from torch import Tensor +from typing import List + + +def adam(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[int], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float): + r"""Functional API that performs Adam algorithm computation. + See :class:`~torch.optim.Adam` for details. + """ + + for i, param in enumerate(params): + + grad = grads[i] + if beta1 != 0: + exp_avg = exp_avgs[i] + if beta2 != 0: + exp_avg_sq = exp_avg_sqs[i] + step = state_steps[i] + + bias_correction1 = 1 - beta1 ** step + bias_correction2 = 1 - beta2 ** step + + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + # Decay the first and second moment running average coefficient + if beta1 != 0: + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + else: + exp_avg = grad + if beta2 != 0: + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + else: + exp_avg_sq = grad * grad + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.maximum( + max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) + # Use the max. for normalizing running avg. of gradient + denom = ( + max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2) + ).add_(eps) + else: + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + step_size = lr / bias_correction1 + param.addcdiv_(exp_avg, denom, value=-step_size) + + +class Adam(Optimizer): + r"""Implements Adam algorithm. + + It has been proposed in `Adam: A Method for Stochastic Optimization`_. + The implementation of the L2 penalty follows changes proposed in + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad) + super(Adam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(Adam, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + beta1, beta2 = group['betas'] + params_with_grad = [] + grads = [] + if beta1 != 0: + exp_avgs = [] + else: + exp_avgs = None + if beta2 != 0: + exp_avg_sqs = [] + else: + exp_avg_sqs = None + max_exp_avg_sqs = [] + state_steps = [] + + for p in group['params']: + if p.grad is not None: + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + grads.append(p.grad) + + state = self.state[p] + # Lazy state initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + if beta1 != 0: + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + if beta2 != 0: + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + if group['amsgrad']: + # Maintains max of all exp. moving avg. of sq. + # grad. values + state['max_exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + if beta1 != 0: + exp_avgs.append(state['exp_avg']) + if beta2 != 0: + exp_avg_sqs.append(state['exp_avg_sq']) + + if group['amsgrad']: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step']) + + adam(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + group['amsgrad'], + beta1, + beta2, + group['lr'], + group['weight_decay'], + group['eps']) + return loss diff --git a/sn_testing/output.txt b/sn_testing/output.txt new file mode 100644 index 0000000..e69de29 diff --git a/sn_testing/stats/bedroom.train.256.npz b/sn_testing/stats/bedroom.train.256.npz new file mode 100644 index 0000000..69bec6a Binary files /dev/null and b/sn_testing/stats/bedroom.train.256.npz differ diff --git a/sn_testing/stats/celebahq.3k.128.npz b/sn_testing/stats/celebahq.3k.128.npz new file mode 100644 index 0000000..d366b7d Binary files /dev/null and b/sn_testing/stats/celebahq.3k.128.npz differ diff --git a/sn_testing/stats/celebahq.all.1024.npz b/sn_testing/stats/celebahq.all.1024.npz new file mode 100644 index 0000000..1508119 Binary files /dev/null and b/sn_testing/stats/celebahq.all.1024.npz differ diff --git a/sn_testing/stats/celebahq.all.256.npz b/sn_testing/stats/celebahq.all.256.npz new file mode 100644 index 0000000..9a9392d Binary files /dev/null and b/sn_testing/stats/celebahq.all.256.npz differ diff --git a/sn_testing/stats/church.train.256.npz b/sn_testing/stats/church.train.256.npz new file mode 100644 index 0000000..4d5d49d Binary files /dev/null and b/sn_testing/stats/church.train.256.npz differ diff --git a/sn_testing/stats/cifar10.test.npz b/sn_testing/stats/cifar10.test.npz new file mode 100644 index 0000000..fbe8b66 Binary files /dev/null and b/sn_testing/stats/cifar10.test.npz differ diff --git a/sn_testing/stats/cifar10.train.npz b/sn_testing/stats/cifar10.train.npz new file mode 100644 index 0000000..0d9b0cd Binary files /dev/null and b/sn_testing/stats/cifar10.train.npz differ diff --git a/sn_testing/stats/horse.train.256.npz b/sn_testing/stats/horse.train.256.npz new file mode 100644 index 0000000..a5ea713 Binary files /dev/null and b/sn_testing/stats/horse.train.256.npz differ diff --git a/sn_testing/stats/stl10.unlabeled.48.npz b/sn_testing/stats/stl10.unlabeled.48.npz new file mode 100644 index 0000000..a7965bf Binary files /dev/null and b/sn_testing/stats/stl10.unlabeled.48.npz differ diff --git a/sn_testing/train.py b/sn_testing/train.py new file mode 100644 index 0000000..7990f11 --- /dev/null +++ b/sn_testing/train.py @@ -0,0 +1,380 @@ +import os +import json +import math + +import torch +import torch.optim as optim +from absl import flags, app +from torchvision import transforms +from torchvision.utils import make_grid, save_image +from tensorboardX import SummaryWriter +from tqdm import trange +from pytorch_image_generation_metrics import get_inception_score_and_fid + +from datasets import get_dataset +from losses import HingeLoss, BCEWithLogits, Wasserstein +from SN_models import resnet, dcgan, biggan + +from utils import ema, save_images, infiniteloop, set_seed, module_no_grad + +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True + +def normalize_gradient(net_D, x, **kwargs): #this should disable normalize grad from doing anything + x.requires_grad_(True) + f = net_D(x, **kwargs) + return f + + + +net_G_models = { + 'dcgan.32': dcgan.Generator32, + 'dcgan.48': dcgan.Generator48, + 'resnet.32': resnet.ResGenerator32, + 'resnet.48': resnet.ResGenerator48, + 'biggan.32': biggan.Generator32, +} + +net_D_models = { + 'dcgan.32': dcgan.Discriminator32, + 'dcgan.48': dcgan.Discriminator48, + 'resnet.32': resnet.ResDiscriminator32, + 'resnet.48': resnet.ResDiscriminator48, + 'biggan.32': biggan.Discriminator32, +} + +loss_fns = { + 'hinge': HingeLoss, + 'bce': BCEWithLogits, + 'wass': Wasserstein, +} + + +datasets = ['cifar10.32', 'stl10.48'] + + +FLAGS = flags.FLAGS +# resume +flags.DEFINE_bool('resume', False, 'resume from checkpoint') +flags.DEFINE_bool('eval', False, 'load model and evaluate it') +flags.DEFINE_string('save', "", 'load model and save sample images to dir') +# model and training +flags.DEFINE_enum('dataset', 'cifar10.32', datasets, "select dataset") +flags.DEFINE_enum('arch', 'resnet.32', net_G_models.keys(), "architecture") +flags.DEFINE_enum('loss', 'hinge', loss_fns.keys(), "loss function") +flags.DEFINE_integer('total_steps', 200000, "total number of training steps") +flags.DEFINE_integer('lr_decay_start', 0, 'apply linearly decay to lr') +flags.DEFINE_integer('batch_size_D', 64, "batch size for discriminator") +flags.DEFINE_integer('batch_size_G', 128, "batch size for generator") +flags.DEFINE_integer('num_workers', 10, "dataloader workers") +flags.DEFINE_float('lr_D', 4e-4, "Discriminator learning rate") +flags.DEFINE_float('lr_G', 2e-4, "Generator learning rate") +flags.DEFINE_multi_float('betas', [0.0, 0.9], "for Adam") +flags.DEFINE_integer('n_dis', 5, "update Generator every this steps") +flags.DEFINE_integer('z_dim', 128, "latent space dimension") +flags.DEFINE_float('cr', 0, "weight for consistency regularization") +flags.DEFINE_integer('seed', 0, "random seed") +# conditional +flags.DEFINE_integer('n_classes', 1, 'the number of classes in dataset') +# ema +flags.DEFINE_float('ema_decay', 0.9999, "ema decay rate") +flags.DEFINE_integer('ema_start', 0, "start step for ema") +# logging +flags.DEFINE_integer('sample_step', 500, "sample image every this steps") +flags.DEFINE_integer('sample_size', 64, "sampling size of images") +flags.DEFINE_integer('eval_step', 5000, "evaluate FID and Inception Score") +flags.DEFINE_integer('save_step', 20000, "save model every this step") +flags.DEFINE_integer('num_images', 50000, '# images for evaluation') +flags.DEFINE_string('fid_stats', './stats/cifar10.train.npz', 'FID cache') +flags.DEFINE_string('logdir', './logs/GN-GAN_CIFAR10_RES_0', 'log folder') + + +device = torch.device('cuda') + + +def generate_images(net_G): + images = [] + with torch.no_grad(): + for _ in trange(0, FLAGS.num_images, FLAGS.batch_size_G, + ncols=0, leave=False): + z = torch.randn(FLAGS.batch_size_G, FLAGS.z_dim).to(device) + y = torch.randint( + FLAGS.n_classes, (FLAGS.batch_size_G,)).to(device) + fake = (net_G(z, y) + 1) / 2 + images.append(fake) + images = torch.cat(images, dim=0) + return images[:FLAGS.num_images] + + +def eval_save(): + net_G = net_G_models[FLAGS.arch](FLAGS.z_dim, FLAGS.n_classes).to(device) + ckpt = torch.load(os.path.join(FLAGS.logdir, 'best_model.pt')) + net_G.load_state_dict(ckpt['net_G']) + + images = generate_images(net_G=net_G) + if FLAGS.eval: + (IS, IS_std), FID = get_inception_score_and_fid( + images, FLAGS.fid_stats, verbose=True) + print("IS: %6.3f(%.3f), FID: %7.3f" % (IS, IS_std, FID)) + if FLAGS.save: + save_images(images, FLAGS.save, verbose=True) + + +def evaluate(net_G): + images = generate_images(net_G=net_G) + (IS, IS_std), FID = get_inception_score_and_fid( + images, FLAGS.fid_stats, verbose=True) + del images + return (IS, IS_std), FID + + +def consistency_loss(net_D, real, y_real, pred_real, + transform=transforms.Compose([ + transforms.Lambda(lambda x: (x + 1) / 2), + transforms.ToPILImage(mode='RGB'), + transforms.RandomHorizontalFlip(), + transforms.RandomAffine(0, translate=(0.2, 0.2)), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ])): + aug_real = real.detach().clone() + for idx, img in enumerate(aug_real): + aug_real[idx] = transform(img) + aug_real = aug_real.to(device) + pred_aug = normalize_gradient(net_D, aug_real, y=y_real) + loss = ((pred_aug - pred_real) ** 2).mean() + return loss + + +def train(): + dataset = get_dataset(FLAGS.dataset) + dataloader = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=FLAGS.batch_size_D * FLAGS.n_dis, + shuffle=True, + num_workers=FLAGS.num_workers, + drop_last=True) + looper = infiniteloop(dataloader) + + # model + net_G = net_G_models[FLAGS.arch](FLAGS.z_dim, FLAGS.n_classes).to(device) + ema_G = net_G_models[FLAGS.arch](FLAGS.z_dim, FLAGS.n_classes).to(device) + net_D = net_D_models[FLAGS.arch](FLAGS.n_classes).to(device) + + # ema + ema(net_G, ema_G, decay=0) + + # loss + loss_fn = loss_fns[FLAGS.loss]() + + # optimizer + optim_G = optim.Adam(net_G.parameters(), lr=FLAGS.lr_G, betas=FLAGS.betas) + optim_D = optim.Adam(net_D.parameters(), lr=FLAGS.lr_D, betas=FLAGS.betas) + + # scheduler + def decay_rate(step): + period = max(FLAGS.total_steps - FLAGS.lr_decay_start, 1) + return 1 - max(step - FLAGS.lr_decay_start, 0) / period + sched_G = optim.lr_scheduler.LambdaLR(optim_G, lr_lambda=decay_rate) + sched_D = optim.lr_scheduler.LambdaLR(optim_D, lr_lambda=decay_rate) + + D_size = 0 + for param in net_D.parameters(): + D_size += param.data.nelement() + G_size = 0 + for param in net_G.parameters(): + G_size += param.data.nelement() + #print('D params: %d, G params: %d' % (D_size, G_size)) + + writer = SummaryWriter(FLAGS.logdir) + if FLAGS.resume: + ckpt = torch.load(os.path.join(FLAGS.logdir, 'model.pt')) + net_G.load_state_dict(ckpt['net_G']) + net_D.load_state_dict(ckpt['net_D']) + ema_G.load_state_dict(ckpt['ema_G']) + optim_G.load_state_dict(ckpt['optim_G']) + optim_D.load_state_dict(ckpt['optim_D']) + sched_G.load_state_dict(ckpt['sched_G']) + sched_D.load_state_dict(ckpt['sched_D']) + fixed_z = ckpt['fixed_z'] + fixed_y = ckpt['fixed_y'] + # start value + start = ckpt['step'] + 1 + best_IS, best_FID = ckpt['best_IS'], ckpt['best_FID'] + del ckpt + else: + # sample fixed z + fixed_z = torch.randn(FLAGS.sample_size, FLAGS.z_dim).to(device) + fixed_y = torch.randint( + FLAGS.n_classes, (FLAGS.sample_size,)).to(device) + # start value + start, best_IS, best_FID = 1, 0, 999 + + os.makedirs(os.path.join(FLAGS.logdir, 'sample')) + with open(os.path.join(FLAGS.logdir, "flagfile.txt"), 'w') as f: + f.write(FLAGS.flags_into_string()) + real = next(iter(dataloader))[0][:FLAGS.sample_size] + writer.add_image('real_sample', make_grid((real + 1) / 2)) + writer.flush() + + with trange(start, FLAGS.total_steps + 1, ncols=0, + initial=start - 1, total=FLAGS.total_steps) as pbar: + for step in pbar: + loss_sum = 0 + loss_real_sum = 0 + loss_fake_sum = 0 + loss_cr_sum = 0 + + x, y = next(looper) + x = iter(torch.split(x, FLAGS.batch_size_D)) + y = iter(torch.split(y, FLAGS.batch_size_D)) + # Discriminator + for _ in range(FLAGS.n_dis): + optim_D.zero_grad() + x_real, y_real = next(x).to(device), next(y).to(device) + + with torch.no_grad(): + z_ = torch.randn( + FLAGS.batch_size_D, FLAGS.z_dim).to(device) + y_fake = torch.randint( + FLAGS.n_classes, (FLAGS.batch_size_D,)).to(device) + x_fake = net_G(z_, y_fake).detach() + x_real_fake = torch.cat([x_real, x_fake], dim=0) + y_real_fake = torch.cat([y_real, y_fake], dim=0) + pred = normalize_gradient(net_D, x_real_fake, y=y_real_fake) + pred_real, pred_fake = torch.split( + pred, [x_real.shape[0], x_fake.shape[0]]) + + loss, loss_real, loss_fake = loss_fn(pred_real, pred_fake) + if FLAGS.cr > 0: + loss_cr = consistency_loss( + net_D, x_real, y_real, pred_real).to(device) + else: + loss_cr = torch.tensor(0.).to(device) + loss_all = loss + FLAGS.cr * loss_cr + loss_all.backward() + optim_D.step() + + loss_sum += loss.item() + loss_real_sum += loss_real.item() + loss_fake_sum += loss_fake.item() + loss_cr_sum += loss_cr.item() + + loss = loss_sum / FLAGS.n_dis + loss_real = loss_real_sum / FLAGS.n_dis + loss_fake = loss_fake_sum / FLAGS.n_dis + loss_cr = loss_cr_sum / FLAGS.n_dis + + writer.add_scalar('loss', loss, step) + writer.add_scalar('loss_real', loss_real, step) + writer.add_scalar('loss_fake', loss_fake, step) + writer.add_scalar('loss_cr', loss_cr, step) + + print('loss:', loss, 'loss_real:', loss_real, 'loss_fake:', loss_fake, 'loss_cr', loss_cr, 'step:', step) + + pbar.set_postfix( + loss_real='%.3f' % loss_real, + loss_fake='%.3f' % loss_fake) + + # Generator + with module_no_grad(net_D): + optim_G.zero_grad() + z_ = torch.randn(FLAGS.batch_size_G, FLAGS.z_dim).to(device) + y_ = torch.randint( + FLAGS.n_classes, (FLAGS.batch_size_G,)).to(device) + fake = net_G(z_, y_) + pred_fake = normalize_gradient(net_D, fake, y=y_) + loss = loss_fn(pred_fake) + loss.backward() + optim_G.step() + + # ema + if step < FLAGS.ema_start: + decay = 0 + else: + decay = FLAGS.ema_decay + ema(net_G, ema_G, decay) + + # scheduler + sched_G.step() + sched_D.step() + + # sample from fixed z + if step == 1 or step % FLAGS.sample_step == 0: + with torch.no_grad(): + fake_net = net_G(fixed_z, fixed_y) + fake_ema = ema_G(fixed_z, fixed_y) + grid_net = (make_grid(fake_net) + 1) / 2 + grid_ema = (make_grid(fake_ema) + 1) / 2 + writer.add_image('sample_ema', grid_ema, step) + writer.add_image('sample', grid_net, step) + save_image( + grid_ema, + os.path.join(FLAGS.logdir, 'sample', '%d.png' % step)) + + # evaluate IS, FID and save model + if step == 1 or step % FLAGS.eval_step == 0: + (IS, IS_std), FID = evaluate(net_G) + (IS_ema, IS_std_ema), FID_ema = evaluate(ema_G) + if not math.isnan(FID) and not math.isnan(best_FID): + save_as_best = (FID < best_FID) + else: + save_as_best = (IS > best_IS) + if save_as_best: + best_IS = IS + best_FID = FID + ckpt = { + 'net_G': net_G.state_dict(), + 'net_D': net_D.state_dict(), + 'ema_G': ema_G.state_dict(), + 'optim_G': optim_G.state_dict(), + 'optim_D': optim_D.state_dict(), + 'sched_G': sched_G.state_dict(), + 'sched_D': sched_D.state_dict(), + 'fixed_y': fixed_y, + 'fixed_z': fixed_z, + 'best_IS': best_IS, + 'best_FID': best_FID, + 'step': step, + } + if step == 1 or step % FLAGS.save_step == 0: + torch.save( + ckpt, os.path.join(FLAGS.logdir, '%06d.pt' % step)) + if save_as_best: + torch.save( + ckpt, os.path.join(FLAGS.logdir, 'best_model.pt')) + torch.save(ckpt, os.path.join(FLAGS.logdir, 'model.pt')) + metrics = { + 'IS': IS, + 'IS_std': IS_std, + 'FID': FID, + 'IS_EMA': IS_ema, + 'IS_std_EMA': IS_std_ema, + 'FID_EMA': FID_ema, + } + for name, value in metrics.items(): + writer.add_scalar(name, value, step) + writer.flush() + with open(os.path.join(FLAGS.logdir, 'eval.txt'), 'a') as f: + metrics['step'] = step + f.write(json.dumps(metrics) + "\n") + k = len(str(FLAGS.total_steps)) + pbar.write( + f"{step:{k}d}/{FLAGS.total_steps} " + f"IS: {IS:6.3f}({IS_std:.3f}), " + f"FID: {FID:.3f}, " + f"IS_EMA: {IS_ema:6.3f}({IS_std_ema:.3f}), " + f"FID_EMA: {FID_ema:.3f}") + writer.close() + + +def main(argv): + set_seed(FLAGS.seed) + if FLAGS.eval or FLAGS.save: + eval_save() + else: + train() + + +if __name__ == '__main__': + app.run(main) diff --git a/sn_testing/train_model_sn.py b/sn_testing/train_model_sn.py new file mode 100644 index 0000000..d966840 --- /dev/null +++ b/sn_testing/train_model_sn.py @@ -0,0 +1,367 @@ +import os +import json +import math + +import torch +import torch.optim as optim +from absl import flags, app +from torchvision import transforms +from torchvision.utils import make_grid, save_image +from tensorboardX import SummaryWriter +from tqdm import trange +from pytorch_image_generation_metrics import get_inception_score_and_fid + +from datasets import get_dataset +from losses import HingeLoss, BCEWithLogits, Wasserstein + +from SN_models import dcgan_module +from SN_models import resnet_module + +from SN_models import gradnorm + +from utils import ema, save_images, infiniteloop, set_seed, module_no_grad + + +net_G_models = { + 'dcgan.32': dcgan_module.Generator32, + 'dcgan.48': dcgan_module.Generator48, +} + +net_D_models = { + 'dcgan.32': dcgan_module.Discriminator32, + 'dcgan.48': dcgan_module.Discriminator48, +} + +loss_fns = { + 'hinge': HingeLoss, + 'bce': BCEWithLogits, + 'wass': Wasserstein, +} + + +datasets = ['cifar10.32', 'stl10.48'] + + +FLAGS = flags.FLAGS +# resume +flags.DEFINE_bool('resume', False, 'resume from checkpoint') +flags.DEFINE_bool('eval', False, 'load model and evaluate it') +flags.DEFINE_string('save', "", 'load model and save sample images to dir') +# model and training +flags.DEFINE_enum('dataset', 'cifar10.32', datasets, "select dataset") +flags.DEFINE_enum('arch', 'resnet.32', net_G_models.keys(), "architecture") +flags.DEFINE_enum('loss', 'hinge', loss_fns.keys(), "loss function") +flags.DEFINE_integer('total_steps', 200000, "total number of training steps") +flags.DEFINE_integer('lr_decay_start', 0, 'apply linearly decay to lr') +flags.DEFINE_integer('batch_size_D', 64, "batch size for discriminator") +flags.DEFINE_integer('batch_size_G', 128, "batch size for generator") +flags.DEFINE_integer('num_workers', 10, "dataloader workers") +flags.DEFINE_float('lr_D', 4e-4, "Discriminator learning rate") +flags.DEFINE_float('lr_G', 2e-4, "Generator learning rate") +flags.DEFINE_multi_float('betas', [0.0, 0.9], "for Adam") +flags.DEFINE_integer('n_dis', 5, "update Generator every this steps") +flags.DEFINE_integer('z_dim', 128, "latent space dimension") +flags.DEFINE_float('cr', 0, "weight for consistency regularization") +flags.DEFINE_integer('seed', 0, "random seed") +# conditional +flags.DEFINE_integer('n_classes', 1, 'the number of classes in dataset') +# ema +flags.DEFINE_float('ema_decay', 0.9999, "ema decay rate") +flags.DEFINE_integer('ema_start', 0, "start step for ema") +# logging +flags.DEFINE_integer('sample_step', 500, "sample image every this steps") +flags.DEFINE_integer('sample_size', 64, "sampling size of images") +flags.DEFINE_integer('eval_step', 5000, "evaluate FID and Inception Score") +flags.DEFINE_integer('save_step', 20000, "save model every this step") +flags.DEFINE_integer('num_images', 50000, '# images for evaluation') +flags.DEFINE_string('fid_stats', './stats/cifar10.train.npz', 'FID cache') +flags.DEFINE_string('logdir', './logs/GN-GAN_CIFAR10_RES_0', 'log folder') + + +device = torch.device('cuda') + + +def generate_images(net_G): + images = [] + with torch.no_grad(): + for _ in trange(0, FLAGS.num_images, FLAGS.batch_size_G, + ncols=0, leave=False): + z = torch.randn(FLAGS.batch_size_G, FLAGS.z_dim).to(device) + y = torch.randint( + FLAGS.n_classes, (FLAGS.batch_size_G,)).to(device) + fake = (net_G(z, y) + 1) / 2 + images.append(fake.cuda()) + images = torch.cat(images, dim=0) + return images[:FLAGS.num_images] + + +def eval_save(): + net_G = net_G_models[FLAGS.arch](FLAGS.z_dim, FLAGS.n_classes).to(device) + ckpt = torch.load(os.path.join(FLAGS.logdir, 'best_model.pt')) + net_G.load_state_dict(ckpt['net_G']) + + images = generate_images(net_G=net_G) + if FLAGS.eval: + (IS, IS_std), FID = get_inception_score_and_fid( + images, FLAGS.fid_stats, verbose=True) + print("IS: %6.3f(%.3f), FID: %7.3f" % (IS, IS_std, FID)) + if FLAGS.save: + save_images(images, FLAGS.save, verbose=True) + + +def evaluate(net_G): + images = generate_images(net_G=net_G) + (IS, IS_std), FID = get_inception_score_and_fid( + images, FLAGS.fid_stats, verbose=True) + del images + return (IS, IS_std), FID + + +def consistency_loss(net_D, real, y_real, pred_real, + transform=transforms.Compose([ + transforms.Lambda(lambda x: (x + 1) / 2), + transforms.ToPILImage(mode='RGB'), + transforms.RandomHorizontalFlip(), + transforms.RandomAffine(0, translate=(0.2, 0.2)), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ])): + aug_real = real.detach().clone().cuda() + for idx, img in enumerate(aug_real): + aug_real[idx] = transform(img) + aug_real = aug_real.to(device) + pred_aug = gradnorm.get_gradient(net_D, aug_real, y=y_real) + loss = ((pred_aug - pred_real) ** 2).mean() + return loss + + +def train(): + dataset = get_dataset(FLAGS.dataset) + dataloader = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=FLAGS.batch_size_D * FLAGS.n_dis, + shuffle=True, + num_workers=FLAGS.num_workers, + drop_last=True) + looper = infiniteloop(dataloader) + + # model + net_G = net_G_models[FLAGS.arch](FLAGS.z_dim, FLAGS.n_classes).to(device) + ema_G = net_G_models[FLAGS.arch](FLAGS.z_dim, FLAGS.n_classes).to(device) + net_D = net_D_models[FLAGS.arch](FLAGS.n_classes).to(device) + + # ema + ema(net_G, ema_G, decay=0) + + # loss + loss_fn = loss_fns[FLAGS.loss]() + + # optimizer + optim_G = optim.Adam(net_G.parameters(), lr=FLAGS.lr_G, betas=FLAGS.betas) + optim_D = optim.Adam(net_D.parameters(), lr=FLAGS.lr_D, betas=FLAGS.betas) + + # scheduler + def decay_rate(step): + period = max(FLAGS.total_steps - FLAGS.lr_decay_start, 1) + return 1 - max(step - FLAGS.lr_decay_start, 0) / period + sched_G = optim.lr_scheduler.LambdaLR(optim_G, lr_lambda=decay_rate) + sched_D = optim.lr_scheduler.LambdaLR(optim_D, lr_lambda=decay_rate) + + D_size = 0 + for param in net_D.parameters(): + D_size += param.data.nelement() + G_size = 0 + for param in net_G.parameters(): + G_size += param.data.nelement() + #print('D params: %d, G params: %d' % (D_size, G_size)) + + writer = SummaryWriter(FLAGS.logdir) + if FLAGS.resume: + ckpt = torch.load(os.path.join(FLAGS.logdir, 'model.pt')) + net_G.load_state_dict(ckpt['net_G']) + net_D.load_state_dict(ckpt['net_D']) + ema_G.load_state_dict(ckpt['ema_G']) + optim_G.load_state_dict(ckpt['optim_G']) + optim_D.load_state_dict(ckpt['optim_D']) + sched_G.load_state_dict(ckpt['sched_G']) + sched_D.load_state_dict(ckpt['sched_D']) + fixed_z = ckpt['fixed_z'] + fixed_y = ckpt['fixed_y'] + # start value + start = ckpt['step'] + 1 + best_IS, best_FID = ckpt['best_IS'], ckpt['best_FID'] + del ckpt + else: + # sample fixed z + fixed_z = torch.randn(FLAGS.sample_size, FLAGS.z_dim).to(device) + fixed_y = torch.randint( + FLAGS.n_classes, (FLAGS.sample_size,)).to(device) + # start value + start, best_IS, best_FID = 1, 0, 999 + + os.makedirs(os.path.join(FLAGS.logdir, 'sample')) + with open(os.path.join(FLAGS.logdir, "flagfile.txt"), 'w') as f: + f.write(FLAGS.flags_into_string()) + real = next(iter(dataloader))[0][:FLAGS.sample_size] + writer.add_image('real_sample', make_grid((real + 1) / 2)) + writer.flush() + + with trange(start, FLAGS.total_steps + 1, ncols=0, + initial=start - 1, total=FLAGS.total_steps) as pbar: + for step in pbar: + loss_sum = 0 + loss_real_sum = 0 + loss_fake_sum = 0 + loss_cr_sum = 0 + + x, y = next(looper) + x = iter(torch.split(x, FLAGS.batch_size_D)) + y = iter(torch.split(y, FLAGS.batch_size_D)) + # Discriminator + for _ in range(FLAGS.n_dis): + optim_D.zero_grad() + x_real, y_real = next(x).to(device), next(y).to(device) + + with torch.no_grad(): + z_ = torch.randn( + FLAGS.batch_size_D, FLAGS.z_dim).to(device) + y_fake = torch.randint( + FLAGS.n_classes, (FLAGS.batch_size_D,)).to(device) + x_fake = net_G(z_, y_fake).detach() + x_real_fake = torch.cat([x_real, x_fake], dim=0) + y_real_fake = torch.cat([y_real, y_fake], dim=0) + pred = gradnorm.get_gradient(net_D, x_real_fake, y=y_real_fake) + pred_real, pred_fake = torch.split( + pred, [x_real.shape[0], x_fake.shape[0]]) + + loss, loss_real, loss_fake = loss_fn(pred_real, pred_fake) + if FLAGS.cr > 0: + loss_cr = consistency_loss( + net_D, x_real, y_real, pred_real) + else: + loss_cr = torch.tensor(0.) + loss_all = loss + FLAGS.cr * loss_cr + loss_all.backward() + optim_D.step() + + loss_sum += loss.cuda().item() + loss_real_sum += loss_real.cuda().item() + loss_fake_sum += loss_fake.cuda().item() + loss_cr_sum += loss_cr.cuda().item() + + loss = loss_sum / FLAGS.n_dis + loss_real = loss_real_sum / FLAGS.n_dis + loss_fake = loss_fake_sum / FLAGS.n_dis + loss_cr = loss_cr_sum / FLAGS.n_dis + + writer.add_scalar('loss', loss, step) + writer.add_scalar('loss_real', loss_real, step) + writer.add_scalar('loss_fake', loss_fake, step) + writer.add_scalar('loss_cr', loss_cr, step) + + pbar.set_postfix( + loss_real='%.3f' % loss_real, + loss_fake='%.3f' % loss_fake) + + # Generator + with module_no_grad(net_D): + optim_G.zero_grad() + z_ = torch.randn(FLAGS.batch_size_G, FLAGS.z_dim).to(device) + y_ = torch.randint( + FLAGS.n_classes, (FLAGS.batch_size_G,)).to(device) + fake = net_G(z_, y_) + pred_fake = gradnorm.get_gradient(net_D, fake, y=y_) + loss = loss_fn(pred_fake) + loss.backward() + optim_G.step() + + # ema + if step < FLAGS.ema_start: + decay = 0 + else: + decay = FLAGS.ema_decay + ema(net_G, ema_G, decay) + + # scheduler + sched_G.step() + sched_D.step() + + # sample from fixed z + if step == 1 or step % FLAGS.sample_step == 0: + with torch.no_grad(): + fake_net = net_G(fixed_z, fixed_y).cuda() + fake_ema = ema_G(fixed_z, fixed_y).cuda() + grid_net = (make_grid(fake_net) + 1) / 2 + grid_ema = (make_grid(fake_ema) + 1) / 2 + writer.add_image('sample_ema', grid_ema, step) + writer.add_image('sample', grid_net, step) + save_image( + grid_ema, + os.path.join(FLAGS.logdir, 'sample', '%d.png' % step)) + + # evaluate IS, FID and save model + if step == 1 or step % FLAGS.eval_step == 0: + (IS, IS_std), FID = evaluate(net_G) + (IS_ema, IS_std_ema), FID_ema = evaluate(ema_G) + if not math.isnan(FID) and not math.isnan(best_FID): + save_as_best = (FID < best_FID) + else: + save_as_best = (IS > best_IS) + if save_as_best: + best_IS = IS + best_FID = FID + ckpt = { + 'net_G': net_G.state_dict(), + 'net_D': net_D.state_dict(), + 'ema_G': ema_G.state_dict(), + 'optim_G': optim_G.state_dict(), + 'optim_D': optim_D.state_dict(), + 'sched_G': sched_G.state_dict(), + 'sched_D': sched_D.state_dict(), + 'fixed_y': fixed_y, + 'fixed_z': fixed_z, + 'best_IS': best_IS, + 'best_FID': best_FID, + 'step': step, + } + if step == 1 or step % FLAGS.save_step == 0: + torch.save( + ckpt, os.path.join(FLAGS.logdir, '%06d.pt' % step)) + if save_as_best: + torch.save( + ckpt, os.path.join(FLAGS.logdir, 'best_model.pt')) + torch.save(ckpt, os.path.join(FLAGS.logdir, 'model.pt')) + metrics = { + 'IS': IS, + 'IS_std': IS_std, + 'FID': FID, + 'IS_EMA': IS_ema, + 'IS_std_EMA': IS_std_ema, + 'FID_EMA': FID_ema, + } + for name, value in metrics.items(): + writer.add_scalar(name, value, step) + writer.flush() + with open(os.path.join(FLAGS.logdir, 'eval.txt'), 'a') as f: + metrics['step'] = step + f.write(json.dumps(metrics) + "\n") + k = len(str(FLAGS.total_steps)) + pbar.write( + f"{step:{k}d}/{FLAGS.total_steps} " + f"IS: {IS:6.3f}({IS_std:.3f}), " + f"FID: {FID:.3f}, " + f"IS_EMA: {IS_ema:6.3f}({IS_std_ema:.3f}), " + f"FID_EMA: {FID_ema:.3f}") + writer.close() + + +def main(argv): + set_seed(FLAGS.seed) + if FLAGS.eval or FLAGS.save: + eval_save() + else: + train() + + +if __name__ == '__main__': + app.run(main) diff --git a/sn_testing/utils.py b/sn_testing/utils.py new file mode 100644 index 0000000..74512e4 --- /dev/null +++ b/sn_testing/utils.py @@ -0,0 +1,53 @@ +import os +import random +from contextlib import contextmanager + +import torch +import numpy as np +from torchvision.utils import save_image +from tqdm import tqdm + + +device = torch.device('cuda') + + +def save_images(images, output_dir, verbose=False): + os.makedirs(output_dir, exist_ok=True) + for i, image in enumerate(tqdm(images, dynamic_ncols=True, leave=False, + disable=(not verbose), desc="save_images")): + save_image(image, os.path.join(output_dir, '%d.png' % i)) + + +def infiniteloop(dataloader): + while True: + for x, y in iter(dataloader): + yield x, y + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + #torch.cuda.manual_seed(seed) + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = False + + +def ema(source, target, decay): + source_dict = source.state_dict() + target_dict = target.state_dict() + for key in source_dict.keys(): + target_dict[key].data.copy_( + target_dict[key].data * decay + + source_dict[key].data * (1 - decay)) + + +@contextmanager +def module_no_grad(m: torch.nn.Module): + requires_grad_dict = dict() + for name, param in m.named_parameters(): + requires_grad_dict[name] = param.requires_grad + param.requires_grad_(False) + yield m + for name, param in m.named_parameters(): + param.requires_grad_(requires_grad_dict[name]) diff --git a/stats/bedroom.train.256.npz b/stats/bedroom.train.256.npz new file mode 100644 index 0000000..69bec6a Binary files /dev/null and b/stats/bedroom.train.256.npz differ diff --git a/stats/celebahq.3k.128.npz b/stats/celebahq.3k.128.npz new file mode 100644 index 0000000..d366b7d Binary files /dev/null and b/stats/celebahq.3k.128.npz differ diff --git a/stats/celebahq.all.1024.npz b/stats/celebahq.all.1024.npz new file mode 100644 index 0000000..1508119 Binary files /dev/null and b/stats/celebahq.all.1024.npz differ diff --git a/stats/celebahq.all.256.npz b/stats/celebahq.all.256.npz new file mode 100644 index 0000000..9a9392d Binary files /dev/null and b/stats/celebahq.all.256.npz differ diff --git a/stats/church.train.256.npz b/stats/church.train.256.npz new file mode 100644 index 0000000..4d5d49d Binary files /dev/null and b/stats/church.train.256.npz differ diff --git a/stats/cifar10.test.npz b/stats/cifar10.test.npz new file mode 100644 index 0000000..fbe8b66 Binary files /dev/null and b/stats/cifar10.test.npz differ diff --git a/stats/cifar10.train.npz b/stats/cifar10.train.npz new file mode 100644 index 0000000..0d9b0cd Binary files /dev/null and b/stats/cifar10.train.npz differ diff --git a/stats/horse.train.256.npz b/stats/horse.train.256.npz new file mode 100644 index 0000000..a5ea713 Binary files /dev/null and b/stats/horse.train.256.npz differ diff --git a/stats/stl10.unlabeled.48.npz b/stats/stl10.unlabeled.48.npz new file mode 100644 index 0000000..a7965bf Binary files /dev/null and b/stats/stl10.unlabeled.48.npz differ diff --git a/test.py b/test.py new file mode 100644 index 0000000..56aae18 --- /dev/null +++ b/test.py @@ -0,0 +1,4 @@ +import torch +print("torch:", torch.__version__) +print("cuda runtime (torch):", torch.version.cuda) +print("cuDNN version:", torch.backends.cudnn.version()) \ No newline at end of file diff --git a/train.py b/train.py index 403ca4b..4416eba 100644 --- a/train.py +++ b/train.py @@ -9,7 +9,7 @@ from torchvision.utils import make_grid, save_image from tensorboardX import SummaryWriter from tqdm import trange -from pytorch_gan_metrics import get_inception_score_and_fid +from pytorch_image_generation_metrics import get_inception_score_and_fid from datasets import get_dataset from losses import HingeLoss, BCEWithLogits, Wasserstein @@ -17,6 +17,8 @@ from models.gradnorm import normalize_gradient from utils import ema, save_images, infiniteloop, set_seed, module_no_grad +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True net_G_models = { 'dcgan.32': dcgan.Generator32, @@ -57,7 +59,7 @@ flags.DEFINE_integer('lr_decay_start', 0, 'apply linearly decay to lr') flags.DEFINE_integer('batch_size_D', 64, "batch size for discriminator") flags.DEFINE_integer('batch_size_G', 128, "batch size for generator") -flags.DEFINE_integer('num_workers', 8, "dataloader workers") +flags.DEFINE_integer('num_workers', 10, "dataloader workers") flags.DEFINE_float('lr_D', 4e-4, "Discriminator learning rate") flags.DEFINE_float('lr_G', 2e-4, "Generator learning rate") flags.DEFINE_multi_float('betas', [0.0, 0.9], "for Adam") @@ -80,7 +82,7 @@ flags.DEFINE_string('logdir', './logs/GN-GAN_CIFAR10_RES_0', 'log folder') -device = torch.device('cuda:0') +device = torch.device('cuda') def generate_images(net_G): @@ -92,7 +94,7 @@ def generate_images(net_G): y = torch.randint( FLAGS.n_classes, (FLAGS.batch_size_G,)).to(device) fake = (net_G(z, y) + 1) / 2 - images.append(fake.cpu()) + images.append(fake) images = torch.cat(images, dim=0) return images[:FLAGS.num_images] @@ -128,7 +130,7 @@ def consistency_loss(net_D, real, y_real, pred_real, transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])): - aug_real = real.detach().clone().cpu() + aug_real = real.detach().clone() for idx, img in enumerate(aug_real): aug_real[idx] = transform(img) aug_real = aug_real.to(device) @@ -239,17 +241,17 @@ def decay_rate(step): loss, loss_real, loss_fake = loss_fn(pred_real, pred_fake) if FLAGS.cr > 0: loss_cr = consistency_loss( - net_D, x_real, y_real, pred_real) + net_D, x_real, y_real, pred_real).to(device) else: - loss_cr = torch.tensor(0.) + loss_cr = torch.tensor(0.).to(device) loss_all = loss + FLAGS.cr * loss_cr loss_all.backward() optim_D.step() - loss_sum += loss.cpu().item() - loss_real_sum += loss_real.cpu().item() - loss_fake_sum += loss_fake.cpu().item() - loss_cr_sum += loss_cr.cpu().item() + loss_sum += loss.item() + loss_real_sum += loss_real.item() + loss_fake_sum += loss_fake.item() + loss_cr_sum += loss_cr.item() loss = loss_sum / FLAGS.n_dis loss_real = loss_real_sum / FLAGS.n_dis @@ -291,8 +293,8 @@ def decay_rate(step): # sample from fixed z if step == 1 or step % FLAGS.sample_step == 0: with torch.no_grad(): - fake_net = net_G(fixed_z, fixed_y).cpu() - fake_ema = ema_G(fixed_z, fixed_y).cpu() + fake_net = net_G(fixed_z, fixed_y) + fake_ema = ema_G(fixed_z, fixed_y) grid_net = (make_grid(fake_net) + 1) / 2 grid_ema = (make_grid(fake_ema) + 1) / 2 writer.add_image('sample_ema', grid_ema, step) diff --git a/train_ddp.py b/train_ddp.py index 1951953..615a60e 100644 --- a/train_ddp.py +++ b/train_ddp.py @@ -19,7 +19,7 @@ from models.gradnorm import normalize_gradient from utils import ema, module_no_grad, set_seed from optim import Adam -from pytorch_gan_metrics import ( +from pytorch-image-generation-metrics import ( get_inception_score_and_fid_from_directory, get_inception_score_and_fid) @@ -55,7 +55,7 @@ flags.DEFINE_integer('batch_size_D', 64, "batch size for discriminator") flags.DEFINE_integer('batch_size_G', 128, "batch size for generator") flags.DEFINE_integer('accumulation', 1, 'batch num to accumulate gradient') -flags.DEFINE_integer('num_workers', 8, "dataloader workers") +flags.DEFINE_integer('num_workers', 10, "dataloader workers") flags.DEFINE_float('lr_D', 2e-4, "Discriminator learning rate") flags.DEFINE_float('lr_G', 2e-4, "Generator learning rate") flags.DEFINE_multi_float('betas', [0.0, 0.9], "for Adam") @@ -87,7 +87,7 @@ def image_generator(net_G): fake = (net_G(z) + 1) / 2 fake_list = [torch.empty_like(fake) for _ in range(world_size)] dist.all_gather(fake_list, fake) - fake = torch.cat(fake_list, dim=0).cpu() + fake = torch.cat(fake_list, dim=0).cuda() yield fake[:FLAGS.num_images - idx] del fake, fake_list @@ -96,7 +96,7 @@ def eval_save(rank, world_size): device = torch.device('cuda:%d' % rank) ckpt = torch.load( - os.path.join(FLAGS.logdir, 'best_model.pt'), map_location='cpu') + os.path.join(FLAGS.logdir, 'best_model.pt'), map_location='cuda') net_G = net_G_models[FLAGS.arch](FLAGS.z_dim).to(device) net_G = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net_G) net_G = DDP(net_G, device_ids=[rank], output_device=rank) @@ -176,7 +176,7 @@ def infiniteloop(dataloader, sampler, step=0): def train(rank, world_size): - device = torch.device('cuda:%d' % rank) + device = torch.device('cpu:%d' % rank) local_batch_size_D = FLAGS.batch_size_D // world_size local_batch_size_G = FLAGS.batch_size_G // world_size @@ -220,7 +220,7 @@ def train(rank, world_size): if FLAGS.resume: ckpt = torch.load( - os.path.join(FLAGS.logdir, 'model.pt'), map_location='cpu') + os.path.join(FLAGS.logdir, 'model.pt'), map_location='cuda') net_G.load_state_dict(ckpt['net_G']) net_D.load_state_dict(ckpt['net_D']) ema_G.load_state_dict(ckpt['ema_G']) @@ -337,8 +337,8 @@ def train(rank, world_size): dist.all_gather(fake_ema_list, fake_ema) dist.all_gather(fake_net_list, fake_net) if rank == 0: - fake_ema = torch.cat(fake_ema_list, dim=0).cpu() - fake_net = torch.cat(fake_net_list, dim=0).cpu() + fake_ema = torch.cat(fake_ema_list, dim=0).cuda() + fake_net = torch.cat(fake_net_list, dim=0).cuda() grid_ema = make_grid(fake_ema) grid_net = make_grid(fake_ema) writer.add_image('sample_ema', grid_ema, step) diff --git a/train_module.py b/train_module.py new file mode 100644 index 0000000..90e057e --- /dev/null +++ b/train_module.py @@ -0,0 +1,373 @@ +import os +import json +import math + +import torch +import torch.optim as optim +from absl import flags, app +from torchvision import transforms +from torchvision.utils import make_grid, save_image +from tensorboardX import SummaryWriter +from tqdm import trange +from pytorch_image_generation_metrics import get_inception_score_and_fid + +from datasets import get_dataset +from losses import HingeLoss, BCEWithLogits, Wasserstein + +from models import dcgan_module +from models import resnet_module +from models import biggan_module + +from models.gradnorm import get_gradient +from utils import ema, save_images, infiniteloop, set_seed, module_no_grad + + +net_G_models = { + 'dcgan.32': dcgan_module.Generator32, + 'dcgan.48': dcgan_module.Generator48, + 'resnet.32': resnet_module.ResGenerator32, + 'resnet.48': resnet_module.ResGenerator48, + 'biggan.32': biggan_module.Generator32, +} + +net_D_models = { + 'dcgan.32': dcgan_module.Discriminator32, + 'dcgan.48': dcgan_module.Discriminator48, + 'resnet.32': resnet_module.ResDiscriminator32, + 'resnet.48': resnet_module.ResDiscriminator48, + 'biggan.32': biggan_module.Discriminator32, +} + +loss_fns = { + 'hinge': HingeLoss, + 'bce': BCEWithLogits, + 'wass': Wasserstein, +} + + +datasets = ['cifar10.32', 'stl10.48'] + + +FLAGS = flags.FLAGS +# resume +flags.DEFINE_bool('resume', False, 'resume from checkpoint') +flags.DEFINE_bool('eval', False, 'load model and evaluate it') +flags.DEFINE_string('save', "", 'load model and save sample images to dir') +# model and training +flags.DEFINE_enum('dataset', 'cifar10.32', datasets, "select dataset") +flags.DEFINE_enum('arch', 'resnet.32', net_G_models.keys(), "architecture") +flags.DEFINE_enum('loss', 'hinge', loss_fns.keys(), "loss function") +flags.DEFINE_integer('total_steps', 200000, "total number of training steps") +flags.DEFINE_integer('lr_decay_start', 0, 'apply linearly decay to lr') +flags.DEFINE_integer('batch_size_D', 64, "batch size for discriminator") +flags.DEFINE_integer('batch_size_G', 128, "batch size for generator") +flags.DEFINE_integer('num_workers', 10, "dataloader workers") +flags.DEFINE_float('lr_D', 4e-4, "Discriminator learning rate") +flags.DEFINE_float('lr_G', 2e-4, "Generator learning rate") +flags.DEFINE_multi_float('betas', [0.0, 0.9], "for Adam") +flags.DEFINE_integer('n_dis', 5, "update Generator every this steps") +flags.DEFINE_integer('z_dim', 128, "latent space dimension") +flags.DEFINE_float('cr', 0, "weight for consistency regularization") +flags.DEFINE_integer('seed', 0, "random seed") +# conditional +flags.DEFINE_integer('n_classes', 1, 'the number of classes in dataset') +# ema +flags.DEFINE_float('ema_decay', 0.9999, "ema decay rate") +flags.DEFINE_integer('ema_start', 0, "start step for ema") +# logging +flags.DEFINE_integer('sample_step', 500, "sample image every this steps") +flags.DEFINE_integer('sample_size', 64, "sampling size of images") +flags.DEFINE_integer('eval_step', 5000, "evaluate FID and Inception Score") +flags.DEFINE_integer('save_step', 20000, "save model every this step") +flags.DEFINE_integer('num_images', 50000, '# images for evaluation') +flags.DEFINE_string('fid_stats', './stats/cifar10.train.npz', 'FID cache') +flags.DEFINE_string('logdir', './logs/GN-GAN_CIFAR10_RES_0', 'log folder') + + +device = torch.device('cuda') + + +def generate_images(net_G): + images = [] + with torch.no_grad(): + for _ in trange(0, FLAGS.num_images, FLAGS.batch_size_G, + ncols=0, leave=False): + z = torch.randn(FLAGS.batch_size_G, FLAGS.z_dim).to(device) + y = torch.randint( + FLAGS.n_classes, (FLAGS.batch_size_G,)).to(device) + fake = (net_G(z, y) + 1) / 2 + images.append(fake.cuda()) + images = torch.cat(images, dim=0) + return images[:FLAGS.num_images] + + +def eval_save(): + net_G = net_G_models[FLAGS.arch](FLAGS.z_dim, FLAGS.n_classes).to(device) + ckpt = torch.load(os.path.join(FLAGS.logdir, 'best_model.pt')) + net_G.load_state_dict(ckpt['net_G']) + + images = generate_images(net_G=net_G) + if FLAGS.eval: + (IS, IS_std), FID = get_inception_score_and_fid( + images, FLAGS.fid_stats, verbose=True) + print("IS: %6.3f(%.3f), FID: %7.3f" % (IS, IS_std, FID)) + if FLAGS.save: + save_images(images, FLAGS.save, verbose=True) + + +def evaluate(net_G): + images = generate_images(net_G=net_G) + (IS, IS_std), FID = get_inception_score_and_fid( + images, FLAGS.fid_stats, verbose=True) + del images + return (IS, IS_std), FID + + +def consistency_loss(net_D, real, y_real, pred_real, + transform=transforms.Compose([ + transforms.Lambda(lambda x: (x + 1) / 2), + transforms.ToPILImage(mode='RGB'), + transforms.RandomHorizontalFlip(), + transforms.RandomAffine(0, translate=(0.2, 0.2)), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ])): + aug_real = real.detach().clone().cuda() + for idx, img in enumerate(aug_real): + aug_real[idx] = transform(img) + aug_real = aug_real.to(device) + pred_aug = get_gradient(net_D, aug_real, y=y_real) + loss = ((pred_aug - pred_real) ** 2).mean() + return loss + + +def train(): + dataset = get_dataset(FLAGS.dataset) + dataloader = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=FLAGS.batch_size_D * FLAGS.n_dis, + shuffle=True, + num_workers=FLAGS.num_workers, + drop_last=True) + looper = infiniteloop(dataloader) + + # model + net_G = net_G_models[FLAGS.arch](FLAGS.z_dim, FLAGS.n_classes).to(device) + ema_G = net_G_models[FLAGS.arch](FLAGS.z_dim, FLAGS.n_classes).to(device) + net_D = net_D_models[FLAGS.arch](FLAGS.n_classes).to(device) + + # ema + ema(net_G, ema_G, decay=0) + + # loss + loss_fn = loss_fns[FLAGS.loss]() + + # optimizer + optim_G = optim.Adam(net_G.parameters(), lr=FLAGS.lr_G, betas=FLAGS.betas) + optim_D = optim.Adam(net_D.parameters(), lr=FLAGS.lr_D, betas=FLAGS.betas) + + # scheduler + def decay_rate(step): + period = max(FLAGS.total_steps - FLAGS.lr_decay_start, 1) + return 1 - max(step - FLAGS.lr_decay_start, 0) / period + sched_G = optim.lr_scheduler.LambdaLR(optim_G, lr_lambda=decay_rate) + sched_D = optim.lr_scheduler.LambdaLR(optim_D, lr_lambda=decay_rate) + + D_size = 0 + for param in net_D.parameters(): + D_size += param.data.nelement() + G_size = 0 + for param in net_G.parameters(): + G_size += param.data.nelement() + print('D params: %d, G params: %d' % (D_size, G_size)) + + writer = SummaryWriter(FLAGS.logdir) + if FLAGS.resume: + ckpt = torch.load(os.path.join(FLAGS.logdir, 'model.pt')) + net_G.load_state_dict(ckpt['net_G']) + net_D.load_state_dict(ckpt['net_D']) + ema_G.load_state_dict(ckpt['ema_G']) + optim_G.load_state_dict(ckpt['optim_G']) + optim_D.load_state_dict(ckpt['optim_D']) + sched_G.load_state_dict(ckpt['sched_G']) + sched_D.load_state_dict(ckpt['sched_D']) + fixed_z = ckpt['fixed_z'] + fixed_y = ckpt['fixed_y'] + # start value + start = ckpt['step'] + 1 + best_IS, best_FID = ckpt['best_IS'], ckpt['best_FID'] + del ckpt + else: + # sample fixed z + fixed_z = torch.randn(FLAGS.sample_size, FLAGS.z_dim).to(device) + fixed_y = torch.randint( + FLAGS.n_classes, (FLAGS.sample_size,)).to(device) + # start value + start, best_IS, best_FID = 1, 0, 999 + + os.makedirs(os.path.join(FLAGS.logdir, 'sample')) + with open(os.path.join(FLAGS.logdir, "flagfile.txt"), 'w') as f: + f.write(FLAGS.flags_into_string()) + real = next(iter(dataloader))[0][:FLAGS.sample_size] + writer.add_image('real_sample', make_grid((real + 1) / 2)) + writer.flush() + + with trange(start, FLAGS.total_steps + 1, ncols=0, + initial=start - 1, total=FLAGS.total_steps) as pbar: + for step in pbar: + loss_sum = 0 + loss_real_sum = 0 + loss_fake_sum = 0 + loss_cr_sum = 0 + + x, y = next(looper) + x = iter(torch.split(x, FLAGS.batch_size_D)) + y = iter(torch.split(y, FLAGS.batch_size_D)) + # Discriminator + for _ in range(FLAGS.n_dis): + optim_D.zero_grad() + x_real, y_real = next(x).to(device), next(y).to(device) + + with torch.no_grad(): + z_ = torch.randn( + FLAGS.batch_size_D, FLAGS.z_dim).to(device) + y_fake = torch.randint( + FLAGS.n_classes, (FLAGS.batch_size_D,)).to(device) + x_fake = net_G(z_, y_fake).detach() + x_real_fake = torch.cat([x_real, x_fake], dim=0) + y_real_fake = torch.cat([y_real, y_fake], dim=0) + pred = get_gradient(net_D, x_real_fake, y=y_real_fake) + pred_real, pred_fake = torch.split( + pred, [x_real.shape[0], x_fake.shape[0]]) + + loss, loss_real, loss_fake = loss_fn(pred_real, pred_fake) + if FLAGS.cr > 0: + loss_cr = consistency_loss( + net_D, x_real, y_real, pred_real) + else: + loss_cr = torch.tensor(0.) + loss_all = loss + FLAGS.cr * loss_cr + loss_all.backward() + optim_D.step() + + loss_sum += loss.cuda().item() + loss_real_sum += loss_real.cuda().item() + loss_fake_sum += loss_fake.cuda().item() + loss_cr_sum += loss_cr.cuda().item() + + loss = loss_sum / FLAGS.n_dis + loss_real = loss_real_sum / FLAGS.n_dis + loss_fake = loss_fake_sum / FLAGS.n_dis + loss_cr = loss_cr_sum / FLAGS.n_dis + + writer.add_scalar('loss', loss, step) + writer.add_scalar('loss_real', loss_real, step) + writer.add_scalar('loss_fake', loss_fake, step) + writer.add_scalar('loss_cr', loss_cr, step) + + pbar.set_postfix( + loss_real='%.3f' % loss_real, + loss_fake='%.3f' % loss_fake) + + # Generator + with module_no_grad(net_D): + optim_G.zero_grad() + z_ = torch.randn(FLAGS.batch_size_G, FLAGS.z_dim).to(device) + y_ = torch.randint( + FLAGS.n_classes, (FLAGS.batch_size_G,)).to(device) + fake = net_G(z_, y_) + pred_fake = get_gradient(net_D, fake, y=y_) + loss = loss_fn(pred_fake) + loss.backward() + optim_G.step() + + # ema + if step < FLAGS.ema_start: + decay = 0 + else: + decay = FLAGS.ema_decay + ema(net_G, ema_G, decay) + + # scheduler + sched_G.step() + sched_D.step() + + # sample from fixed z + if step == 1 or step % FLAGS.sample_step == 0: + with torch.no_grad(): + fake_net = net_G(fixed_z, fixed_y).cuda() + fake_ema = ema_G(fixed_z, fixed_y).cuda() + grid_net = (make_grid(fake_net) + 1) / 2 + grid_ema = (make_grid(fake_ema) + 1) / 2 + writer.add_image('sample_ema', grid_ema, step) + writer.add_image('sample', grid_net, step) + save_image( + grid_ema, + os.path.join(FLAGS.logdir, 'sample', '%d.png' % step)) + + # evaluate IS, FID and save model + if step == 1 or step % FLAGS.eval_step == 0: + (IS, IS_std), FID = evaluate(net_G) + (IS_ema, IS_std_ema), FID_ema = evaluate(ema_G) + if not math.isnan(FID) and not math.isnan(best_FID): + save_as_best = (FID < best_FID) + else: + save_as_best = (IS > best_IS) + if save_as_best: + best_IS = IS + best_FID = FID + ckpt = { + 'net_G': net_G.state_dict(), + 'net_D': net_D.state_dict(), + 'ema_G': ema_G.state_dict(), + 'optim_G': optim_G.state_dict(), + 'optim_D': optim_D.state_dict(), + 'sched_G': sched_G.state_dict(), + 'sched_D': sched_D.state_dict(), + 'fixed_y': fixed_y, + 'fixed_z': fixed_z, + 'best_IS': best_IS, + 'best_FID': best_FID, + 'step': step, + } + if step == 1 or step % FLAGS.save_step == 0: + torch.save( + ckpt, os.path.join(FLAGS.logdir, '%06d.pt' % step)) + if save_as_best: + torch.save( + ckpt, os.path.join(FLAGS.logdir, 'best_model.pt')) + torch.save(ckpt, os.path.join(FLAGS.logdir, 'model.pt')) + metrics = { + 'IS': IS, + 'IS_std': IS_std, + 'FID': FID, + 'IS_EMA': IS_ema, + 'IS_std_EMA': IS_std_ema, + 'FID_EMA': FID_ema, + } + for name, value in metrics.items(): + writer.add_scalar(name, value, step) + writer.flush() + with open(os.path.join(FLAGS.logdir, 'eval.txt'), 'a') as f: + metrics['step'] = step + f.write(json.dumps(metrics) + "\n") + k = len(str(FLAGS.total_steps)) + pbar.write( + f"{step:{k}d}/{FLAGS.total_steps} " + f"IS: {IS:6.3f}({IS_std:.3f}), " + f"FID: {FID:.3f}, " + f"IS_EMA: {IS_ema:6.3f}({IS_std_ema:.3f}), " + f"FID_EMA: {FID_ema:.3f}") + writer.close() + + +def main(argv): + set_seed(FLAGS.seed) + if FLAGS.eval or FLAGS.save: + eval_save() + else: + train() + + +if __name__ == '__main__': + app.run(main) diff --git a/utils.py b/utils.py index f5a9ad4..74512e4 100644 --- a/utils.py +++ b/utils.py @@ -8,7 +8,7 @@ from tqdm import tqdm -device = torch.device('cuda:0') +device = torch.device('cuda') def save_images(images, output_dir, verbose=False): @@ -28,7 +28,7 @@ def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + #torch.cuda.manual_seed(seed) # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False