当前位置:网站首页>Learning notes 7-depth neural network optimization
Learning notes 7-depth neural network optimization
2022-04-23 10:39:00 【When can I be as powerful as a big man】
Batch normalization (BatchNormalization)
Standardization of inputs ( Shallow model )
The mean value of any processed feature on all samples in the data set is 0、 The standard deviation is 1.
Standardize the input data so that the distribution of each feature is similar
Batch normalization ( Depth model )
Use the mean and standard deviation on a small batch , Constantly adjust the intermediate output of neural network , Thus, the output value of the whole neural network in the middle of each layer is more stable .
1. Batch normalization of the whole connection layer
Location : Between affine transformation and activation function in full connection layer .
Full connection :
x = W u + b o u t p u t = ϕ ( x ) \boldsymbol{x} = \boldsymbol{W\boldsymbol{u} + \boldsymbol{b}} \\ output =\phi(\boldsymbol{x}) x=Wu+boutput=ϕ(x)
Batch normalization :
o u t p u t = ϕ ( BN ( x ) ) output=\phi(\text{BN}(\boldsymbol{x})) output=ϕ(BN(x))
y ( i ) = BN ( x ( i ) ) \boldsymbol{y}^{(i)} = \text{BN}(\boldsymbol{x}^{(i)}) y(i)=BN(x(i))
μ B ← 1 m ∑ i = 1 m x ( i ) , \boldsymbol{\mu}_\mathcal{B} \leftarrow \frac{1}{m}\sum_{i = 1}^{m} \boldsymbol{x}^{(i)}, μB←m1i=1∑mx(i),
σ B 2 ← 1 m ∑ i = 1 m ( x ( i ) − μ B ) 2 , \boldsymbol{\sigma}_\mathcal{B}^2 \leftarrow \frac{1}{m} \sum_{i=1}^{m}(\boldsymbol{x}^{(i)} - \boldsymbol{\mu}_\mathcal{B})^2, σB2←m1i=1∑m(x(i)−μB)2,
x ^ ( i ) ← x ( i ) − μ B σ B 2 + ϵ , \hat{\boldsymbol{x}}^{(i)} \leftarrow \frac{\boldsymbol{x}^{(i)} - \boldsymbol{\mu}_\mathcal{B}}{\sqrt{\boldsymbol{\sigma}_\mathcal{B}^2 + \epsilon}}, x^(i)←σB2+ϵx(i)−μB,
this ⾥ϵ > 0 It's a very small constant , Ensure that the denominator is greater than 0
y ( i ) ← γ ⊙ x ^ ( i ) + β . {\boldsymbol{y}}^{(i)} \leftarrow \boldsymbol{\gamma} \odot \hat{\boldsymbol{x}}^{(i)} + \boldsymbol{\beta}. y(i)←γ⊙x^(i)+β.
Introduce learnable parameters : Stretch parameters γ And offset parameters β. if γ = σ B 2 + ϵ \boldsymbol{\gamma} = \sqrt{\boldsymbol{\sigma}_\mathcal{B}^2 + \epsilon} γ=σB2+ϵ and β = μ B \boldsymbol{\beta} = \boldsymbol{\mu}_\mathcal{B} β=μB, Batch normalization is invalid .
2. Batch reduction of convolution layer ⼀ turn
Location : After convolution calculation 、 Should be ⽤ Before activating the function .
If the convolution calculation outputs multiple channels , We need to normalize the output of these channels in batches , Each channel has its own stretch and offset parameters .
Calculation : For single channel ,batchsize=m, Convolution calculation output =pxq
In this channel m×p×q Batch normalization of multiple elements at the same time , Use the same mean and variance .
3. Batch return when forecasting ⼀ turn
Training : With batch In units of , For each batch Calculate the mean and variance .
forecast : The moving average is used to estimate the sample mean and variance of the whole training data set .
From zero
import time
import torch
from torch import nn, optim
import torch.nn.functional as F
import torchvision
import sys
sys.path.append("/home/kesci/input/")
import d2lzh1981 as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def batch_norm(is_training, X, gamma, beta, moving_mean, moving_var, eps, momentum):
# Judge whether the current mode is training mode or prediction mode
if not is_training:
# If it is in prediction mode , The mean and variance obtained by directly using the incoming moving average
X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
else:
assert len(X.shape) in (2, 4)
if len(X.shape) == 2:
# The case of using the full connection layer , Calculate the mean and variance on the characteristic dimension
mean = X.mean(dim=0)
var = ((X - mean) ** 2).mean(dim=0)
else:
# Use of two-dimensional convolution , Calculate the channel dimension (axis=1) The mean and variance of . Here we need to keep
# X So that the broadcast operation can be done later
mean = X.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
var = ((X - mean) ** 2).mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
# In the training mode, the current mean and variance are used for standardization
X_hat = (X - mean) / torch.sqrt(var + eps)
# Update the mean and variance of the moving average
moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
moving_var = momentum * moving_var + (1.0 - momentum) * var
Y = gamma * X_hat + beta # Stretch and offset
return Y, moving_mean, moving_var
class BatchNorm(nn.Module):
def __init__(self, num_features, num_dims):
super(BatchNorm, self).__init__()
if num_dims == 2:
shape = (1, num_features) # The output neuron of the whole connection layer
else:
shape = (1, num_features, 1, 1) # The channel number
# Stretch and offset parameters involved in gradient sum iteration , Initialize into 0 and 1
self.gamma = nn.Parameter(torch.ones(shape))
self.beta = nn.Parameter(torch.zeros(shape))
# Variables that do not participate in the gradient sum iteration , All initialized to... In memory 0
self.moving_mean = torch.zeros(shape)
self.moving_var = torch.zeros(shape)
def forward(self, X):
# If X Not in memory , take moving_mean and moving_var Copied to the X On the video memory
if self.moving_mean.device != X.device:
self.moving_mean = self.moving_mean.to(X.device)
self.moving_var = self.moving_var.to(X.device)
# Save the updated moving_mean and moving_var, Module Example of traning Property defaults to true, call .eval() Set it to false
Y, self.moving_mean, self.moving_var = batch_norm(self.training,
X, self.gamma, self.beta, self.moving_mean,
self.moving_var, eps=1e-5, momentum=0.9)
return Y
net = nn.Sequential(
nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size
BatchNorm(6, num_dims=4),
nn.Sigmoid(),
nn.MaxPool2d(2, 2), # kernel_size, stride
nn.Conv2d(6, 16, 5),
BatchNorm(16, num_dims=4),
nn.Sigmoid(),
nn.MaxPool2d(2, 2),
d2l.FlattenLayer(),
nn.Linear(16*4*4, 120),
BatchNorm(120, num_dims=2),
nn.Sigmoid(),
nn.Linear(120, 84),
BatchNorm(84, num_dims=2),
nn.Sigmoid(),
nn.Linear(84, 10)
)
print(net)
#batch_size = 256
##cpu Turn it down batchsize
batch_size=16
def load_data_fashion_mnist(batch_size, resize=None, root='/home/kesci/input/FashionMNIST2065'):
"""Download the fashion mnist dataset and then load into memory."""
trans = []
if resize:
trans.append(torchvision.transforms.Resize(size=resize))
trans.append(torchvision.transforms.ToTensor())
transform = torchvision.transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)
mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=2)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=2)
return train_iter, test_iter
train_iter, test_iter = load_data_fashion_mnist(batch_size)
lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)
版权声明
本文为[When can I be as powerful as a big man]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230619103837.html
边栏推荐
- Installing MySQL with CentOS / Linux
- Detailed explanation of MapReduce calculation process
- 微信小程序中app.js文件、组件、api
- Sim Api User Guide(6)
- Charles function introduction and use tutorial
- Windows installs redis and sets the redis service to start automatically
- Wonderful review | deepnova x iceberg meetup online "building a real-time data Lake based on iceberg"
- Sim Api User Guide(5)
- JVM——》常用命令
- 精彩回顾 | DEEPNOVA x Iceberg Meetup Online《基于Iceberg打造实时数据湖》
猜你喜欢
MapReduce compression
Exercise questions and simulation test of refrigeration and air conditioning equipment operation test in 2022
2022 mobile crane driver test question bank simulation test platform operation
Swagger2 自定义参数注解如何不显示
Charles 功能介绍和使用教程
SQL Server 游标循环表数据
Introduction to wechat applet, development history, advantages of applet, application account, development tools, initial knowledge of wxml file and wxss file
部署jar包
基于PyQt5实现弹出任务进度条功能示例
第120章 SQL函数 ROUND
随机推荐
Introduction to wechat applet, development history, advantages of applet, application account, development tools, initial knowledge of wxml file and wxss file
基于PyQt5实现弹出任务进度条功能示例
SQL Server recursive query of superior and subordinate
How can swagger2 custom parameter annotations not be displayed
Arbitrary file reading vulnerability exploitation Guide
微信小程序简介、发展史、小程序的优点、申请账号、开发工具、初识wxml文件和wxss文件
Detailed explanation of MapReduce calculation process
Sim Api User Guide(4)
2022 mobile crane driver test question bank simulation test platform operation
Sim Api User Guide(8)
二叉树的构建和遍历
Image processing - Noise notes
997、有序数组的平方(数组)
Common SQL statements of DBA (6) - daily management
得到知识服务app原型设计比较与实践
JVM——》常用命令
SQL tuning series - Introduction to SQL tuning
242. Valid Letter ectopic words (hash table)
任意文件读取漏洞 利用指南
Reading integrity monitoring techniques for vision navigation systems - 5 Results