当前位置: 首页 > 资讯

【原创】cat-generator 基于生成对抗网络生成喵星人的头像

论智       2017-11-07

cat-generator是一个很有意思的项目,基于生成对抗网络生成喵星人的头像。

cats

256张随机生成的喵星人头像


生成对抗网络,Generative Adversarial Network,简称GAN,是由Goodfellow等人提出的非监督式学习的一种方法。


生成对抗网络由一个生成网络与一个判别网络组成。生成网络的输出结果需要尽量模仿训练集中的真实样本,尽可能地欺骗判别网络。判别网络则尽可能将生成网络的输出从真实样本中分辨出来。两个网络相互对抗、不断调整参数,最终使判别网络无法将生成网络的输出从真实样本中分辨出来,达到以假乱真的效果。

生成对抗网络原理示意

网络架构


cat-generator实际使用的架构,除了生成网络和判别网络外,还包括验证网络。


生成网络基本上就是一个拉普拉斯金字塔。生成图像的大小随着层数的增加而增大。


local model = nn.Sequential()

-- 4x4

-- noiseDim是一个100维矢量,值取自-1和+1之间

model:add(nn.Linear(noiseDim, 512*4*4))

model:add(nn.PReLU(nil, nil, true))

model:add(nn.View(512, 4, 4))


-- 4x4 -> 8x8

model:add(nn.SpatialUpSamplingNearest(2))

model:add(cudnn.SpatialConvolution(512, 512, 3, 3, 1, 1, (3-1)/2, (3-1)/2))

model:add(nn.SpatialBatchNormalization(512))

model:add(nn.PReLU(nil, nil, true))


-- 8x8 -> 16x16

model:add(nn.SpatialUpSamplingNearest(2))

model:add(cudnn.SpatialConvolution(512, 256, 3, 3, 1, 1, (3-1)/2, (3-1)/2))

model:add(nn.SpatialBatchNormalization(256))

model:add(nn.PReLU(nil, nil, true))

-- 16x16 -> 32x32

model:add(nn.SpatialUpSamplingNearest(2))

model:add(cudnn.SpatialConvolution(256, 128, 5, 5, 1, 1, (5-1)/2, (5-1)/2))

model:add(nn.SpatialBatchNormalization(128))

model:add(nn.PReLU(nil, nil, true))


-- dimensions[1]共4维,包括3位颜色信息,和1位灰度信息。

model:add(cudnn.SpatialConvolution(128, dimensions[1], 3, 3, 1, 1, (3-1)/2, (3-1)/2))

model:add(nn.Sigmoid())

判别网络是一个多分支的卷积网络。它首先使用一个空间转换函数去除旋转信息,4个分支中的3个也包含反转、平移、拉伸的空间转换函数,以便专注学习图像的特定区域。第4个分支尝试分析整张图片。

判别网络

验证网络是一个标准的卷积网络。和判别网络类似,验证网络给生成网络生成的图像评分,判断“伪造”的图像有多“假”。验证网络可以识别一些非常假的图片,从而提高整个生成对抗网络的训练效率。


local model = nn.Sequential()

local activation = nn.LeakyReLU


model:add(nn.SpatialConvolution(dimensions[1], 128, 3, 3, 1, 1, (3-1)/2))

model:add(activation())

model:add(nn.SpatialMaxPooling(2, 2))

model:add(nn.SpatialConvolution(128, 128, 3, 3, 1, 1, (3-1)/2))

model:add(nn.SpatialBatchNormalization(128))

model:add(activation())

model:add(nn.SpatialMaxPooling(2, 2))

model:add(nn.Dropout())


model:add(nn.SpatialConvolution(128, 256, 3, 3, 1, 1, (3-1)/2))

model:add(activation())

model:add(nn.SpatialConvolution(256, 256, 3, 3, 1, 1, (3-1)/2))

model:add(nn.SpatialBatchNormalization(256))

model:add(activation())

model:add(nn.SpatialMaxPooling(2, 2))

model:add(nn.SpatialDropout())

local imgSize = 0.25 * 0.25 * 0.25 * dimensions[2] * dimensions[3]

model:add(nn.View(256 * imgSize))

model:add(nn.Linear(256 * imgSize, 1024))

model:add(nn.BatchNormalization(1024))

model:add(activation())

model:add(nn.Dropout())


model:add(nn.Linear(1024, 1024))

model:add(nn.BatchNormalization(1024))

model:add(activation())

model:add(nn.Dropout())


model:add(nn.Linear(1024, 2))

model:add(nn.SoftMax())

快速上手


想自己生成喵星人头像?安装依赖先:


Torch以及相应的库(大多数库应该已经是默认安装的)

nn (luarocks install nn)

pl (luarocks install pl)

paths (luarocks install paths)

image (luarocks install image)

optim (luarocks install optim)

cutorch (luarocks install cutorch)

cunn (luarocks install cunn)

cudnn (luarocks install cudnn)

dpnn (luarocks install dpnn)

stn

display

Python 2.7 以及对应的库

scipy

numpy

scikit-image

另外,请确保你的显卡支持CUDA及cudnn3(显存4GB或以上)。如果你的显卡不符合要求,可以参考这篇指南选购新显卡。


然后克隆cats-generator仓库:


git clone https://github.com/aleju/cat-generator.git

然后下载喵星人数据并解压,解压后的目录应该包含CAT_00到CAT_06的子目录。


转换喵星人数据格式:


cd cat-generator/dataset

python generate_dataset.py --path="/喵星人/数据/目录"

转换大概需要2小时。


万事俱备,可以开始训练了。


启动服务:


th -ldisplay.start

在浏览器访问 http://localhost:8000/


训练验证网络:


th train_v.lua

当看到saving network to <path>提示后,手动停止训练。


预训练生成网络:


th pretrain_g.lua

和上一步一样,看到提示后手动停止。


训练整个网络:


th train.lua

如果效果不佳的话,尝试加上--D_iterations=2参数。


如果希望得到灰度图像,则给上面的步骤分别加上--colorSpace="y"

灰度图像

1024张生成的灰度图像


机器人网原创文章,未经授权禁止转载。详情见转载须知

本文来自机器人网,如若转载,请注明出处:https://www.jqr.com/news/008511