当前位置:网站首页>动手学深度学习_风格迁移
动手学深度学习_风格迁移
2022-08-09 16:44:00 【CV小Rookie】
风格迁移(style transfer)是让一张图片内容不发生改变,但样式改为另一张图片效果。
这里所使用的风格迁移并不是基于 GAN 的,而是基于卷积神经网络的风格迁移方法(当然现在主流的风格迁移是基于 GAN 的,感兴趣的可以了解一下 之间写过的一些 GAN 的介绍)
首先,我们初始化合成图像,例如将其初始化为内容图像。 该合成图像是风格迁移过程中唯一需要更新的变量,即风格迁移所需迭代的模型参数。 然后,我们选择一个预训练的卷积神经网络来抽取图像的特征,其中的模型参数在训练中无须更新。 这个深度卷积神经网络凭借多个层逐级抽取图像的特征,我们可以选择其中某些层的输出作为内容特征或风格特征。
接下来,我们通过前向传播(实线箭头方向)计算风格迁移的损失函数,并通过反向传播(虚线箭头方向)迭代模型参数,即不断更新合成图像。
风格迁移常用的损失函数由3部分组成:
(i)内容损失使合成图像与内容图像在内容特征上接近;
(ii)风格损失使合成图像与风格图像在风格特征上接近;
(iii)全变分损失则有助于减少合成图像中的噪点。
最后,当模型训练结束时,我们输出风格迁移的模型参数,即得到最终的合成图像。
搭建网络
使用基于 ImageNet 数据集预训练的 VGG-19 模型来抽取图像特征
pretrained_net = torchvision.models.vgg19(pretrained=True)
抽取特征,我们可以选择VGG网络中某些层的输出。 一般来说,越靠近输入层,越容易抽取图像的细节信息;反之,则越容易抽取图像的全局信息。 为了避免合成图像过多保留内容图像的细节,我们选择VGG较靠近输出的层,即内容层,来输出图像的内容特征。 我们还从VGG中选择不同层的输出来匹配局部和全局的风格,这些图层也称为风格层。
style_layers, content_layers = [0, 5, 10, 19, 28], [25]
net = nn.Sequential(*[pretrained_net.features[i] for i in
range(max(content_layers + style_layers) + 1)])
# 保存风格层和内容层的输出
def extract_features(X, content_layers, style_layers):
contents = []
styles = []
for i in range(len(net)):
X = net[i](X)
if i in style_layers:
styles.append(X)
if i in content_layers:
contents.append(X)
return contents, styles
# get_contents函数对内容图像抽取内容特征
def get_contents(image_shape, device):
content_X = preprocess(content_img, image_shape).to(device)
contents_Y, _ = extract_features(content_X, content_layers, style_layers)
return content_X, contents_Y
# get_styles函数对风格图像抽取风格特征
def get_styles(image_shape, device):
style_X = preprocess(style_img, image_shape).to(device)
_, styles_Y = extract_features(style_X, content_layers, style_layers)
return style_X, styles_Y
损失函数
由三部分组成:内容损失,风格损失,全变分损失
我们学到的合成图像里面有大量高频噪点,即有特别亮或者特别暗的颗粒像素。 一种常见的去噪方法是全变分去噪(total variation denoising): 假设 表示坐标
处的像素值,降低全变分损失
能够尽可能使邻近的像素值相似。
# 内容损失
def content_loss(Y_hat, Y):
# 我们从动态计算梯度的树中分离目标:
# 这是一个规定的值,而不是一个变量。
return torch.square(Y_hat - Y.detach()).mean()
# 风格损失
def gram(X):
num_channels, n = X.shape[1], X.numel() // X.shape[1]
X = X.reshape((num_channels, n))
return torch.matmul(X, X.T) / (num_channels * n)
# 全变分损失
def tv_loss(Y_hat):
return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +
torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())
content_weight, style_weight, tv_weight = 1, 1e3, 10
def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):
# 分别计算内容损失、风格损失和全变分损失
contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(
contents_Y_hat, contents_Y)]
styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(
styles_Y_hat, styles_Y_gram)]
tv_l = tv_loss(X) * tv_weight
# 对所有损失求和
l = sum(10 * styles_l + contents_l + [tv_l])
return contents_l, styles_l, tv_l, l
初始化合成图像
与之前训练的网络不同,之前的网络训练的参数是每层网络里面的权重。但是在卷积神经网络的风格迁移中,唯一需要更新的变量是最后需要合成的图像。我们可以定义一个简单的模型SynthesizedImage
,并将合成的图像视为模型参数。模型的前向传播只需返回模型参数。
class SynthesizedImage(nn.Module):
def __init__(self, img_shape, **kwargs):
super(SynthesizedImage, self).__init__(**kwargs)
self.weight = nn.Parameter(torch.rand(*img_shape))
def forward(self):
return self.weight
定义get_inits
函数。该函数创建了合成图像的模型实例,并将其初始化为图像X
。风格图像在各个风格层的格拉姆矩阵styles_Y_gram
将在训练前预先计算好。
def get_inits(X, device, lr, styles_Y):
gen_img = SynthesizedImage(X.shape).to(device)
gen_img.weight.data.copy_(X.data)
trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)
styles_Y_gram = [gram(Y) for Y in styles_Y]
return gen_img(), styles_Y_gram, trainer
训练模型
在训练模型进行风格迁移时,我们不断抽取合成图像的内容特征和风格特征,然后计算损失函数
def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):
X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)
scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)
animator = d2l.Animator(xlabel='epoch', ylabel='loss',
xlim=[10, num_epochs],
legend=['content', 'style', 'TV'],
ncols=2, figsize=(7, 2.5))
for epoch in range(num_epochs):
trainer.zero_grad()
contents_Y_hat, styles_Y_hat = extract_features(
X, content_layers, style_layers)
contents_l, styles_l, tv_l, l = compute_loss(
X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)
l.backward()
trainer.step()
scheduler.step()
if (epoch + 1) % 10 == 0:
animator.axes[1].imshow(postprocess(X))
animator.add(epoch + 1, [float(sum(contents_l)),
float(sum(styles_l)), float(tv_l)])
return X
device, image_shape = d2l.try_gpu(), (300, 450)
net = net.to(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)
边栏推荐
猜你喜欢
低代码平台和专业开发人员——完美搭档?
What is control board custom development?
Logic unauthorized and horizontal and vertical unauthorized payment tampering, verification code bypass, interface
Functions and Features of Smart Home Control System
.NET 6学习笔记(4)——解决VS2022中Nullable警告
Ark: Survival Evolved Open Server Port Mapping Tutorial
The senior told me that the MySQL of the big factory is connected through SSH
逻辑越权和水平垂直越权支付篡改,验证码绕过,接口
融云 x N 世界:构建无限用户实时交互的「元宇宙会场」
3 Feature Binning Methods!
随机推荐
110+ public professional datasets summarized
低代码平台和专业开发人员——完美搭档?
MASA Stack 第三期社区例会
开篇-开启全新的.NET现代应用开发体验
Jenkins使用pipeline部署服务到远程服务器
AI基础环境搭建和设置总文
如何仿造一个websocket请求?
How to choose a good SaaS knowledge base tool?
What is control board custom development?
论如何提升学习的能力
MySQL索引的B+树到底有多高?
【ROS2原理9】 QoS - 截止日期、活跃度和寿命
Fees and inquiry methods of futures account opening exchanges
MySQL的索引你了解吗
怎样选择一个好的SaaS知识库工具?
uniapp电影购票选座系统源码
【代码审计】——PHP项目类RCE及文件包含下载删除
【教程3】疯壳·ARM功能手机-整板资源介绍
本机号码一键登录原理
The senior told me that the MySQL of the big factory is connected through SSH