当前位置:网站首页>Loading and using image classification dataset fashion MNIST in pytorch
Loading and using image classification dataset fashion MNIST in pytorch
2022-04-23 13:11:00 【Cloud fff】
- Reference resources :《 Hands-on deep learning 》(Pytorch) edition 3.5 section
- notes : This article is about jupyter notebook Document conversion , Part of the code may not be copied and run directly !
List of articles
- The most commonly used image classification dataset is handwritten numeral recognition dataset MNIST, But most models are MNIST The classification accuracy on the is more than 95%, In order to more intuitively observe the differences between algorithms , This paper introduces a data set with more complex image content Fashion-MNIST, This data set is more difficult than MNIST high , But the size is not big , Only a few dozen M, No, GPU Your computer can stand it
- The data set can take advantage of
torchvision
Download and process packages , The package contains the following core modulestorchvision.datasets
: Provide functions for loading data and common data set interfaces ;torchvision.models
: Contains common model structures ( Including pre training model ), Such as AlexNet、VGG、ResNet etc. ;torchvision.transforms
: Provide common image transformation methods , For example, cutting 、 Spin, etc ;torchvision.utils
: Provide some other useful methods
- Before the introduction , Import the package first
import torch import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import time import numpy as np from IPython import display
1. Get data set
-
adopt
torchvision.datasets.FashionMNIST
Method to get the datasetmnist_train = torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST', train=True, transform=transforms.ToTensor()) mnist_test = torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST', train=False, transform=transforms.ToTensor())
Parameter description
-
root
Parameter specifies the data set saving path -
train
Parameter specifies whether to obtain training set or test set -
download
If the parameter is set toTrue
, Found in root Automatically download from the Internet when there is no data set under the path , If there is an existing data set, no action will be taken -
transform = transforms.ToTensor()
Convert all data intoTensor
, If you do not convert, you will return PIL picturetransforms.ToTensor()
take “ Size is H × W × C H \times W \times C H×W×C And the data is located in [ 0 , 255 ] [0, 255] [0,255] Of PIL picture ” perhaps “ The data type isnp.uint8
Of NumPy Array ” Convert to “ Size is C × H × W C \times H \times W C×H×W And the data type istorch.float32
And located in[0.0, 1.0]
Of Tensor”Be careful
transforms.ToTensor()
The default input of some functions about pictures isuint8
type , If not, you may get unwanted results , therefore If you use [ 0 , 255 ] [0,255] [0,255] The pixel value of represents the picture data , Set its type touint8
, To avoid unnecessary bug
-
-
It's loaded here
mnist_train
andmnist_test
All aretorch.utils.data.Dataset
Subclasses of , Some common methods are as followsprint(type(mnist_train)) print(len(mnist_train), len(mnist_test)) # use len() Gets the size of the dataset feature, label = mnist_train[0] # Access any sample by subscript print(feature.shape, label) # [Channel , Height , Width] label, Note that because the data set is grayscale , The number of channels is 1 ''' torchvision.datasets.mnist.FashionMNIST 60000 10000 torch.Size([1, 28, 28]) 9 '''
-
Fashion-MNIST It includes 10 Categories , Respectively
- t-shirt(T T-shirt )
- trouser( The trousers )
- pullover( Pullover )
- dress( dress )
- coat( coat )
- sandal( Sandals )
- shirt( shirt )
- sneaker( Sports shoes )
- bag( package )
- ankle boot( Boots )
Use the following function to convert the list of numeric labels into the corresponding list of text labels
def get_fashion_mnist_labels(labels): text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] return [text_labels[int(i)] for i in labels]
-
Use the following function to draw multiple images and corresponding labels in one line
def show_fashion_mnist(images, labels): display.set_matplotlib_formats('svg') # Use svg format to display plot in jupyter _, figs = plt.subplots(1, len(images), figsize=(12, 12)) for f, img, lbl in zip(figs, images, labels): f.imshow(img.view((28, 28)).numpy()) f.set_title(lbl) f.axes.get_xaxis().set_visible(False) f.axes.get_yaxis().set_visible(False) plt.show()
-
Random display 10 Samples
X, y = [], [] for i in np.random.randint(0,60000,size = 10).tolist(): X.append(mnist_train[i][0]) y.append(mnist_train[i][1]) show_fashion_mnist(X, get_fashion_mnist_labels(y))
Here I come across an error report , Please refer to ‘OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program’, I deleted... In the virtual environment
libiomp5md.dll
Solve this problem
2. Read small batch
-
In practice , Data reading is often the performance bottleneck of training ,
torch.utils
Module providedDataLoader
Method allows us to easily use multiple processes to speed up data reading -
mnist_train
yestorch.utils.data.Dataset
Subclasses of , So we can pass it intotorch.utils.data.DataLoader
To create a program that reads a small batch of data samplesDataLoader
example , When creating a- Through parameters
num_workers
To specify the number of processes reading data - adopt
shuffle
Parameter specifies whether to disrupt the reading
batch_size = 256 if sys.platform.startswith('win'): # Judge the operating system as windows num_workers = 4 # Use 4 Two processes read at the same time else: num_workers = 0 # 0 It means that there is no extra process to speed up reading data train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers) test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
- Through parameters
-
See how long it takes to read the data once
start = time.time() for X, y in train_iter: continue print('%.2f sec' % (time.time() - start))
After testing , My laptop takes time without multi process acceleration 5.88s, Reduce to... After use 3.18s
版权声明
本文为[Cloud fff]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204231306550368.html
边栏推荐
- 十万大学生都已成为猿粉,你还在等什么?
- 普通大学生如何拿到大厂offer?敖丙教你一招致胜!
- 8086 of x86 architecture
- Async void provoque l'écrasement du programme
- Armv8m (cortex M33) MPU actual combat
- R语言中dcast 和 melt的使用 简单易懂
- 鸿蒙系统是抄袭?还是未来?3分钟听完就懂的专业讲解
- HQL statement tuning
- Golang implements MD5, sha256 and bcrypt encryption
- Uniapp image import local image not displayed
猜你喜欢
Recovering data with MySQL binlog
R语言中dcast 和 melt的使用 简单易懂
The project file '' has been renamed or is no longer in the solution, and the source control provider associated with the solution could not be found - two engineering problems
MySQL5.5安装教程
100 lectures on practical application cases of Excel (VIII) - report connection function of Excel
Design of STM32 multi-channel temperature measurement wireless transmission alarm system (industrial timing temperature measurement / engine room temperature timing detection, etc.)
Request和Response及其ServletContext总结
HQL find the maximum value in a range
超40W奖金池等你来战!第二届“长沙银行杯”腾讯云启创新大赛火热来袭!
hbuilderx + uniapp 打包ipa提交App store踩坑记
随机推荐
Three channel ultrasonic ranging system based on 51 single chip microcomputer (timer ranging)
Imx6ull QEMU bare metal tutorial 2: usdhc SD card
【微信小程序】flex布局使用记录
5 tricky activity life cycle interview questions. After learning, go and hang the interviewer!
nodeJs + websocket 循环小案例
AUTOSAR from introduction to mastery 100 lectures (83) - bootloader self refresh
Subscribe to Alibaba demo send business messages
[51 single chip microcomputer traffic light simulation]
Design and manufacture of 51 single chip microcomputer solar charging treasure with low voltage alarm (complete code data)
[walking notes]
async void 导致程序崩溃
Important knowledge of network layer (interview, reexamination, term end)
GIS practical tips (III) - how to add legend in CASS?
Introduction to metalama 4 Use fabric to manipulate items or namespaces
(personal) sorting out system vulnerabilities after recent project development
100 lectures on practical application cases of Excel (VIII) - report connection function of Excel
Async void provoque l'écrasement du programme
pyqt5 将opencv图片存入内置SQLlite数据库,并查询
2021年6月程序员工资统计,平均15052元,你拖后腿了吗?
8086 of x86 architecture