当前位置:网站首页>学习笔记7-深度神经网络优化
学习笔记7-深度神经网络优化
2022-04-23 06:19:00 【什么时候才能像大佬一样厉害】
批量归一化(BatchNormalization)
对输入的标准化(浅层模型)
处理后的任意一个特征在数据集中所有样本上的均值为0、标准差为1。
标准化处理输入数据使各个特征的分布相近
批量归一化(深度模型)
利用小批量上的均值和标准差,不断调整神经网络中间输出,从而使整个神经网络在各层的中间输出的数值更稳定。
1.对全连接层做批量归一化
位置:全连接层中的仿射变换和激活函数之间。
全连接:
x = W u + b o u t p u t = ϕ ( x ) \boldsymbol{x} = \boldsymbol{W\boldsymbol{u} + \boldsymbol{b}} \\ output =\phi(\boldsymbol{x}) x=Wu+boutput=ϕ(x)
批量归一化:
o u t p u t = ϕ ( BN ( x ) ) output=\phi(\text{BN}(\boldsymbol{x})) output=ϕ(BN(x))
y ( i ) = BN ( x ( i ) ) \boldsymbol{y}^{(i)} = \text{BN}(\boldsymbol{x}^{(i)}) y(i)=BN(x(i))
μ B ← 1 m ∑ i = 1 m x ( i ) , \boldsymbol{\mu}_\mathcal{B} \leftarrow \frac{1}{m}\sum_{i = 1}^{m} \boldsymbol{x}^{(i)}, μB←m1i=1∑mx(i),
σ B 2 ← 1 m ∑ i = 1 m ( x ( i ) − μ B ) 2 , \boldsymbol{\sigma}_\mathcal{B}^2 \leftarrow \frac{1}{m} \sum_{i=1}^{m}(\boldsymbol{x}^{(i)} - \boldsymbol{\mu}_\mathcal{B})^2, σB2←m1i=1∑m(x(i)−μB)2,
x ^ ( i ) ← x ( i ) − μ B σ B 2 + ϵ , \hat{\boldsymbol{x}}^{(i)} \leftarrow \frac{\boldsymbol{x}^{(i)} - \boldsymbol{\mu}_\mathcal{B}}{\sqrt{\boldsymbol{\sigma}_\mathcal{B}^2 + \epsilon}}, x^(i)←σB2+ϵx(i)−μB,
这⾥ϵ > 0是个很小的常数,保证分母大于0
y ( i ) ← γ ⊙ x ^ ( i ) + β . {\boldsymbol{y}}^{(i)} \leftarrow \boldsymbol{\gamma} \odot \hat{\boldsymbol{x}}^{(i)} + \boldsymbol{\beta}. y(i)←γ⊙x^(i)+β.
引入可学习参数:拉伸参数γ和偏移参数β。若 γ = σ B 2 + ϵ \boldsymbol{\gamma} = \sqrt{\boldsymbol{\sigma}_\mathcal{B}^2 + \epsilon} γ=σB2+ϵ和 β = μ B \boldsymbol{\beta} = \boldsymbol{\mu}_\mathcal{B} β=μB,批量归一化无效。
2.对卷积层做批量归⼀化
位置:卷积计算之后、应⽤激活函数之前。
如果卷积计算输出多个通道,我们需要对这些通道的输出分别做批量归一化,且每个通道都拥有独立的拉伸和偏移参数。
计算:对单通道,batchsize=m,卷积计算输出=pxq
对该通道中m×p×q个元素同时做批量归一化,使用相同的均值和方差。
3.预测时的批量归⼀化
训练:以batch为单位,对每个batch计算均值和方差。
预测:用移动平均估算整个训练数据集的样本均值和方差。
从零实现
import time
import torch
from torch import nn, optim
import torch.nn.functional as F
import torchvision
import sys
sys.path.append("/home/kesci/input/")
import d2lzh1981 as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def batch_norm(is_training, X, gamma, beta, moving_mean, moving_var, eps, momentum):
# 判断当前模式是训练模式还是预测模式
if not is_training:
# 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差
X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
else:
assert len(X.shape) in (2, 4)
if len(X.shape) == 2:
# 使用全连接层的情况,计算特征维上的均值和方差
mean = X.mean(dim=0)
var = ((X - mean) ** 2).mean(dim=0)
else:
# 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。这里我们需要保持
# X的形状以便后面可以做广播运算
mean = X.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
var = ((X - mean) ** 2).mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
# 训练模式下用当前的均值和方差做标准化
X_hat = (X - mean) / torch.sqrt(var + eps)
# 更新移动平均的均值和方差
moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
moving_var = momentum * moving_var + (1.0 - momentum) * var
Y = gamma * X_hat + beta # 拉伸和偏移
return Y, moving_mean, moving_var
class BatchNorm(nn.Module):
def __init__(self, num_features, num_dims):
super(BatchNorm, self).__init__()
if num_dims == 2:
shape = (1, num_features) #全连接层输出神经元
else:
shape = (1, num_features, 1, 1) #通道数
# 参与求梯度和迭代的拉伸和偏移参数,分别初始化成0和1
self.gamma = nn.Parameter(torch.ones(shape))
self.beta = nn.Parameter(torch.zeros(shape))
# 不参与求梯度和迭代的变量,全在内存上初始化成0
self.moving_mean = torch.zeros(shape)
self.moving_var = torch.zeros(shape)
def forward(self, X):
# 如果X不在内存上,将moving_mean和moving_var复制到X所在显存上
if self.moving_mean.device != X.device:
self.moving_mean = self.moving_mean.to(X.device)
self.moving_var = self.moving_var.to(X.device)
# 保存更新过的moving_mean和moving_var, Module实例的traning属性默认为true, 调用.eval()后设成false
Y, self.moving_mean, self.moving_var = batch_norm(self.training,
X, self.gamma, self.beta, self.moving_mean,
self.moving_var, eps=1e-5, momentum=0.9)
return Y
net = nn.Sequential(
nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size
BatchNorm(6, num_dims=4),
nn.Sigmoid(),
nn.MaxPool2d(2, 2), # kernel_size, stride
nn.Conv2d(6, 16, 5),
BatchNorm(16, num_dims=4),
nn.Sigmoid(),
nn.MaxPool2d(2, 2),
d2l.FlattenLayer(),
nn.Linear(16*4*4, 120),
BatchNorm(120, num_dims=2),
nn.Sigmoid(),
nn.Linear(120, 84),
BatchNorm(84, num_dims=2),
nn.Sigmoid(),
nn.Linear(84, 10)
)
print(net)
#batch_size = 256
##cpu要调小batchsize
batch_size=16
def load_data_fashion_mnist(batch_size, resize=None, root='/home/kesci/input/FashionMNIST2065'):
"""Download the fashion mnist dataset and then load into memory."""
trans = []
if resize:
trans.append(torchvision.transforms.Resize(size=resize))
trans.append(torchvision.transforms.ToTensor())
transform = torchvision.transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)
mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=2)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=2)
return train_iter, test_iter
train_iter, test_iter = load_data_fashion_mnist(batch_size)
lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)
版权声明
本文为[什么时候才能像大佬一样厉害]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_36016038/article/details/104492753
边栏推荐
- Lead the industry trend with intelligent production! American camera intelligent video production platform unveiled at 2021 world Ultra HD Video Industry Development Conference
- PyTorch 11. Regularization
- golang实现MD5,SHA256,bcrypt加密
- hql求一个范围内最大值
- 江宁医院DMR系统解决方案
- H5案例开发
- colab
- 各类日期转化的utils
- 可视化常见问题解决方案(九)背景颜色问题
- 防汛救灾应急通信系统
猜你喜欢
随机推荐
美摄助力百度“度咔剪辑”,让知识创作更容易
可视化之路(十)分割画布函数详解
van-uploader上传图片实现过程、使用原生input实现上传图片
美摄科技受邀LVSon2020大会 分享《AI合成虚拟人物的技术框架与挑战》
华为云MVP邮件
开发板如何ping通百度
enforce fail at inline_ container. cc:222
直观理解熵
Metro wireless intercom system
美摄科技云剪辑,助力哔哩哔哩使用体验再升级
Us photo cloud editing helps BiliBili upgrade its experience
枫桥学院开元名庭酒店DMR系统解决方案
关于短视频平台框架搭建与技术选型探讨
Lead the industry trend with intelligent production! American camera intelligent video production platform unveiled at 2021 world Ultra HD Video Industry Development Conference
通过sparksql读取presto中的数据存到clickhouse
Flexible blind patch of ad hoc network | Beifeng oil and gas field survey solution
JDBC连接池
Discussion on frame construction and technology selection of short video platform
技术小白的第一篇(表达自己)
H5案例开发