当前位置:网站首页>Loading and using image classification dataset fashion MNIST in pytorch
Loading and using image classification dataset fashion MNIST in pytorch
2022-04-23 13:11:00 【Cloud fff】
- Reference resources :《 Hands-on deep learning 》(Pytorch) edition 3.5 section
- notes : This article is about jupyter notebook Document conversion , Part of the code may not be copied and run directly !
List of articles
- The most commonly used image classification dataset is handwritten numeral recognition dataset MNIST, But most models are MNIST The classification accuracy on the is more than 95%, In order to more intuitively observe the differences between algorithms , This paper introduces a data set with more complex image content Fashion-MNIST, This data set is more difficult than MNIST high , But the size is not big , Only a few dozen M, No, GPU Your computer can stand it
- The data set can take advantage of
torchvision
Download and process packages , The package contains the following core modulestorchvision.datasets
: Provide functions for loading data and common data set interfaces ;torchvision.models
: Contains common model structures ( Including pre training model ), Such as AlexNet、VGG、ResNet etc. ;torchvision.transforms
: Provide common image transformation methods , For example, cutting 、 Spin, etc ;torchvision.utils
: Provide some other useful methods
- Before the introduction , Import the package first
import torch import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import time import numpy as np from IPython import display
1. Get data set
-
adopt
torchvision.datasets.FashionMNIST
Method to get the datasetmnist_train = torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST', train=True, transform=transforms.ToTensor()) mnist_test = torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST', train=False, transform=transforms.ToTensor())
Parameter description
-
root
Parameter specifies the data set saving path -
train
Parameter specifies whether to obtain training set or test set -
download
If the parameter is set toTrue
, Found in root Automatically download from the Internet when there is no data set under the path , If there is an existing data set, no action will be taken -
transform = transforms.ToTensor()
Convert all data intoTensor
, If you do not convert, you will return PIL picturetransforms.ToTensor()
take “ Size is H × W × C H \times W \times C H×W×C And the data is located in [ 0 , 255 ] [0, 255] [0,255] Of PIL picture ” perhaps “ The data type isnp.uint8
Of NumPy Array ” Convert to “ Size is C × H × W C \times H \times W C×H×W And the data type istorch.float32
And located in[0.0, 1.0]
Of Tensor”Be careful
transforms.ToTensor()
The default input of some functions about pictures isuint8
type , If not, you may get unwanted results , therefore If you use [ 0 , 255 ] [0,255] [0,255] The pixel value of represents the picture data , Set its type touint8
, To avoid unnecessary bug
-
-
It's loaded here
mnist_train
andmnist_test
All aretorch.utils.data.Dataset
Subclasses of , Some common methods are as followsprint(type(mnist_train)) print(len(mnist_train), len(mnist_test)) # use len() Gets the size of the dataset feature, label = mnist_train[0] # Access any sample by subscript print(feature.shape, label) # [Channel , Height , Width] label, Note that because the data set is grayscale , The number of channels is 1 ''' torchvision.datasets.mnist.FashionMNIST 60000 10000 torch.Size([1, 28, 28]) 9 '''
-
Fashion-MNIST It includes 10 Categories , Respectively
- t-shirt(T T-shirt )
- trouser( The trousers )
- pullover( Pullover )
- dress( dress )
- coat( coat )
- sandal( Sandals )
- shirt( shirt )
- sneaker( Sports shoes )
- bag( package )
- ankle boot( Boots )
Use the following function to convert the list of numeric labels into the corresponding list of text labels
def get_fashion_mnist_labels(labels): text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] return [text_labels[int(i)] for i in labels]
-
Use the following function to draw multiple images and corresponding labels in one line
def show_fashion_mnist(images, labels): display.set_matplotlib_formats('svg') # Use svg format to display plot in jupyter _, figs = plt.subplots(1, len(images), figsize=(12, 12)) for f, img, lbl in zip(figs, images, labels): f.imshow(img.view((28, 28)).numpy()) f.set_title(lbl) f.axes.get_xaxis().set_visible(False) f.axes.get_yaxis().set_visible(False) plt.show()
-
Random display 10 Samples
X, y = [], [] for i in np.random.randint(0,60000,size = 10).tolist(): X.append(mnist_train[i][0]) y.append(mnist_train[i][1]) show_fashion_mnist(X, get_fashion_mnist_labels(y))
Here I come across an error report , Please refer to ‘OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program’, I deleted... In the virtual environment
libiomp5md.dll
Solve this problem
2. Read small batch
-
In practice , Data reading is often the performance bottleneck of training ,
torch.utils
Module providedDataLoader
Method allows us to easily use multiple processes to speed up data reading -
mnist_train
yestorch.utils.data.Dataset
Subclasses of , So we can pass it intotorch.utils.data.DataLoader
To create a program that reads a small batch of data samplesDataLoader
example , When creating a- Through parameters
num_workers
To specify the number of processes reading data - adopt
shuffle
Parameter specifies whether to disrupt the reading
batch_size = 256 if sys.platform.startswith('win'): # Judge the operating system as windows num_workers = 4 # Use 4 Two processes read at the same time else: num_workers = 0 # 0 It means that there is no extra process to speed up reading data train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers) test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
- Through parameters
-
See how long it takes to read the data once
start = time.time() for X, y in train_iter: continue print('%.2f sec' % (time.time() - start))
After testing , My laptop takes time without multi process acceleration 5.88s, Reduce to... After use 3.18s
版权声明
本文为[Cloud fff]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204231306550368.html
边栏推荐
- Request和Response及其ServletContext总结
- Conflict between Mui picker and drop-down refresh
- Complete project data of UAV apriltag dynamic tracking landing based on openmv (LabVIEW + openmv + apriltag + punctual atom four axes)
- [notes de marche]
- Use Proteus to simulate STM32 ultrasonic srf04 ranging! Code+Proteus
- Translation of attention in natural language processing
- filter()遍历Array异常友好
- async void 导致程序崩溃
- "Play with Lighthouse" lightweight application server self built DNS resolution server
- 7_ The cell type scores obtained by addmodule and gene addition method are compared in space
猜你喜欢
Customize classloader and implement hot deployment - use loadclass
AUTOSAR from introduction to mastery 100 lectures (52) - diagnosis and communication management function unit
解决虚拟机中Oracle每次要设置ip的问题
普通大学生如何拿到大厂offer?敖丙教你一招致胜!
@优秀的你!CSDN高校俱乐部主席招募!
MySQL —— 16、索引的数据结构
Jupiter notebook installation
Free and open source charging pile Internet of things cloud platform
[wechat applet] flex layout usage record
9419 page analysis of the latest first-line Internet Android interview questions
随机推荐
FatFs FAT32 learning notes
4.22学习记录(你一天只做了水题是吗)
AUTOSAR from introduction to mastery 100 lectures (86) - 2F of UDS service foundation
The use of dcast and melt in R language is simple and easy to understand
Three channel ultrasonic ranging system based on 51 single chip microcomputer (timer ranging)
AUTOSAR from introduction to mastery 100 lectures (83) - bootloader self refresh
2021年6月程序员工资统计,平均15052元,你拖后腿了吗?
Timing role in the project
7_Addmodule和基因加和法add 得到的细胞类型打分在空间上空转对比
HQL find the maximum value in a range
PC starts multiple wechat at one time
Use Proteus to simulate STM32 ultrasonic srf04 ranging! Code+Proteus
Nodejs + websocket cycle small case
Use of Presto date function
(1) Openjuterpyrab comparison scheme
Go language slicing operation
Metalama简介4.使用Fabric操作项目或命名空间
9419 page analysis of the latest first-line Internet Android interview questions
Byte jump 2020 autumn recruitment programming question: quickly find your own ranking according to the job number
Hanlp word splitter (via spark)