当前位置:网站首页>Pytorch 经典卷积神经网络 LeNet
Pytorch 经典卷积神经网络 LeNet
2022-04-23 13:58:00 【哇咔咔负负得正】
Pytorch 经典卷积神经网络 LeNet
0. 环境介绍
环境使用 Kaggle 里免费建立的 Notebook
小技巧:当遇到函数看不懂的时候可以按 Shift+Tab
查看函数详解。
1. LeNet
1.0 简介
LeNet 是最早发布的卷积神经网络之一,因其在计算机视觉任务中的高效性能而受到广泛关注。 这个模型是由 AT&T 贝尔实验室的研究员 Yann LeCun 在 1989 年提出的(并以其命名),目的是识别图像 [LeCun et al., 1998] 中的手写数字。 当时,Yann LeCun 发表了第一篇通过反向传播成功训练卷积神经网络的研究,这项工作代表了十多年来神经网络研究开发的成果。
当时,LeNet取得了与支持向量机(support vector machines)性能相媲美的成果,成为监督学习的主流方法。 LeNet 被广泛用于自动取款机(ATM)机中,帮助识别处理支票的数字。 时至今日,一些自动取款机仍在运行 Yann LeCun 和他的同事 Leon Bottou 在上世纪 90 年代写的代码。
论文地址:https://axon.cs.byu.edu/~martinez/classes/678/Papers/Convolution_nets.pdf
其中的手写数字 MNIST 数据集:
- 50 , 000 50,000 50,000 个训练数据
- 10 , 000 10,000 10,000 个测试数据
- 图像大小 28 × 28 28 \times 28 28×28
- 10 10 10 类 ( 0 → 9 ) (0 \to 9) (0→9)
1.2 LeNet 结构
每个卷积块中的基本单元是一个卷积层、一个 sigmoid 激活函数和平均池化层。
注:虽然 ReLU 激活函数和最大池化层更有效,但它们在20世纪90年代还没有出现。
每个卷积层使用 5 × 5 5\times 5 5×5 卷积核和一个 sigmoid 激活函数。这些层将输入映射到多个二维特征输出,通常同时增加通道的数量。第一卷积层有 6 6 6 个输出通道,而第二个卷积层有 16 16 16 个输出通道。使用 2 × 2 2\times 2 2×2 的平均池化窗口通过空间下采样将维数减少4倍。
先使用卷积层来学习图片空间信息,然后使用全连接层来转换到类别空间。
2. 代码实现
2.1 网络结构
对原始模型做了一点小改动,去掉了最后一层的高斯激活。除此之外,这个网络与最初的 LeNet-5 一致。
!pip install -U d2l
import torch
from torch import nn
from d2l import torch as d2l
net = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
nn.Linear(120, 84), nn.Sigmoid(),
nn.Linear(84, 10))
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:
X = layer(X)
print(layer.__class__.__name__,'output shape: \t',X.shape)
在整个卷积块中,与上一层相比,每一层特征的高度和宽度都减小了。 第一个卷积层使用 2 2 2 个像素的填充,来补偿卷积核导致的特征减少。 第二个卷积层没有填充,因此高度和宽度都减少了 4 4 4 个像素。 随着层叠的上升,通道的数量从输入时的 1 1 1 个,增加到第一个卷积层之后的 6 6 6 个,再到第二个卷积层之后的 16 16 16 个。 同时,每个平均池化层的高度和宽度都减半。最后,每个全连接层减少维数,最终输出一个维数与结果分类数相匹配的输出。
2.2 加载 Fashion-MNIST 数据集
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)
2.3 评价函数
def evaluate_accuracy_gpu(net, data_iter, device=None): #@save
"""使用GPU计算模型在数据集上的精度"""
if isinstance(net, nn.Module):
net.eval() # 设置为评估模式
if not device:
device = next(iter(net.parameters())).device
# 正确预测的数量,总预测的数量
metric = d2l.Accumulator(2)
with torch.no_grad():
for X, y in data_iter:
if isinstance(X, list):
# BERT微调所需的(之后将介绍)
X = [x.to(device) for x in X]
else:
X = X.to(device)
y = y.to(device)
metric.add(d2l.accuracy(net(X), y), y.numel())
return metric[0] / metric[1]
2.4 训练函数
#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):
"""用GPU训练模型(在第六章定义)"""
def init_weights(m):
if type(m) == nn.Linear or type(m) == nn.Conv2d:
# 使用 xavier 权重初始化
nn.init.xavier_uniform_(m.weight)
net.apply(init_weights)
print('training on', device)
net.to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss()
animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
legend=['train loss', 'train acc', 'test acc'])
timer, num_batches = d2l.Timer(), len(train_iter)
for epoch in range(num_epochs):
# 训练损失之和,训练准确率之和,样本数
metric = d2l.Accumulator(3)
net.train()
for i, (X, y) in enumerate(train_iter):
timer.start()
optimizer.zero_grad()
X, y = X.to(device), y.to(device)
y_hat = net(X)
l = loss(y_hat, y)
l.backward()
optimizer.step()
with torch.no_grad():
metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
timer.stop()
train_l = metric[0] / metric[2]
train_acc = metric[1] / metric[2]
if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
animator.add(epoch + (i + 1) / num_batches,
(train_l, train_acc, None))
test_acc = evaluate_accuracy_gpu(net, test_iter)
animator.add(epoch + 1, (None, None, test_acc))
print(f'loss {
train_l:.3f}, train acc {
train_acc:.3f}, '
f'test acc {
test_acc:.3f}')
print(f'{
metric[2] * num_epochs / timer.sum():.1f} examples/sec '
f'on {
str(device)}')
2.5 用 CPU 训练
在 kaggle 中 Accelerator 设置为 None
。
lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
每秒遍历 5612.2 5612.2 5612.2 个样本。
2.6 用 GPU 训练
在 kaggle 中使用 GPU:
每秒遍历 33873.6 33873.6 33873.6 个样本,可以发现比 CPU 训练快了不少。
训练集精度 0.820 0.820 0.820,测试集精度 0.801 0.801 0.801。
2.7 尝试更换激活函数为 ReLU 以及池化层换成最大池化,调整学习率
net2 = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear(16 * 5 * 5, 120), nn.ReLU(),
nn.Linear(120, 84), nn.ReLU(),
nn.Linear(84, 10))
# 学习率 0.9 的时候会不收敛,所以调整为 0.1
lr, num_epochs = 0.1, 10
train_ch6(net2, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
训练集精度 0.879 0.879 0.879,测试集精度 0.857 0.857 0.857,相对于之前的模型确实有提高。
版权声明
本文为[哇咔咔负负得正]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_39906884/article/details/124360050
边栏推荐
- Dolphin scheduler integrates Flink task pit records
- Building MySQL environment under Ubuntu & getting to know SQL
- Force deduction brush question 101 Symmetric binary tree
- PG library checks the name
- 【项目】小帽外卖(八)
- MySQL index [data structure + index creation principle]
- 专题测试05·二重积分【李艳芳全程班】
- Jenkins construction and use
- 力扣刷题 101. 对称二叉树
- redis如何解决缓存雪崩、缓存击穿和缓存穿透问题
猜你喜欢
随机推荐
Analysis of redo log generated by select command
Solution of discarding evaluate function in surprise Library
MySQL and PgSQL time related operations
Oracle and MySQL batch query all table names and table name comments under users
记录一个奇怪的bug:缓存组件跳转之后出现组件复制
Leetcode | 38 appearance array
Jenkins construction and use
Jiannanchun understood the word game
初探 Lambda Powertools TypeScript
Kettle--控件解析
Oracle database recovery data
Android 面试主题集合整理
SQL learning | set operation
淘宝发布宝贝提示“您的消保保证金额度不足,已启动到期保障”
crontab定时任务输出产生大量邮件耗尽文件系统inode问题处理
redis如何解决缓存雪崩、缓存击穿和缓存穿透问题
Three characteristics of volatile keyword [data visibility, prohibition of instruction rearrangement and no guarantee of operation atomicity]
Use future and countdownlatch to realize multithreading to execute multiple asynchronous tasks, and return results after all tasks are completed
The art of automation
Dolphin scheduler configuring dataX pit records