您当前的位置:首页 > 电脑百科 > 网络技术 > 网络软件

5分钟入门GANS:原理解释和keras代码实现

时间:2020-08-24 10:14:25  来源:  作者:

介绍

生成对抗网络通常也称为GANs,用于生成图像而不需要很少或没有输入。GANs允许我们生成由神经网络生成的图像。在我们深入讨论这个理论之前,我想向您展示GANs构建您兴奋感的能力。把马变成斑马(反之亦然)。

5分钟入门GANS:原理解释和keras代码实现

 


5分钟入门GANS:原理解释和keras代码实现

 

历史

生成式对抗网络(GANs)是由Ian Goodfellow (GANs的GAN Father)等人于2014年在其题为"生成式对抗网络"的论文中提出的。它是一种可替代的自适应变分编码器(VAEs)学习图像的潜在空间,以生成合成图像。它的目的是创造逼真的人工图像,几乎无法与真实的图像区分。

GAN的直观解释

生成器和鉴别器网络:

生成器网络的目的是将随机图像初始化并解码成一个合成图像。

鉴别器网络的目的是获取这个输入,并预测这个图像是来自真实的数据集还是合成的。

正如我们刚才看到的,这实际上就是GANs,两个相互竞争的对抗网络。

GAN的训练过程

GANS的训练是出了名的困难。在CNN中,我们使用梯度下降来改变权重以减少损失。

然而,在GANs中,每一次重量的变化都会改变整个动态系统的平衡。

在GAN的网络中,我们不是在寻求将损失最小化,而是在我们对立的两个网络之间找到一种平衡。

我们将过程总结如下

1. 输入随机生成的噪声图像到我们的生成器网络中生成样本图像。

1. 我们从真实数据中提取一些样本图像,并将其与一些生成的图像混合在一起。

1. 将这些混合图像输入到我们的鉴别器中,鉴别器将对这个混合集进行训练并相应地更新它的权重。

1. 然后我们制作更多的假图像,并将它们输入到鉴别器中,但是我们将它们标记为真实的。这样做是为了训练生成器。我们在这个阶段冻结了鉴别器的权值(鉴别器学习停止),并且我们使用来自鉴别器的反馈来更新生成器的权值。这就是我们如何教我们的生成器(制作更好的合成图像)和鉴别器更好地识别赝品的方法。

流程图如下

5分钟入门GANS:原理解释和keras代码实现

 

对于本文,我们将使用MNIST数据集生成手写数字。GAN的架构是:

5分钟入门GANS:原理解释和keras代码实现

 

使用KERAS实现GANS

首先,我们加载所有必要的库

import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

from keras.layers import Input
from keras.models import Model, Sequential
from keras.layers.core import Reshape, Dense, Dropout, Flatten
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Convolution2D, UpSampling2D
from keras.layers.normalization import BatchNormalization
from keras.datasets import mnist
from keras.optimizers import Adam
from keras import backend as K
from keras import initializers

K.set_image_dim_ordering('th')

# Deterministic output.
# Tired of seeing the same results every time? Remove the line below.
np.random.seed(1000)

# The results are a little better when the dimensionality of the random vector is only 10.
# The dimensionality has been left at 100 for consistency with other GAN implementations.
randomDim = 100

现在我们加载数据集。这里使用MNIST数据集,所以不需要单独下载和处理。

(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5)/127.5
X_train = X_train.reshape(60000, 784)

接下来,我们定义生成器和鉴别器的结构

# Optimizer
adam = Adam(lr=0.0002, beta_1=0.5)#generator
generator = Sequential()
generator.add(Dense(256, input_dim=randomDim, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
generator.add(LeakyReLU(0.2))
generator.add(Dense(512))
generator.add(LeakyReLU(0.2))
generator.add(Dense(1024))
generator.add(LeakyReLU(0.2))
generator.add(Dense(784, activation='tanh'))
generator.compile(loss='binary_crossentropy', optimizer=adam)#discriminator
discriminator = Sequential()
discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(512))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(256))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=adam)

现在我们把发生器和鉴别器结合起来同时训练。

# Combined network
discriminator.trainable = False
ganInput = Input(shape=(randomDim,))
x = generator(ganInput)
ganOutput = discriminator(x)
gan = Model(inputs=ganInput, outputs=ganOutput)
gan.compile(loss='binary_crossentropy', optimizer=adam)

dLosses = []
gLosses = []

三个函数,每20个epoch绘制并保存结果,并保存模型。

# Plot the loss from each batch
def plotLoss(epoch):
    plt.figure(figsize=(10, 8))
    plt.plot(dLosses, label='Discriminitive loss')
    plt.plot(gLosses, label='Generative loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('images/gan_loss_epoch_%d.png' % epoch)

# Create a wall of generated MNIST images
def plotGeneratedImages(epoch, examples=100, dim=(10, 10), figsize=(10, 10)):
    noise = np.random.normal(0, 1, size=[examples, randomDim])
    generatedImages = generator.predict(noise)
    generatedImages = generatedImages.reshape(examples, 28, 28)

    plt.figure(figsize=figsize)
    for i in range(generatedImages.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generatedImages[i], interpolation='nearest', cmap='gray_r')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('images/gan_generated_image_epoch_%d.png' % epoch)

# Save the generator and discriminator networks (and weights) for later use
def saveModels(epoch):
    generator.save('models/gan_generator_epoch_%d.h5' % epoch)
    discriminator.save('models/gan_discriminator_epoch_%d.h5' % epoch)

训练函数

def train(epochs=1, batchSize=128):
    batchCount = X_train.shape[0] / batchSize
    print 'Epochs:', epochs
    print 'Batch size:', batchSize
    print 'Batches per epoch:', batchCount

    for e in xrange(1, epochs+1):
        print '-'*15, 'Epoch %d' % e, '-'*15
        for _ in tqdm(xrange(batchCount)):
            # Get a random set of input noise and images
            noise = np.random.normal(0, 1, size=[batchSize, randomDim])
            imageBatch = X_train[np.random.randint(0, X_train.shape[0], size=batchSize)]

            # Generate fake MNIST images
            generatedImages = generator.predict(noise)
            # print np.shape(imageBatch), np.shape(generatedImages)
            X = np.concatenate([imageBatch, generatedImages])

            # Labels for generated and real data
            yDis = np.zeros(2*batchSize)
            # One-sided label smoothing
            yDis[:batchSize] = 0.9

            # Train discriminator
            discriminator.trainable = True
            dloss = discriminator.train_on_batch(X, yDis)

            # Train generator
            noise = np.random.normal(0, 1, size=[batchSize, randomDim])
            yGen = np.ones(batchSize)
            discriminator.trainable = False
            gloss = gan.train_on_batch(noise, yGen)

        # Store loss of most recent batch from this epoch
        dLosses.Append(dloss)
        gLosses.append(gloss)

        if e == 1 or e % 20 == 0:
            plotGeneratedImages(e)
            saveModels(e)

    # Plot losses from every epoch
    plotLoss(e)

至此一个简单的GAN已经完成了,完整的代码在这里找到

github/bhaveshgoyal27/mediumblogs/blob/master/KerasMNISTGAN.py

作者:Bhavesh Goyal

deephub翻译组



Tags:GANS   点击:()  评论:()
声明:本站部分内容及图片来自互联网,转载是出于传递更多信息之目的,内容观点仅代表作者本人,如有任何标注错误或版权侵犯请与我们联系(Email:2595517585@qq.com),我们将及时更正、删除,谢谢。
▌相关推荐
介绍生成对抗网络通常也称为GANs,用于生成图像而不需要很少或没有输入。GANs允许我们生成由神经网络生成的图像。在我们深入讨论这个理论之前,我想向您展示GANs构建您兴奋感的...【详细内容】
2020-08-24  Tags: GANS  点击:(93)  评论:(0)  加入收藏
▌简易百科推荐
说到远程控制,首先你会想到的是什么?是TeamViewer 还是向日葵?抑或是QQ远程还是anydesk?对,就在不久前,我们熟知的都是以上的产品,但是只2020年开始,一款新的远控产品ToDesk进入到我...【详细内容】
2021-12-27  网管世界    Tags:ToDesk   点击:(3)  评论:(0)  加入收藏
# 1. nps-npc1.1 简介nps是一款轻量级、高性能、功能强大的内网穿透代理服务器。目前支持tcp、udp流量转发,可支持任何tcp、udp上层协议(访问内网网站、本地支付接口调试、ssh...【详细内容】
2021-12-22  大数据推荐杂谈    Tags:内网穿透   点击:(8)  评论:(0)  加入收藏
“磨刀不误砍柴工”。 优秀的工具有助于提高工作效率,安全工程师也需要优秀的安全软件来提高工作效率。 在具体的工作场景中,有很多种选择,这里有10种开源的免费安全工具,不仅可...【详细内容】
2021-11-23  山东云管家官方    Tags:安全工具   点击:(33)  评论:(0)  加入收藏
火绒安全软件是一款小巧精悍、独立纯粹的国产安全软件.有很多网友都下载安装了火绒安全软件使用.那么火绒安全软件怎么样呢,火绒安全软件好用吗?下面小编就给大家分析下详解...【详细内容】
2021-11-03  装机吧    Tags:火绒   点击:(34)  评论:(0)  加入收藏
背景上次给大家介绍了实现基础的运维系统功能—webssh,今日书接上回,继续给大家介绍一个web远程ssh终端录像回放功能。 一、思路网上查了一下资料,搜索了一下关于实现webs...【详细内容】
2021-10-13  小堂运维笔记    Tags:ssh终端   点击:(40)  评论:(0)  加入收藏
QuickPing快速Ping扫描器QuickPing,哪些地址已经使用,哪些可用,图形界面非常直观,而且可以导出列表,该软件体积很小,可以快速的知道网段内哪些主机已经开启,ping成功的即显示出不同...【详细内容】
2021-10-11  海南弱电李工    Tags:网管   点击:(66)  评论:(0)  加入收藏
1、每个项目根据现场的网络环境不同,需要定义不同的IP地址,通过此工具可以快速配置。而且有助于做项目实施资料。2、以前连接过的wifi密码自带记忆功能,通过检索对应的WiFi名字...【详细内容】
2021-10-08  IT游侠    Tags:局域网管理   点击:(49)  评论:(0)  加入收藏
01概述无论是开发还是测试,在工作中经常会遇到需要抓包的时候。本篇文章主要介绍如何在各个平台下,高效的抓包。目前的抓包软件总体可以分为两类: 一种是设置代理抓取http包,比...【详细内容】
2021-09-28  小码哥聊软件测试    Tags:网络抓包   点击:(100)  评论:(0)  加入收藏
Fiddler 简介Fiddler 是位于客户端和服务器端的 HTTP 代理 目前最常用的 http 抓包工具之一 功能非常强大,是 Web 调试的利器关注+转发+私信【软件测试】领取Fiddler安装包和...【详细内容】
2021-09-28  土豆聊软件测试    Tags:抓包工具   点击:(63)  评论:(0)  加入收藏
前言上次有写过一篇《20张图深度详解MAC地址表、ARP表、路由表》的文章,里面有提到了MAC地址表。那么什么是MAC地址表?MAC地址表有什么作用?MAC地址表里面包含了哪些要素?今天...【详细内容】
2021-09-09  网络工程师笔记    Tags:MAC地址表   点击:(76)  评论:(0)  加入收藏
相关文章
    无相关信息
最新更新
栏目热门
栏目头条