import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
from anime_face_generator.dataset import ImageDataset
batch_size = 32
num_epoch = 100
z_dimension = 100
dir_path = './wgan_img'
# 創(chuàng)建文件夾
if not os.path.exists(dir_path):
os.mkdir(dir_path)
def to_img(x):
"""因為我們在生成器里面用了tanh"""
out = 0.5 * (x + 1)
return out
dataset = ImageDataset()
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.gen = nn.Sequential(
# 輸入是一個nz維度的噪聲,我們可以認為它是一個1*1*nz的feature map
nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
# 上一步的輸出形狀:(512) x 4 x 4
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
# 上一步的輸出形狀: (256) x 8 x 8
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
# 上一步的輸出形狀: (256) x 16 x 16
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
# 上一步的輸出形狀:(256) x 32 x 32
nn.ConvTranspose2d(64, 3, 5, 3, 1, bias=False),
nn.Tanh() # 輸出范圍 -1~1 故而采用Tanh
# nn.Sigmoid()
# 輸出形狀:3 x 96 x 96
)
def forward(self, x):
x = self.gen(x)
return x
def weight_init(m):
# weight_initialization: important for wgan
class_name = m.__class__.__name__
if class_name.find('Conv') != -1:
m.weight.data.normal_(0, 0.02)
elif class_name.find('Norm') != -1:
m.weight.data.normal_(1.0, 0.02)
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.dis = nn.Sequential(
nn.Conv2d(3, 64, 5, 3, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# 輸出 (64) x 32 x 32
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# 輸出 (128) x 16 x 16
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
# 輸出 (256) x 8 x 8
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
# 輸出 (512) x 4 x 4
nn.Conv2d(512, 1, 4, 1, 0, bias=False),
nn.Flatten(),
# nn.Sigmoid() # 輸出一個數(概率)
)
def forward(self, x):
x = self.dis(x)
return x
def weight_init(m):
# weight_initialization: important for wgan
class_name = m.__class__.__name__
if class_name.find('Conv') != -1:
m.weight.data.normal_(0, 0.02)
elif class_name.find('Norm') != -1:
m.weight.data.normal_(1.0, 0.02)
def save(model, filename="model.pt", out_dir="out/"):
if model is not None:
if not os.path.exists(out_dir):
os.mkdir(out_dir)
torch.save({'model': model.state_dict()}, out_dir + filename)
else:
print("[ERROR]:Please build a model!!!")
import QuickModelBuilder as builder
if __name__ == '__main__':
one = torch.FloatTensor([1]).cuda()
mone = -1 * one
is_print = True
# 創(chuàng)建對象
D = Discriminator()
G = Generator()
D.weight_init()
G.weight_init()
if torch.cuda.is_available():
D = D.cuda()
G = G.cuda()
lr = 2e-4
d_optimizer = torch.optim.RMSprop(D.parameters(), lr=lr, )
g_optimizer = torch.optim.RMSprop(G.parameters(), lr=lr, )
d_scheduler = torch.optim.lr_scheduler.ExponentialLR(d_optimizer, gamma=0.99)
g_scheduler = torch.optim.lr_scheduler.ExponentialLR(g_optimizer, gamma=0.99)
fake_img = None
# ##########################進入訓練##判別器的判斷過程#####################
for epoch in range(num_epoch): # 進行多個epoch的訓練
pbar = builder.MyTqdm(epoch=epoch, maxval=len(dataloader))
for i, img in enumerate(dataloader):
num_img = img.size(0)
real_img = img.cuda() # 將tensor變成Variable放入計算圖中
# 這里的優(yōu)化器是D的優(yōu)化器
for param in D.parameters():
param.requires_grad = True
# ########判別器訓練train#####################
# 分為兩部分:1、真的圖像判別為真;2、假的圖像判別為假
# 計算真實圖片的損失
d_optimizer.zero_grad() # 在反向傳播之前,先將梯度歸0
real_out = D(real_img) # 將真實圖片放入判別器中
d_loss_real = real_out.mean(0).view(1)
d_loss_real.backward(one)
# 計算生成圖片的損失
z = torch.randn(num_img, z_dimension).cuda() # 隨機生成一些噪聲
z = z.reshape(num_img, z_dimension, 1, 1)
fake_img = G(z).detach() # 隨機噪聲放入生成網絡中,生成一張假的圖片。 # 避免梯度傳到G,因為G不用更新, detach分離
fake_out = D(fake_img) # 判別器判斷假的圖片,
d_loss_fake = fake_out.mean(0).view(1)
d_loss_fake.backward(mone)
d_loss = d_loss_fake - d_loss_real
d_optimizer.step() # 更新參數
# 每次更新判別器的參數之后把它們的絕對值截斷到不超過一個固定常數c=0.01
for parm in D.parameters():
parm.data.clamp_(-0.01, 0.01)
# ==================訓練生成器============================
# ###############################生成網絡的訓練###############################
for param in D.parameters():
param.requires_grad = False
# 這里的優(yōu)化器是G的優(yōu)化器,所以不需要凍結D的梯度,因為不是D的優(yōu)化器,不會更新D
g_optimizer.zero_grad() # 梯度歸0
z = torch.randn(num_img, z_dimension).cuda()
z = z.reshape(num_img, z_dimension, 1, 1)
fake_img = G(z) # 隨機噪聲輸入到生成器中,得到一副假的圖片
output = D(fake_img) # 經過判別器得到的結果
# g_loss = criterion(output, real_label) # 得到的假的圖片與真實的圖片的label的loss
g_loss = torch.mean(output).view(1)
# bp and optimize
g_loss.backward(one) # 進行反向傳播
g_optimizer.step() # .step()一般用在反向傳播后面,用于更新生成網絡的參數
# 打印中間的損失
pbar.set_right_info(d_loss=d_loss.data.item(),
g_loss=g_loss.data.item(),
real_scores=real_out.data.mean().item(),
fake_scores=fake_out.data.mean().item(),
)
pbar.update()
try:
fake_images = to_img(fake_img.cpu())
save_image(fake_images, dir_path + '/fake_images-{}.png'.format(epoch + 1))
except:
pass
if is_print:
is_print = False
real_images = to_img(real_img.cpu())
save_image(real_images, dir_path + '/real_images.png')
pbar.finish()
d_scheduler.step()
g_scheduler.step()
save(D, "wgan_D.pt")
save(G, "wgan_G.pt")
到此這篇關于Pytorch實現(xiàn)WGAN用于動漫頭像生成的文章就介紹到這了,更多相關Pytorch實現(xiàn)WGAN用于動漫頭像生成內容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!