首页 > 分享 > Pytorch生成对抗网络(GAN)官方入门教程

Pytorch生成对抗网络(GAN)官方入门教程

目录

引言(Introduction)

生成对抗网络(Generative Adversarial Networks)

什么是GAN?(What is a GAN?)

什么是DCGAN?(What is a DCGAN?)

输入(Inputs)

数据(Data)

实现(Implementation)

权重初始化(Weight Initialization)

生成器(Generator)

判别器(Discriminator)

损失函数和优化器(Loss Functions and Optimizers)

训练(Training)

下一步(Where to Go Next)

原文链接:https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

引言(Introduction)

本教程将通过一个示例介绍DCGAN(Deep Convolutional Generative Adversarial Networks)。我们将训练一个生成对抗网络(GAN),在展示许多名人的真实照片后产生新的名人。这里的代码实现来自 pytorch/examples,本文档对代码实现进行透彻的解释,并阐明此模型如何以及为什么有效果。但别担心,理解GANs不需要有先验知识,但它可能需要你花一些时间来研究幕后到底发生了什么。另外,因为时间的缘故,有一个或两个GPU也会有帮助。让我们开始吧。

生成对抗网络(Generative Adversarial Networks)

什么是GAN?(What is a GAN?)

GANs是一个教学DL(Deep Learning)模型的框架,使得DL模型可以捕获训练数据的分布,这样我们就可以在相同的数据分布中生成新的数据。GANs是Goodfellow 在2014年发明的,并在 Generative Adversarial Nets论文中首次提出。它由两个不同的模型组成,一个生成器和一个判别器。生成器的目标是生成类似于训练图片的图片,判别器的目标是,输入一张图片,判断输入的图片是真图片还是生成器产生的假图片。在训练过程中,生成器不断的生成更好的假图片试图骗过判别器,而判别器则在努力成为更好的鉴别者,正确的对真假图片进行分类。这个游戏的平衡点就是生成器产生的图片就好像是从训练图片中取出的一样,判别器总是有50%的置信度鉴别生成器的图片是真或是假。

现在,让我们定义一些在整个教程中使用的符号,从判别器(discriminator)开始。设x表示图像数据。D(x)表示判别器,它的输出是x来自训练数据而不是生成器的概率(标量)。这里,我们处理的是CHW(channel,height,width)为3*64*64大小的图像。直观的说,当x来自训练数据时D(x)的值应该是高的,当x来自生成器时D(x)的值应该是低的。你也可以把D(x)看作是传统的二元分类器。

对于生成器(generator )的符号,设z是从标准正态分布采样的隐向量(此处的隐没有什么特别高深晦涩难懂的意思,就像前馈神经网络的隐藏层一样,表示没有物理含义的变量或空间,一般不具备可解释性),G(z)表示将隐向量z映射到数据空间的生成函数。G的目标是估算训练数据的分布(pdata),以便从估计的分布(pg)中生成假样本。

所以,D(G(z))是生成器G的输出是真实图片的概率(标量)。正如 Goodfellow的论文中所描述的:D和G在玩一个极大极小博弈:D试图最大化它能正确分类真赝品的概率 (logD(x)),而G试图最小化D预测其输出是假的概率 (log(1−D(G(x))))。从论文中可以看出,GAN的损失函数为:

理论上,这个极大极小博弈的解决方案是pg=pdata,判别器随机猜测输入图片是真是假。然而,GANs的收敛理论仍在积极研究中,而现实中的模型通常不能做到收敛。

什么是DCGAN?(What is a DCGAN?)

DCGAN是上述DAN的直接扩展,不同之处在于它在判别器和生成器中分别使用了卷积和卷积转置层。它是由Radford 等人在 Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks论文中首先提出的。其中的判别器由convolution层,batch norm层,和LeakyReLU激活函数组成。输入是一个3*64*64的图片数据,输出是一个概率(标量),即输入来自真实数据的分布。其中的生成器由convolutional-transpose层,batch norm层,和ReLU激活函数组成。输入是一个隐向量——z,来自标准正态分布,输出是一个3*64*64的GRB图片。卷积转置层可以将隐向量转换成图像的形状。在论文中,作者还提供了一些关于如何设置优化器、如何计算损失函数以及如何初始化模型权重的提示,这些将在接下来的部分中进行解释。

from __future__ import print_function

import argparse

import os

import random

import torch

import torch.nn as nn

import torch.nn.parallel

import torch.backends.cudnn as cudnn

import torch.optim as optim

import torch.utils.data

import torchvision.datasets as dset

import torchvision.transforms as transforms

import torchvision.utils as vutils

import numpy as np

import matplotlib.pyplot as plt

import matplotlib.animation as animation

from IPython.display import HTML

manualSeed = 999

print("Random Seed: ", manualSeed)

random.seed(manualSeed)

torch.manual_seed(manualSeed)

输出:

Random Seed: 999

输入(Inputs)

让我们为接下来的运行定义一些输入:

dataroot - 数据集存放路径. 我们将在下一节中深入讨论workers - 多进程加载数据所用的进程数batch_size - 训练时batch的大小.  DCGAN 论文中使用的是 128image_size -训练图片的尺寸. 这里默认是 64x64.如果需要另一种尺寸,则必须更改 D 和G 的结构. 参阅here 了解更多详细信息。nc - 输入图片的通道数. 这里是3nz - 隐向量的维度(即来自标准正态分布的隐向量的维度)(也即高斯噪声的维度)ngf - 生成器的特征图数量(即进行最后一次卷积转置层时,out_channels为3时的in_channels)ndf - 判别器的特征图数量(即进行第一次卷积时,in_channels为3时的out通道数)num_epochs - 训练模型的迭代次数。长时间的训练可能会带来更好的结果,但也需要更长的时间lr - 训练时的学习率. 在DCGAN 论文中, 这个数值是 0.0002beta1 - Adam 优化器的beta1参数. 在论文中,此数值是0.5ngpu - 可用GPU的数量. 如果为 0, 代码将使用CPU训练. 如果大于0,将使用此数值的GPU进行训练

dataroot = "data/celeba"

workers = 2

batch_size = 128

image_size = 64

nc = 3

nz = 100

ngf = 64

ndf = 64

num_epochs = 5

lr = 0.0002

beta1 = 0.5

ngpu = 1

'

数据(Data)

在本教程中,我们将使用 Celeb-A Faces dataset 数据集,该数据集可以在链接站点或 Google Drive中下载。数据集下载之后是一个名为img_align_celeba.zip的文件。当你下载完成之后,创建一个celeba目录并将zip文件解压到这个目录。然后,将上一节提到的dataroot 输入的值设置为我们刚刚创建的celeba目录。生成的目录结构应为:

/path/to/celeba

-> img_align_celeba

-> 188242.jpg

-> 173822.jpg

-> 284702.jpg

-> 537394.jpg

...

这是非常重要的一步,因为我们将使用ImageFolder这个数据集类,它要求在这个数据集的根目录下必须要有子目录。现在,我们可以创建数据集,创建dataloader,设置device,最后可视化一些训练数据。

dataset = dset.ImageFolder(root=dataroot,

transform=transforms.Compose([

transforms.Resize(image_size),

transforms.CenterCrop(image_size),

transforms.ToTensor(),

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

]))

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,

shuffle=True)

device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

real_batch = next(iter(dataloader))

plt.figure(figsize=(8,8))

plt.axis("off")

plt.title("Training Images")

plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

实现(Implementation)

输入参数和数据集都准备好了,现在可以进入实现环节了。我们将会从权重的初始化策略开始,然后详细讨论生成器,判别器,损失函数和训练过程。

权重初始化(Weight Initialization)

在DCGAN的论文中,作者指定所有模型的初始化权重是一个均值为0,标准差为0.02的正态分布。weights_init函数的输入是一个初始化的模型,然后按此标准重新初始化模型的卷积层、卷积转置层和BN层的权重。模型初始化后应立即应用此函数。(这个文章中,我有的时候用的权重,有时候用参数,这两个名词是等价的)

# custom weights initialization called on netG and netD

def weights_init(m):

classname = m.__class__.__name__

if classname.find('Conv') != -1:

nn.init.normal_(m.weight.data, 0.0, 0.02)

elif classname.find('BatchNorm') != -1:

nn.init.normal_(m.weight.data, 1.0, 0.02)

nn.init.constant_(m.bias.data, 0)

生成器(Generator)

生成器G, 用于将隐向量 (z)映射到数据空间。 由于我们的数据是图片,也就是通过隐向量z生成一张与训练图片大小相同的RGB图片 (比如 3x64x64). 在实践中,这是通过一系列的ConvTranspose2d,BatchNorm2d,ReLU完成的。 生成器的输出,通过tanh激活函数把数据映射到[−1,1]。值得注意的是,在卷积转置层之后紧跟BN层,这是DCGAN论文的重要贡献。这些层(即BN层)有助于训练过程中梯度的流动。DCGAN论文中的生成器如下图所示。

dcgan_generator

注意,我们在输入(Inputs)小节设置的参数 (nz, ngf, and nc) 影响着生成器G的架构。 nz 是隐向量z的长度, ngf 为生成器的特征图大小,nc 是输出图片(若为RGB图像,则设置为3)的通道数。 生成器的代码如下:

class Generator(nn.Module):

def __init__(self, ngpu):

super(Generator, self).__init__()

self.ngpu = ngpu

self.main = nn.Sequential(

nn.ConvTranspose2d(in_channels=nz, out_channels=ngf * 8, kernel_size=4, stride=1, padding=0, bias=False),

nn.BatchNorm2d(ngf * 8),

nn.ReLU(True),

nn.ConvTranspose2d(in_channels=ngf * 8, out_channels=ngf * 4, kernel_size=4, stride=2, padding=1, bias=False),

nn.BatchNorm2d(ngf * 4),

nn.ReLU(True),

nn.ConvTranspose2d(in_channels=ngf * 4, out_channels=ngf * 2, kernel_size=4, stride=2, padding=1, bias=False),

nn.BatchNorm2d(ngf * 2),

nn.ReLU(True),

nn.ConvTranspose2d(in_channels=ngf * 2, out_channels=ngf, kernel_size=4, stride=2, padding=1, bias=False),

nn.BatchNorm2d(ngf),

nn.ReLU(True),

nn.ConvTranspose2d(in_channels=ngf, out_channels=nc, kernel_size=4, stride=2, padding=1, bias=False),

nn.Tanh()

)

"""

上卷积层可理解为是卷积层的逆运算。

拿最后一个上卷积层举例。若卷积的输入是(nc) x 64 x 64时,

经过Hout=(Hin+2*Padding-kernel_size)/stride+1=(64+2*1-4)/2+1=32,输出为(out_channels) x 32 x 32

此处上卷积层为卷积层的输入输出的倒置:

即输入通道数为out_channels,输出通道数为3;输入图片大小为(out_channels) x 32 x 32,输出图片的大小为(nc) x 64 x 64

"""

def forward(self, input):

return self.main(input)

现在,我们可以实例化生成器,并应用weights_init方法。打印并查看生成器的结构。

netG = Generator(ngpu).to(device)

if (device.type == 'cuda') and (ngpu > 1):

netG = nn.DataParallel(netG, list(range(ngpu)))

netG.apply(weights_init)

print(netG)

输出如下:

Generator(

(main): Sequential(

(0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)

(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

(2): ReLU(inplace=True)

(3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)

(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

(5): ReLU(inplace=True)

(6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)

(7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

(8): ReLU(inplace=True)

(9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)

(10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

(11): ReLU(inplace=True)

(12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)

(13): Tanh()

)

)

判别器(Discriminator)

如前所述,判别器D是一个二分类网络,它将图片作为输入,输出其为真的标量概率。这里,D的输入是一个3*64*64的图片,通过一系列的 Conv2d, BatchNorm2d,和 LeakyReLU 层对其进行处理,最后通过Sigmoid 激活函数输出最终概率。如有必要,你可以使用更多层对其扩展。DCGAN 论文提到使用跨步卷积而不是池化进行降采样是一个很好的实践,因为它可以让网络自己学习池化方法。BatchNorm2d层和LeakyReLU层也促进了梯度的健康流动,这对生成器G和判别器D的学习过程都是至关重要的。

判别器代码

class Discriminator(nn.Module):

def __init__(self, ngpu):

super(Discriminator, self).__init__()

self.ngpu = ngpu

self.main = nn.Sequential(

nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),

nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),

nn.BatchNorm2d(ndf * 2),

nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),

nn.BatchNorm2d(ndf * 4),

nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),

nn.BatchNorm2d(ndf * 8),

nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),

nn.Sigmoid()

)

def forward(self, input):

return self.main(input)

现在,我们可以实例化判别器,并应用weights_init方法。打印并查看判别器的结构。

netD = Discriminator(ngpu).to(device)

if (device.type == 'cuda') and (ngpu > 1):

netD = nn.DataParallel(netD, list(range(ngpu)))

netD.apply(weights_init)

print(netD)

输出如下:

Discriminator(

(main): Sequential(

(0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)

(1): LeakyReLU(negative_slope=0.2, inplace=True)

(2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)

(3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

(4): LeakyReLU(negative_slope=0.2, inplace=True)

(5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)

(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

(7): LeakyReLU(negative_slope=0.2, inplace=True)

(8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)

(9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

(10): LeakyReLU(negative_slope=0.2, inplace=True)

(11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)

(12): Sigmoid()

)

)

损失函数和优化器(Loss Functions and Optimizers)

有了生成器D和判别器G,我们可以为其指定损失函数和优化器来进行学习。这里将使用Binary Cross Entropy损失函数 (BCELoss)。其在PyTorch中的定义为:

注意这个损失函数需要你提供两个log组件 (比如 log(D(x))和log(1−D(G(z))))。我们可以指定BCE的哪个部分使用输入y标签。这将会在接下来的训练小节中讲到,但是明白我们可以仅仅通过改变y标签来指定使用哪个log部分是非常重要的(比如GT标签)。

接下来,我们定义真实标签为1,假标签为0。这些标签用来计算生成器D和判别器G的损失,这也是原始GAN论文的惯例。最后,我们将设置两个独立的优化器,一个用于生成器G,另一个判别器D。如DCGAN 论文所述,两个Adam优化器学习率都为0.0002,Beta1都为0.5。为了记录生成器的学习过程,我们将会生成一批符合高斯分布的固定的隐向量(即fixed_noise)。在训练过程中,我们将周期性地把固定噪声作为生成器G的输入,通过输出看到由噪声生成的图像。

criterion = nn.BCELoss()

fixed_noise = torch.randn(64, nz, 1, 1, device=device)

real_label = 1.

fake_label = 0.

optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))

optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

训练(Training)

最后,我们已经定义了GAN网络的所有结构,可以开始训练它了。请注意,训练GAN有点像一种艺术形式,因为不正确的超参数会导致模式崩溃,却不会提示超参数错误的信息。这里,我们将遵循Goodfellow’s论文的算法1,同时遵循 ganhacks中的一些最佳实践。也就是说,我们将会“为真假数据构造不同的mini-batches数据”,同时调整判别器G的目标函数以最大化logD(G(z))。训练分为两个部分。第一部分更新判别器,第二部分更新生成器。

第一部分——训练判别器(Part 1 - Train the Discriminator)

回想一下,判别器的训练目的是最大化输入正确分类的概率。从Goodfellow的角度来看,我们希望“通过随机梯度的变化来更新鉴别器”。实际上,我们想要最大化log(D(x))+log(1−D(G(z)))。为了区别mini-batch,ganhacks建议分两步计算。第一步,我们将会构造一个来自训练数据的真图片batch,作为判别器D的输入,计算其损失loss(log(D(x)),调用backward方法计算梯度。第二步,我们将会构造一个来自生成器G的假图片batch,作为判别器D的输入,计算其损失loss(log(1−D(G(z))),调用backward方法累计梯度。最后,调用判别器D优化器的step方法更新一次模型(即判别器D)的参数。

第二部分——训练生成器(Part 2 - Train the Generator)

如原论文所述,我们希望通过最小化log(1−D(G(z)))训练生成器G来创造更好的假图片。作为解决方案,我们希望最大化log(D(G(z)))。通过以下方法来实现这一点:使用判别器D来分类在第一部分G的输出图片,计算损失函数的时候用真实标签(记做GT),调用backward方法更新生成器G的梯度,最后调用生成器G优化器的step方法更新一次模型(即生成器G)的参数。使用真实标签作为GT来计算损失函数看起来有悖常理,但是这允许我们可以使用BCELoss的log(x)部分而不是log(1−x)部分,这正是我们想要的。

最后,我们将做一些统计报告。以展示每个迭代完成之后我们的固定噪声通过生成器G产生的图片信息。训练过程中统计数据报告如下:

Loss_D - 真假batch图片输入判别器后,所产生的损失总和((log(D(x)) + log(D(G(z))))).Loss_G - 生成器损失总和(log(D(G(z))))D(x) - 真batch图片输入判别器后,所产生的的平均值(即平均概率)。这个值理论上应该接近1,然后随着生成器的改善,它会收敛到0.5左右。D(G(z)) - 假batch图片输入判别器后,所产生的平均值(即平均概率)。第一个值在判别器D更新之前,第二个值在判别器D更新之后。这两个值应该从接近0开始,随着G的改善收敛到0.5。

注意: 这一步可能会运行时间久一些。这取决于你跑了多少Epochs和你的数据集中有多少数据。

img_list = []

G_losses = []

D_losses = []

iters = 0

print("Starting Training Loop...")

for epoch in range(num_epochs):

for i, data in enumerate(dataloader, 0):

netD.zero_grad()

real_cpu = data[0].to(device)

b_size = real_cpu.size(0)

label = torch.full((b_size,), real_label, dtype=torch.float, device=device)

output = netD(real_cpu).view(-1)

errD_real = criterion(output, label)

errD_real.backward()

D_x = output.mean().item()

noise = torch.randn(b_size, nz, 1, 1, device=device)

fake = netG(noise)

label.fill_(fake_label)

output = netD(fake.detach()).view(-1)

errD_fake = criterion(output, label)

errD_fake.backward()

D_G_z1 = output.mean().item()

errD = errD_real + errD_fake

optimizerD.step()

netG.zero_grad()

label.fill_(real_label)

output = netD(fake).view(-1)

errG = criterion(output, label)

errG.backward()

D_G_z2 = output.mean().item()

optimizerG.step()

if i % 50 == 0:

print('[%d/%d][%d/%d]tLoss_D: %.4ftLoss_G: %.4ftD(x): %.4ftD(G(z)): %.4f / %.4f'

% (epoch, num_epochs, i, len(dataloader),

errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

G_losses.append(errG.item())

D_losses.append(errD.item())

if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):

with torch.no_grad():

fake = netG(fixed_noise).detach().cpu()

img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

iters += 1

输出:

Starting Training Loop...

[0/5][0/1583]Loss_D: 1.7834Loss_G: 5.0952D(x): 0.5564D(G(z)): 0.5963 / 0.0094

[0/5][50/1583]Loss_D: 0.2582Loss_G: 28.5604D(x): 0.8865D(G(z)): 0.0000 / 0.0000

[0/5][100/1583]Loss_D: 0.9311Loss_G: 13.3240D(x): 0.9443D(G(z)): 0.4966 / 0.0000

[0/5][150/1583]Loss_D: 0.7385Loss_G: 8.8132D(x): 0.9581D(G(z)): 0.4625 / 0.0004

[0/5][200/1583]Loss_D: 0.4796Loss_G: 6.5862D(x): 0.9888D(G(z)): 0.3271 / 0.0047

[0/5][250/1583]Loss_D: 0.7410Loss_G: 5.4159D(x): 0.8282D(G(z)): 0.3274 / 0.0082

[0/5][300/1583]Loss_D: 0.4622Loss_G: 3.7107D(x): 0.7776D(G(z)): 0.1251 / 0.0375

[0/5][350/1583]Loss_D: 1.0642Loss_G: 6.3149D(x): 0.9374D(G(z)): 0.5391 / 0.0061

[0/5][400/1583]Loss_D: 0.3848Loss_G: 6.3376D(x): 0.9153D(G(z)): 0.2209 / 0.0036

[0/5][450/1583]Loss_D: 0.2790Loss_G: 4.3376D(x): 0.8896D(G(z)): 0.1256 / 0.0217

[0/5][500/1583]Loss_D: 1.2478Loss_G: 8.1121D(x): 0.9361D(G(z)): 0.5578 / 0.0016

[0/5][550/1583]Loss_D: 0.3393Loss_G: 4.0673D(x): 0.8257D(G(z)): 0.0496 / 0.0323

[0/5][600/1583]Loss_D: 0.8083Loss_G: 2.5396D(x): 0.6232D(G(z)): 0.0484 / 0.1265

[0/5][650/1583]Loss_D: 0.3682Loss_G: 4.3142D(x): 0.8227D(G(z)): 0.1114 / 0.0217

[0/5][700/1583]Loss_D: 0.4788Loss_G: 6.2379D(x): 0.8594D(G(z)): 0.2307 / 0.0037

[0/5][750/1583]Loss_D: 0.4767Loss_G: 5.3962D(x): 0.8935D(G(z)): 0.2463 / 0.0092

[0/5][800/1583]Loss_D: 0.8085Loss_G: 2.3573D(x): 0.5934D(G(z)): 0.0769 / 0.1357

[0/5][850/1583]Loss_D: 0.3595Loss_G: 3.9025D(x): 0.7769D(G(z)): 0.0563 / 0.0381

[0/5][900/1583]Loss_D: 0.3235Loss_G: 4.7795D(x): 0.9224D(G(z)): 0.1785 / 0.0163

[0/5][950/1583]Loss_D: 0.3426Loss_G: 3.1228D(x): 0.8257D(G(z)): 0.0847 / 0.0795

[0/5][1000/1583]Loss_D: 0.6667Loss_G: 7.3167D(x): 0.9556D(G(z)): 0.3751 / 0.0019

[0/5][1050/1583]Loss_D: 0.2840Loss_G: 5.0387D(x): 0.9268D(G(z)): 0.1642 / 0.0143

[0/5][1100/1583]Loss_D: 0.4534Loss_G: 3.8780D(x): 0.7535D(G(z)): 0.0697 / 0.0391

[0/5][1150/1583]Loss_D: 0.5040Loss_G: 2.9283D(x): 0.7452D(G(z)): 0.1167 / 0.0839

[0/5][1200/1583]Loss_D: 0.6478Loss_G: 4.0913D(x): 0.6595D(G(z)): 0.0263 / 0.0358

[0/5][1250/1583]Loss_D: 1.2299Loss_G: 7.8236D(x): 0.9850D(G(z)): 0.5941 / 0.0013

[0/5][1300/1583]Loss_D: 0.3228Loss_G: 4.9211D(x): 0.8882D(G(z)): 0.1488 / 0.0140

[0/5][1350/1583]Loss_D: 0.4208Loss_G: 4.1520D(x): 0.8254D(G(z)): 0.1638 / 0.0260

[0/5][1400/1583]Loss_D: 0.5751Loss_G: 3.9585D(x): 0.7692D(G(z)): 0.1902 / 0.0329

[0/5][1450/1583]Loss_D: 1.6244Loss_G: 0.5350D(x): 0.3037D(G(z)): 0.0159 / 0.6617

[0/5][1500/1583]Loss_D: 0.3676Loss_G: 3.2653D(x): 0.8076D(G(z)): 0.0825 / 0.0710

[0/5][1550/1583]Loss_D: 0.2759Loss_G: 4.4156D(x): 0.9010D(G(z)): 0.1370 / 0.0178

[1/5][0/1583]Loss_D: 1.0879Loss_G: 7.8641D(x): 0.8737D(G(z)): 0.5376 / 0.0008

[1/5][50/1583]Loss_D: 0.2761Loss_G: 4.4716D(x): 0.9008D(G(z)): 0.1267 / 0.0231

[1/5][100/1583]Loss_D: 0.3438Loss_G: 4.0343D(x): 0.8389D(G(z)): 0.1162 / 0.0308

[1/5][150/1583]Loss_D: 0.4937Loss_G: 4.8593D(x): 0.7951D(G(z)): 0.1819 / 0.0162

[1/5][200/1583]Loss_D: 0.3973Loss_G: 3.2078D(x): 0.8671D(G(z)): 0.1916 / 0.0587

[1/5][250/1583]Loss_D: 0.4521Loss_G: 4.5155D(x): 0.9006D(G(z)): 0.2441 / 0.0222

[1/5][300/1583]Loss_D: 0.4423Loss_G: 5.3907D(x): 0.8635D(G(z)): 0.2039 / 0.0125

[1/5][350/1583]Loss_D: 0.6447Loss_G: 2.5607D(x): 0.6177D(G(z)): 0.0195 / 0.1284

[1/5][400/1583]Loss_D: 0.4079Loss_G: 4.2563D(x): 0.8621D(G(z)): 0.1949 / 0.0268

[1/5][450/1583]Loss_D: 0.9649Loss_G: 8.0302D(x): 0.9727D(G(z)): 0.5302 / 0.0010

[1/5][500/1583]Loss_D: 0.7693Loss_G: 5.9895D(x): 0.9070D(G(z)): 0.4331 / 0.0053

[1/5][550/1583]Loss_D: 0.4522Loss_G: 2.6169D(x): 0.7328D(G(z)): 0.0634 / 0.1113

[1/5][600/1583]Loss_D: 0.4039Loss_G: 3.4861D(x): 0.8436D(G(z)): 0.1738 / 0.0494

[1/5][650/1583]Loss_D: 0.4434Loss_G: 3.0261D(x): 0.7756D(G(z)): 0.1299 / 0.0777

[1/5][700/1583]Loss_D: 1.5401Loss_G: 8.3636D(x): 0.9705D(G(z)): 0.7050 / 0.0011

[1/5][750/1583]Loss_D: 0.3899Loss_G: 4.3379D(x): 0.7379D(G(z)): 0.0231 / 0.0248

[1/5][800/1583]Loss_D: 0.9547Loss_G: 5.6122D(x): 0.9520D(G(z)): 0.5318 / 0.0074

[1/5][850/1583]Loss_D: 0.3714Loss_G: 3.2116D(x): 0.7770D(G(z)): 0.0752 / 0.0700

[1/5][900/1583]Loss_D: 0.2717Loss_G: 4.0063D(x): 0.8673D(G(z)): 0.1058 / 0.0272

[1/5][950/1583]Loss_D: 0.2652Loss_G: 3.7649D(x): 0.8381D(G(z)): 0.0540 / 0.0361

[1/5][1000/1583]Loss_D: 0.9463Loss_G: 1.6266D(x): 0.5189D(G(z)): 0.0913 / 0.2722

[1/5][1050/1583]Loss_D: 0.7117Loss_G: 3.7363D(x): 0.8544D(G(z)): 0.3578 / 0.0397

[1/5][1100/1583]Loss_D: 0.5164Loss_G: 4.0939D(x): 0.8865D(G(z)): 0.2904 / 0.0252

[1/5][1150/1583]Loss_D: 0.3745Loss_G: 3.1891D(x): 0.8262D(G(z)): 0.1358 / 0.0645

[1/5][1200/1583]Loss_D: 0.4583Loss_G: 2.9545D(x): 0.7866D(G(z)): 0.1453 / 0.0778

[1/5][1250/1583]Loss_D: 0.5870Loss_G: 4.4096D(x): 0.9473D(G(z)): 0.3706 / 0.0208

[1/5][1300/1583]Loss_D: 0.5159Loss_G: 4.1076D(x): 0.8640D(G(z)): 0.2738 / 0.0240

[1/5][1350/1583]Loss_D: 0.6005Loss_G: 1.8590D(x): 0.6283D(G(z)): 0.0418 / 0.2032

[1/5][1400/1583]Loss_D: 0.3646Loss_G: 3.4323D(x): 0.7712D(G(z)): 0.0653 / 0.0534

[1/5][1450/1583]Loss_D: 0.6245Loss_G: 2.2462D(x): 0.6515D(G(z)): 0.0905 / 0.1514

[1/5][1500/1583]Loss_D: 0.6055Loss_G: 1.7674D(x): 0.7026D(G(z)): 0.1682 / 0.2169

[1/5][1550/1583]Loss_D: 0.5181Loss_G: 3.2728D(x): 0.7926D(G(z)): 0.2048 / 0.0549

[2/5][0/1583]Loss_D: 0.9580Loss_G: 5.1154D(x): 0.9605D(G(z)): 0.5535 / 0.0105

[2/5][50/1583]Loss_D: 0.9947Loss_G: 1.7223D(x): 0.4860D(G(z)): 0.0563 / 0.2477

[2/5][100/1583]Loss_D: 0.7023Loss_G: 4.1781D(x): 0.9083D(G(z)): 0.4116 / 0.0239

[2/5][150/1583]Loss_D: 0.3496Loss_G: 2.7264D(x): 0.8871D(G(z)): 0.1795 / 0.0982

[2/5][200/1583]Loss_D: 0.6805Loss_G: 3.8157D(x): 0.8900D(G(z)): 0.3851 / 0.0312

[2/5][250/1583]Loss_D: 0.6193Loss_G: 3.8180D(x): 0.8557D(G(z)): 0.3286 / 0.0303

[2/5][300/1583]Loss_D: 0.6480Loss_G: 1.4683D(x): 0.6157D(G(z)): 0.0640 / 0.2844

[2/5][350/1583]Loss_D: 0.7498Loss_G: 4.1299D(x): 0.8922D(G(z)): 0.4244 / 0.0256

[2/5][400/1583]Loss_D: 0.7603Loss_G: 4.2291D(x): 0.9512D(G(z)): 0.4604 / 0.0213

[2/5][450/1583]Loss_D: 0.4833Loss_G: 4.0068D(x): 0.9348D(G(z)): 0.3095 / 0.0257

[2/5][500/1583]Loss_D: 1.2311Loss_G: 0.7107D(x): 0.3949D(G(z)): 0.0496 / 0.5440

[2/5][550/1583]Loss_D: 0.9657Loss_G: 1.5119D(x): 0.4513D(G(z)): 0.0338 / 0.2821

[2/5][600/1583]Loss_D: 0.5351Loss_G: 3.4546D(x): 0.8889D(G(z)): 0.3018 / 0.0449

[2/5][650/1583]Loss_D: 0.8761Loss_G: 1.2051D(x): 0.5292D(G(z)): 0.1193 / 0.3583

[2/5][700/1583]Loss_D: 1.0206Loss_G: 4.5741D(x): 0.8599D(G(z)): 0.5140 / 0.0159

[2/5][750/1583]Loss_D: 1.0886Loss_G: 5.4749D(x): 0.9770D(G(z)): 0.6093 / 0.0067

[2/5][800/1583]Loss_D: 0.6539Loss_G: 3.5203D(x): 0.9074D(G(z)): 0.3962 / 0.0390

[2/5][850/1583]Loss_D: 0.8633Loss_G: 1.0995D(x): 0.5701D(G(z)): 0.1401 / 0.3842

[2/5][900/1583]Loss_D: 0.3703Loss_G: 2.2482D(x): 0.8183D(G(z)): 0.1302 / 0.1329

[2/5][950/1583]Loss_D: 0.6592Loss_G: 1.6081D(x): 0.6040D(G(z)): 0.0818 / 0.2523

[2/5][1000/1583]Loss_D: 0.7449Loss_G: 1.0548D(x): 0.5975D(G(z)): 0.1375 / 0.4085

[2/5][1050/1583]Loss_D: 0.5783Loss_G: 2.3644D(x): 0.6435D(G(z)): 0.0531 / 0.1357

[2/5][1100/1583]Loss_D: 0.6123Loss_G: 2.2695D(x): 0.7269D(G(z)): 0.2083 / 0.1343

[2/5][1150/1583]Loss_D: 0.6263Loss_G: 1.8714D(x): 0.6661D(G(z)): 0.1407 / 0.1914

[2/5][1200/1583]Loss_D: 0.4233Loss_G: 3.0119D(x): 0.8533D(G(z)): 0.2039 / 0.0692

[2/5][1250/1583]Loss_D: 0.8826Loss_G: 3.3618D(x): 0.7851D(G(z)): 0.3971 / 0.0502

[2/5][1300/1583]Loss_D: 0.6201Loss_G: 2.1584D(x): 0.6418D(G(z)): 0.0977 / 0.1536

[2/5][1350/1583]Loss_D: 0.9558Loss_G: 3.8876D(x): 0.8561D(G(z)): 0.5001 / 0.0302

[2/5][1400/1583]Loss_D: 0.4369Loss_G: 2.3479D(x): 0.7959D(G(z)): 0.1588 / 0.1214

[2/5][1450/1583]Loss_D: 0.5086Loss_G: 2.1034D(x): 0.6758D(G(z)): 0.0586 / 0.1575

[2/5][1500/1583]Loss_D: 0.6513Loss_G: 3.5801D(x): 0.8535D(G(z)): 0.3429 / 0.0455

[2/5][1550/1583]Loss_D: 0.6975Loss_G: 2.5560D(x): 0.7379D(G(z)): 0.2784 / 0.1031

[3/5][0/1583]Loss_D: 2.2846Loss_G: 1.7977D(x): 0.1771D(G(z)): 0.0111 / 0.2394

[3/5][50/1583]Loss_D: 1.6111Loss_G: 5.7904D(x): 0.9581D(G(z)): 0.7350 / 0.0063

[3/5][100/1583]Loss_D: 0.8553Loss_G: 1.0540D(x): 0.5229D(G(z)): 0.1020 / 0.3945

[3/5][150/1583]Loss_D: 0.7402Loss_G: 2.6338D(x): 0.7668D(G(z)): 0.3277 / 0.0959

[3/5][200/1583]Loss_D: 0.9278Loss_G: 2.9689D(x): 0.8913D(G(z)): 0.4787 / 0.0769

[3/5][250/1583]Loss_D: 2.6573Loss_G: 6.4810D(x): 0.9684D(G(z)): 0.8799 / 0.0035

[3/5][300/1583]Loss_D: 0.5435Loss_G: 1.9416D(x): 0.7118D(G(z)): 0.1454 / 0.1801

[3/5][350/1583]Loss_D: 1.2350Loss_G: 4.6877D(x): 0.9595D(G(z)): 0.6444 / 0.0147

[3/5][400/1583]Loss_D: 0.9264Loss_G: 0.9139D(x): 0.4825D(G(z)): 0.0715 / 0.4526

[3/5][450/1583]Loss_D: 0.8967Loss_G: 4.4258D(x): 0.9155D(G(z)): 0.5074 / 0.0174

[3/5][500/1583]Loss_D: 0.6874Loss_G: 2.4529D(x): 0.7775D(G(z)): 0.3171 / 0.1097

[3/5][550/1583]Loss_D: 0.5821Loss_G: 3.0756D(x): 0.8681D(G(z)): 0.3161 / 0.0609

[3/5][600/1583]Loss_D: 0.7164Loss_G: 1.5045D(x): 0.5652D(G(z)): 0.0428 / 0.2868

[3/5][650/1583]Loss_D: 0.6290Loss_G: 2.1863D(x): 0.7952D(G(z)): 0.2829 / 0.1442

[3/5][700/1583]Loss_D: 0.6270Loss_G: 1.2824D(x): 0.6481D(G(z)): 0.1184 / 0.3234

[3/5][750/1583]Loss_D: 0.7011Loss_G: 1.3549D(x): 0.5861D(G(z)): 0.0926 / 0.3017

[3/5][800/1583]Loss_D: 0.6912Loss_G: 1.4927D(x): 0.5919D(G(z)): 0.0741 / 0.2728

[3/5][850/1583]Loss_D: 0.6385Loss_G: 2.9333D(x): 0.8418D(G(z)): 0.3338 / 0.0723

[3/5][900/1583]Loss_D: 0.7835Loss_G: 4.4475D(x): 0.9290D(G(z)): 0.4703 / 0.0151

[3/5][950/1583]Loss_D: 0.6294Loss_G: 2.3463D(x): 0.7388D(G(z)): 0.2414 / 0.1202

[3/5][1000/1583]Loss_D: 0.6288Loss_G: 1.5448D(x): 0.6575D(G(z)): 0.1389 / 0.2581

[3/5][1050/1583]Loss_D: 0.6292Loss_G: 3.4867D(x): 0.8741D(G(z)): 0.3549 / 0.0433

[3/5][1100/1583]Loss_D: 0.7644Loss_G: 1.7661D(x): 0.5457D(G(z)): 0.0408 / 0.2076

[3/5][1150/1583]Loss_D: 0.4918Loss_G: 3.1858D(x): 0.8576D(G(z)): 0.2563 / 0.0527

[3/5][1200/1583]Loss_D: 1.1773Loss_G: 4.5200D(x): 0.8192D(G(z)): 0.5536 / 0.0183

[3/5][1250/1583]Loss_D: 0.6889Loss_G: 1.8073D(x): 0.6909D(G(z)): 0.2230 / 0.1969

[3/5][1300/1583]Loss_D: 0.9721Loss_G: 1.0578D(x): 0.4541D(G(z)): 0.0570 / 0.4080

[3/5][1350/1583]Loss_D: 0.5301Loss_G: 2.3562D(x): 0.7453D(G(z)): 0.1670 / 0.1222

[3/5][1400/1583]Loss_D: 0.5464Loss_G: 2.5304D(x): 0.8018D(G(z)): 0.2438 / 0.1020

[3/5][1450/1583]Loss_D: 0.5987Loss_G: 2.2034D(x): 0.6195D(G(z)): 0.0601 / 0.1477

[3/5][1500/1583]Loss_D: 1.4470Loss_G: 4.2791D(x): 0.9006D(G(z)): 0.6537 / 0.0221

[3/5][1550/1583]Loss_D: 0.7917Loss_G: 3.3235D(x): 0.8287D(G(z)): 0.4002 / 0.0489

[4/5][0/1583]Loss_D: 0.7682Loss_G: 1.2445D(x): 0.5371D(G(z)): 0.0538 / 0.3386

[4/5][50/1583]Loss_D: 0.9274Loss_G: 0.9439D(x): 0.4905D(G(z)): 0.1004 / 0.4476

[4/5][100/1583]Loss_D: 0.9571Loss_G: 0.7391D(x): 0.4619D(G(z)): 0.0511 / 0.5431

[4/5][150/1583]Loss_D: 1.4795Loss_G: 0.7522D(x): 0.3092D(G(z)): 0.0387 / 0.5307

[4/5][200/1583]Loss_D: 0.5203Loss_G: 1.8662D(x): 0.7279D(G(z)): 0.1425 / 0.1895

[4/5][250/1583]Loss_D: 0.8140Loss_G: 1.9120D(x): 0.5155D(G(z)): 0.0606 / 0.1939

[4/5][300/1583]Loss_D: 0.5813Loss_G: 2.5807D(x): 0.7674D(G(z)): 0.2255 / 0.1008

[4/5][350/1583]Loss_D: 0.5209Loss_G: 2.8571D(x): 0.8125D(G(z)): 0.2389 / 0.0743

[4/5][400/1583]Loss_D: 0.4505Loss_G: 2.7965D(x): 0.8221D(G(z)): 0.2014 / 0.0805

[4/5][450/1583]Loss_D: 0.4919Loss_G: 2.4360D(x): 0.8148D(G(z)): 0.2163 / 0.1100

[4/5][500/1583]Loss_D: 0.5861Loss_G: 1.8476D(x): 0.7139D(G(z)): 0.1733 / 0.1968

[4/5][550/1583]Loss_D: 0.3823Loss_G: 2.7134D(x): 0.8286D(G(z)): 0.1591 / 0.0833

[4/5][600/1583]Loss_D: 0.8388Loss_G: 4.0517D(x): 0.9135D(G(z)): 0.4704 / 0.0238

[4/5][650/1583]Loss_D: 1.1851Loss_G: 3.8484D(x): 0.9364D(G(z)): 0.6310 / 0.0301

[4/5][700/1583]Loss_D: 0.6797Loss_G: 1.6355D(x): 0.6011D(G(z)): 0.0880 / 0.2444

[4/5][750/1583]Loss_D: 0.6017Loss_G: 1.8937D(x): 0.7011D(G(z)): 0.1684 / 0.1909

[4/5][800/1583]Loss_D: 0.6368Loss_G: 1.7310D(x): 0.6652D(G(z)): 0.1495 / 0.2195

[4/5][850/1583]Loss_D: 0.7758Loss_G: 0.8409D(x): 0.5400D(G(z)): 0.0775 / 0.4691

[4/5][900/1583]Loss_D: 0.5234Loss_G: 1.7439D(x): 0.6728D(G(z)): 0.0839 / 0.2216

[4/5][950/1583]Loss_D: 0.6529Loss_G: 3.4036D(x): 0.9078D(G(z)): 0.3899 / 0.0443

[4/5][1000/1583]Loss_D: 0.6068Loss_G: 2.1435D(x): 0.7773D(G(z)): 0.2603 / 0.1434

[4/5][1050/1583]Loss_D: 0.9208Loss_G: 2.4387D(x): 0.7600D(G(z)): 0.4164 / 0.1163

[4/5][1100/1583]Loss_D: 0.6253Loss_G: 1.8932D(x): 0.6321D(G(z)): 0.0981 / 0.1835

[4/5][1150/1583]Loss_D: 0.6524Loss_G: 2.7757D(x): 0.7961D(G(z)): 0.2996 / 0.0823

[4/5][1200/1583]Loss_D: 0.5320Loss_G: 2.8334D(x): 0.8048D(G(z)): 0.2383 / 0.0781

[4/5][1250/1583]Loss_D: 0.8212Loss_G: 1.3884D(x): 0.5531D(G(z)): 0.1236 / 0.3016

[4/5][1300/1583]Loss_D: 0.4568Loss_G: 2.6822D(x): 0.8278D(G(z)): 0.2067 / 0.0912

[4/5][1350/1583]Loss_D: 0.6665Loss_G: 1.3834D(x): 0.6517D(G(z)): 0.1532 / 0.2904

[4/5][1400/1583]Loss_D: 0.4927Loss_G: 1.8337D(x): 0.7101D(G(z)): 0.1022 / 0.1965

[4/5][1450/1583]Loss_D: 2.2483Loss_G: 0.2021D(x): 0.1705D(G(z)): 0.0452 / 0.8293

[4/5][1500/1583]Loss_D: 0.5997Loss_G: 2.0054D(x): 0.6909D(G(z)): 0.1507 / 0.1733

[4/5][1550/1583]Loss_D: 1.0521Loss_G: 4.8488D(x): 0.9193D(G(z)): 0.5659 / 0.0120

结果(Results)

最后,让我们看看我们是如何做到对抗生成的。这里,我们将会从三个不同的方面展示。首先,我们将看下D和G在训练过程中损失是如何变化的。第二,我们将会把训练过程中每个Epoch结束,固定噪声在G的输出图片可视化。第三,我们将会看到真图片和来G产生的假图片的对比。

训练过程中的对抗损失(Loss versus training iteration)

下面是生成器和判别器的损失对比图。

可视化生成器的进度(Visualization of G’s progression)

还记得我们是如何在训练时保存固定噪声在生成器G的输出的。现在,我们可以通过动画展示其训练过程。按下play按钮来开启动画。(注意,想要看动画,需在Jupyter Notebook环境下运行代码。因为 HTML(animator.to_jshtml()) 将动图在 Jupyter Notebook 里展示。)

真假图片(Real Images vs. Fake Images)

最后,让我们把真假图片并排(左侧真实图片,右侧假),对比看下。

real_batch = next(iter(dataloader))

plt.figure(figsize=(15,15))

plt.subplot(1,2,1)

plt.axis("off")

plt.title("Real Images")

plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

plt.subplot(1,2,2)

plt.axis("off")

plt.title("Fake Images")

plt.imshow(np.transpose(img_list[-1],(1,2,0)))

plt.show()

下一步(Where to Go Next)

我们已经到达旅程的终点了,不过这里有几个地方你可以去:

训练更长时间来看results有什么变化修改这个模型:不同的数据集,或不同的图片大小,或模型的结构试试更酷的GAN项目here创建GANs来生成music

相关知识

GAN生成对抗网络:花卉生成
基于生成对抗网络的植物景观生成设计——以花境平面图生成为例
【大虾送书第二期】《Python机器学习:基于PyTorch和Scikit
精准农业的智能化:大模型在作物监测与产量预测中的应用
基于pytorch搭建ResNet神经网络用于花类识别
基于pytorch搭建ResNet神经网络用于花类识别持续创作,加速成长!这是我参与「掘金日新计划 · 6 月更文挑战」
ResNet残差网络在PyTorch中的实现——训练花卉分类器
使用PyTorch实现对花朵的分类
图像分类:AlexNet网络、五分类 flower 数据集、pytorch
基于pytorch搭建VGGNet神经网络用于花类识别

网址: Pytorch生成对抗网络(GAN)官方入门教程 https://m.huajiangbk.com/newsview1159286.html

所属分类:花卉
上一篇: 2024新版计算机网络视频教程6
下一篇: 神经网络画图?PPT就够了