当前位置:网站首页>PyTorch笔记——观察DataLoader&用torch构建LeNet处理CIFAR-10完整代码
PyTorch笔记——观察DataLoader&用torch构建LeNet处理CIFAR-10完整代码
2022-04-23 05:44:00 【umbrellalalalala】
参考资料:《深度学习框架PyTorch:入门与实践》
目录
一、简单numpy例子观察DataLoader
创建数据,显示它的shape:
import numpy as np
data = np.array([[1,1,1,1],
[2,2,2,2],
[3,3,3,3],
[4,4,4,4],
[5,5,5,5],
[6,6,6,6]])
print(data)
print(data.shape)
输出:
[[1 1 1 1]
[2 2 2 2]
[3 3 3 3]
[4 4 4 4]
[5 5 5 5]
[6 6 6 6]]
(6, 4)
然后使用DataLoader处理这份数据,注意这份数据中有“6个”数据:
import torch as t
from tqdm import tqdm
# 每个batch有2个数据,shuffle=False是禁止打乱
dataloader = t.utils.data.DataLoader(data,
batch_size=2,
shuffle=False,
num_workers=1)
for i, data_ in enumerate(tqdm(dataloader)):
print(i)
print(data_)
逐个batch输出数据,就会显示如下:
100%|██████████| 3/3 [00:00<00:00, 12.95it/s]
0
tensor([[1, 1, 1, 1],
[2, 2, 2, 2]])
1
tensor([[3, 3, 3, 3],
[4, 4, 4, 4]])
2
tensor([[5, 5, 5, 5],
[6, 6, 6, 6]])
二、两种方式加载CIFAR-10数据
方式1,用torchvision自动下载CIFAR-10
import torchvision as tv
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
show = ToPILImage()
# 第一次运行torchvision会自动下载CIFAR-10数据集,大约100MB,
# 如果已经有,可以通过root参数指定
# 定义对数据的预处理
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))])
# 训练集
# 注意download=True即可
trainset = tv.datasets.CIFAR10(
root='填写想存放数据的文件夹',
train=True,
download=True,
transform=transform)
trainloader = t.utils.data.DataLoader(
trainset,
batch_size=4,
shuffle=True,
num_workers=2)
# 测试集
testset = tv.datasets.CIFAR10(
root='填写想存放数据的文件夹',
train=False,
download=True,
transform=transform)
testloader = t.utils.data.DataLoader(
testset,
batch_size=4,
shuffle=False,
num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
如果上述代码报错:
URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1056)>
那么只需要在开头加入:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
但是这种方式下载速度非常慢,所以我采用方式2
方式2,自行下载CIFAR-10
上述代码是在https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
中下载数据的,我将链接复制到了Neat download manager中,速度非常快就下载好了。下载之后需要解压,我的文件夹如下:
然后还是和方式1一样的代码,只不过注意root参数应该是上图文件夹的位置(注意不是精确到图中cifar-10-batches-py
,而是它所在的这个文件夹的位置)
三、观察CIFAR-10数据集的size
之前代码有一句:
trainloader = t.utils.data.DataLoader(
trainset,
batch_size=4,
shuffle=True,
num_workers=2)
这说明batch_size为4,并且允许打乱。我们现在观察一下trainloader是什么(testloader同理):
jupyter输入:
trainloader
输出:
<torch.utils.data.dataloader.DataLoader at 0x12b3c00b8>
这显然不足以让我了解trainloader,那么我要怎么观察trainloader呢?
答案很简单,可以用enumerate
,这样可以把trainloader中的数据赋值给data:
for i, data in enumerate(trainloader):
# print(np.array(data).shape) # 输出结果为(2, )
print(data[0].shape)
print(data[1].shape)
if i == 2: break
输出结果:
torch.Size([4, 3, 32, 32]) # batch, channel, height, width
torch.Size([4])
torch.Size([4, 3, 32, 32])
torch.Size([4])
torch.Size([4, 3, 32, 32])
torch.Size([4])
输出了三个batch的数据的shape。容易知道,data是[图片, label]的组织形式。上述代码中的data[0]是图片数据,data[1]是对应label。这样一来,我们就通过对trainloader的观察明白了data的形态。图片是3通道,高×宽是32×32;label就是一个数字。
四、LeNet处理CIFAR-10完整代码
import torch.nn as nn
import torch.nn.functional as F
import torchvision as tv
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
from torch import optim
import ssl
show = ToPILImage()
########################
###### 加载数据##########
########################
# 如果选择下载的方式加载数据,
# 报错“URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1056)>”
# 加入下面这句解决了错误
ssl._create_default_https_context = ssl._create_unverified_context
# 第一次运行torchvision会自动下载CIFAR-10数据集,大约100MB,
# 如果已经有,可以通过root参数指定
# 定义对数据的预处理
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))])
# 训练集
# 如果选择下载的方式加载数据,download设置为True
trainset = tv.datasets.CIFAR10(
root='上文讲的cifar-10-batches-py文件夹所在的文件夹',
train=True,
download=False,
transform=transform)
trainloader = t.utils.data.DataLoader(
trainset,
batch_size=4,
shuffle=True,
num_workers=2)
# 测试集
# 如果选择下载的方式加载数据,download设置为True
testset = tv.datasets.CIFAR10(
root='上文讲的cifar-10-batches-py文件夹所在的文件夹',
train=False,
download=False,
transform=transform)
testloader = t.utils.data.DataLoader(
testset,
batch_size=4,
shuffle=False,
num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
##################
##### 定义网络 ####
##################
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
x = x.view(x.size()[0], -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
#########################
# 定义损失函数和优化器######
#########################
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
#########################
## 训练 #################
#########################
# tqdm显示进度条,可以和enumerate结合使用
from tqdm import tqdm
for epoch in range(2):
running_loss = 0.0
for i, data in enumerate(tqdm(trainloader)):
# 前文已经讲过data[0]是图片,data[1]是label
inputs, labels = data
# 梯度清零
optimizer.zero_grad()
# 前向传播
outputs = net(inputs)
loss = criterion(outputs, labels)
# 反向传播
loss.backward()
optimizer.step()
running_loss += loss.data
# 每2000个batch打印一次训练状态
if i % 2000 == 1999:
print('[%d, %5d] loss: %.3f' % (epoch+1, i+1, running_loss / 2000))
running_loss = 0.0
print("finish training")
因为主要目的是演示代码运行,所以只训练两个epoch,运行结果如下:
16%|█▋ | 2044/12500 [00:04<00:23, 437.93it/s]
[1, 2000] loss: 2.185
33%|███▎ | 4072/12500 [00:10<00:23, 352.94it/s]
[1, 4000] loss: 1.808
48%|████▊ | 6062/12500 [00:15<00:17, 367.84it/s]
[1, 6000] loss: 1.637
65%|██████▍ | 8075/12500 [00:20<00:10, 428.74it/s]
[1, 8000] loss: 1.553
80%|████████ | 10053/12500 [00:25<00:06, 381.21it/s]
[1, 10000] loss: 1.483
96%|█████████▋| 12059/12500 [00:29<00:01, 412.74it/s]
[1, 12000] loss: 1.438
100%|██████████| 12500/12500 [00:31<00:00, 401.07it/s]
16%|█▋ | 2035/12500 [00:04<00:28, 365.83it/s]
[2, 2000] loss: 1.369
33%|███▎ | 4091/12500 [00:09<00:18, 451.14it/s]
[2, 4000] loss: 1.356
48%|████▊ | 6062/12500 [00:14<00:15, 409.82it/s]
[2, 6000] loss: 1.330
64%|██████▍ | 8057/12500 [00:19<00:10, 422.07it/s]
[2, 8000] loss: 1.279
81%|████████ | 10072/12500 [00:24<00:05, 412.58it/s]
[2, 10000] loss: 1.273
96%|█████████▌| 12023/12500 [00:29<00:01, 365.16it/s]
[2, 12000] loss: 1.243
100%|██████████| 12500/12500 [00:30<00:00, 406.19it/s]
finish training
版权声明
本文为[umbrellalalalala]所创,转载请带上原文链接,感谢
https://blog.csdn.net/umbrellalalalala/article/details/119916492
边栏推荐
- Ptorch learning record (XIII): recurrent neural network
- Get the value of state in effects in DVA
- 解决报错:ImportError: IProgress not found. Please update jupyter and ipywidgets
- EditorConfig
- excel获取两列数据的差异数据
- 去噪论文——[Noise2Void,CVPR19]Noise2Void-Learning Denoising from Single Noisy Images
- 建表到页面完整实例演示—联表查询
- C3P0数据库连接池使用
- Create enterprise mailbox account command
- MySQL realizes master-slave replication / master-slave synchronization
猜你喜欢
字符串(String)笔记
Pytorch学习记录(十一):数据增强、torchvision.transforms各函数讲解
Opensips (1) -- detailed process of installing opensips
PyEMD安装及简单使用
Pytorch学习记录(十):数据预处理+Batch Normalization批处理(BN)
Pyqy5 learning (2): qmainwindow + QWidget + qlabel
Pytorch学习记录(三):神经网络的结构+使用Sequential、Module定义模型
Pyqy5 learning (III): qlineedit + qtextedit
Pytorch——数据加载和处理
Pyqy5 learning (4): qabstractbutton + qradiobutton + qcheckbox
随机推荐
Shansi Valley P290 polymorphism exercise
无监督去噪——[TMI2022]ISCL: Interdependent Self-Cooperative Learning for Unpaired Image Denoising
Pytorch Learning record (XIII): Recurrent Neural Network
创建线程的三种方式
Solve the error: importerror: iprogress not found Please update jupyter and ipywidgets
创建二叉树
基于thymeleaf实现数据库图片展示到浏览器表格
类的加载与ClassLoader的理解
The user name and password of users in the domain accessing the samba server outside the domain are wrong
解决报错:ImportError: IProgress not found. Please update jupyter and ipywidgets
Pytorch学习记录(十三):循环神经网络((Recurrent Neural Network)
Idea plug-in --- playing songs in the background
创建企业邮箱账户命令
Rsync for file server backup
图像恢复论文简记——Uformer: A General U-Shaped Transformer for Image Restoration
Pytorch学习记录(十一):数据增强、torchvision.transforms各函数讲解
Remedy after postfix becomes a spam transit station
Conda 虚拟环境管理(创建、删除、克隆、重命名、导出和导入)
编写一个自己的 RedisTemplate
excel获取两列数据的差异数据