当前位置:网站首页>Loading and using image classification dataset fashion MNIST in pytorch
Loading and using image classification dataset fashion MNIST in pytorch
2022-04-23 13:11:00 【Cloud fff】
- Reference resources :《 Hands-on deep learning 》(Pytorch) edition 3.5 section
- notes : This article is about jupyter notebook Document conversion , Part of the code may not be copied and run directly !
List of articles
- The most commonly used image classification dataset is handwritten numeral recognition dataset MNIST, But most models are MNIST The classification accuracy on the is more than 95%, In order to more intuitively observe the differences between algorithms , This paper introduces a data set with more complex image content Fashion-MNIST, This data set is more difficult than MNIST high , But the size is not big , Only a few dozen M, No, GPU Your computer can stand it
- The data set can take advantage of
torchvisionDownload and process packages , The package contains the following core modulestorchvision.datasets: Provide functions for loading data and common data set interfaces ;torchvision.models: Contains common model structures ( Including pre training model ), Such as AlexNet、VGG、ResNet etc. ;torchvision.transforms: Provide common image transformation methods , For example, cutting 、 Spin, etc ;torchvision.utils: Provide some other useful methods
- Before the introduction , Import the package first
import torch import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import time import numpy as np from IPython import display
1. Get data set
-
adopt
torchvision.datasets.FashionMNISTMethod to get the datasetmnist_train = torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST', train=True, transform=transforms.ToTensor()) mnist_test = torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST', train=False, transform=transforms.ToTensor())Parameter description
-
rootParameter specifies the data set saving path -
trainParameter specifies whether to obtain training set or test set -
downloadIf the parameter is set toTrue, Found in root Automatically download from the Internet when there is no data set under the path , If there is an existing data set, no action will be taken -
transform = transforms.ToTensor()Convert all data intoTensor, If you do not convert, you will return PIL picturetransforms.ToTensor()take “ Size is H × W × C H \times W \times C H×W×C And the data is located in [ 0 , 255 ] [0, 255] [0,255] Of PIL picture ” perhaps “ The data type isnp.uint8Of NumPy Array ” Convert to “ Size is C × H × W C \times H \times W C×H×W And the data type istorch.float32And located in[0.0, 1.0]Of Tensor”Be careful
transforms.ToTensor()The default input of some functions about pictures isuint8type , If not, you may get unwanted results , therefore If you use [ 0 , 255 ] [0,255] [0,255] The pixel value of represents the picture data , Set its type touint8, To avoid unnecessary bug
-
-
It's loaded here
mnist_trainandmnist_testAll aretorch.utils.data.DatasetSubclasses of , Some common methods are as followsprint(type(mnist_train)) print(len(mnist_train), len(mnist_test)) # use len() Gets the size of the dataset feature, label = mnist_train[0] # Access any sample by subscript print(feature.shape, label) # [Channel , Height , Width] label, Note that because the data set is grayscale , The number of channels is 1 ''' torchvision.datasets.mnist.FashionMNIST 60000 10000 torch.Size([1, 28, 28]) 9 ''' -
Fashion-MNIST It includes 10 Categories , Respectively
- t-shirt(T T-shirt )
- trouser( The trousers )
- pullover( Pullover )
- dress( dress )
- coat( coat )
- sandal( Sandals )
- shirt( shirt )
- sneaker( Sports shoes )
- bag( package )
- ankle boot( Boots )
Use the following function to convert the list of numeric labels into the corresponding list of text labels
def get_fashion_mnist_labels(labels): text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] return [text_labels[int(i)] for i in labels] -
Use the following function to draw multiple images and corresponding labels in one line
def show_fashion_mnist(images, labels): display.set_matplotlib_formats('svg') # Use svg format to display plot in jupyter _, figs = plt.subplots(1, len(images), figsize=(12, 12)) for f, img, lbl in zip(figs, images, labels): f.imshow(img.view((28, 28)).numpy()) f.set_title(lbl) f.axes.get_xaxis().set_visible(False) f.axes.get_yaxis().set_visible(False) plt.show() -
Random display 10 Samples
X, y = [], [] for i in np.random.randint(0,60000,size = 10).tolist(): X.append(mnist_train[i][0]) y.append(mnist_train[i][1]) show_fashion_mnist(X, get_fashion_mnist_labels(y))Here I come across an error report , Please refer to ‘OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program’, I deleted... In the virtual environment
libiomp5md.dllSolve this problem

2. Read small batch
-
In practice , Data reading is often the performance bottleneck of training ,
torch.utilsModule providedDataLoaderMethod allows us to easily use multiple processes to speed up data reading -
mnist_trainyestorch.utils.data.DatasetSubclasses of , So we can pass it intotorch.utils.data.DataLoaderTo create a program that reads a small batch of data samplesDataLoaderexample , When creating a- Through parameters
num_workersTo specify the number of processes reading data - adopt
shuffleParameter specifies whether to disrupt the reading
batch_size = 256 if sys.platform.startswith('win'): # Judge the operating system as windows num_workers = 4 # Use 4 Two processes read at the same time else: num_workers = 0 # 0 It means that there is no extra process to speed up reading data train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers) test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers) - Through parameters
-
See how long it takes to read the data once
start = time.time() for X, y in train_iter: continue print('%.2f sec' % (time.time() - start))After testing , My laptop takes time without multi process acceleration 5.88s, Reduce to... After use 3.18s
版权声明
本文为[Cloud fff]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204231306550368.html
边栏推荐
- 100 GIS practical application cases (52) - how to keep the number of rows and columns consistent and aligned when cutting grids with grids in ArcGIS?
- mui 关闭其他页面,只保留首页面
- XML
- HQL statement tuning
- Temperature and humidity monitoring + timing alarm system based on 51 single chip microcomputer (C51 source code)
- 2020最新Android大厂高频面试题解析大全(BAT TMD JD 小米)
- async void 導致程序崩潰
- Record Alibaba cloud server mining program processing
- 这几种 VSCode 扩展是我最喜欢的
- Request和Response及其ServletContext总结
猜你喜欢

hbuilderx + uniapp 打包ipa提交App store踩坑记

Complete project data of UAV apriltag dynamic tracking landing based on openmv (LabVIEW + openmv + apriltag + punctual atom four axes)

There is no need to crack the markdown editing tool typora

FatFs FAT32 learning notes

AUTOSAR from introduction to mastery 100 lectures (86) - 2F of UDS service foundation

AUTOSAR from introduction to mastery 100 lectures (52) - diagnosis and communication management function unit

AUTOSAR from introduction to mastery 100 lectures (81) - FIM of AUTOSAR Foundation

100 GIS practical application cases (52) - how to keep the number of rows and columns consistent and aligned when cutting grids with grids in ArcGIS?

Imx6ull QEMU bare metal tutorial 1: GPIO, iomux, I2C

vscode小技巧
随机推荐
Golang implements a five insurance and one gold calculator with web interface
十万大学生都已成为猿粉,你还在等什么?
X509 parsing
C语言之字符串与字符数组的区别
[wechat applet] flex layout usage record
Summary of JVM knowledge points - continuously updated
Async void provoque l'écrasement du programme
Three channel ultrasonic ranging system based on 51 single chip microcomputer (timer ranging)
AUTOSAR from introduction to mastery lecture 100 (84) - Summary of UDS time parameters
mysql 基本语句查询
这几种 VSCode 扩展是我最喜欢的
Armv8m (cortex M33) MPU actual combat
Design of body fat detection system based on 51 single chip microcomputer (51 + OLED + hx711 + US100)
数据仓库—什么是OLAP
melt reshape decast 长数据短数据 长短转化 数据清洗 行列转化
Data warehouse - what is OLAP
100 GIS practical application cases (51) - a method for calculating the hourly spatial average of NC files according to the specified range in ArcGIS
mui + hbuilder + h5api模拟弹出支付样式
CMSIS cm3 source code annotation
[untitled] PID control TT encoder motor