当前位置:网站首页>Transfer learning of five categories of pictures based on VGg
Transfer learning of five categories of pictures based on VGg
2022-04-23 17:54:00 【Stephen_ Tao】
List of articles
Introduction to dataset

It is divided into two parts: training set and test set , Each part contains 5 Categories of data , They are cars 、 The dinosaur 、 Elephant 、 Flowers and horses .
Code implementation
It is mainly divided into the following five steps :
- Read local picture data and categories
- VGG Modification of model structure ( Add a custom classification layer )
- freeze Drop the original VGG Model
- compile 、 Train and save models
Import of related packages
import numpy as np
import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator,load_img,img_to_array
from tensorflow.python.keras.applications.vgg16 import VGG16,preprocess_input
from tensorflow.python.keras import layers
from tensorflow.python.keras.optimizers import adam_v2
from tensorflow.python.keras.losses import sparse_categorical_crossentropy
from tensorflow.python.keras.callbacks import ModelCheckpoint
Read local picture data and categories
class TransferModel(object):
def __init__(self):
self.train_dir = './data/train'
self.test_dir = './data/test'
self.model_size = (224,224)
self.batch_size = 32
self.train_generator = ImageDataGenerator(rescale=1.0/255.0)
self.test_generator = ImageDataGenerator(rescale=1.0/255.0)
self.base_model = VGG16(include_top=False)
def get_local_data(self):
""" Read local picture data and categories :return: Iterators for training data and test data """
train_gen = self.train_generator.flow_from_directory(directory=self.train_dir,
target_size=self.model_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True)
test_gen = self.test_generator.flow_from_directory(directory=self.test_dir,
target_size=self.model_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True)
return train_gen,test_gen
Yes train_gen as well as test_gen Print , We can get the following results :
Found 400 images belonging to 5 classes.
Found 100 images belonging to 5 classes.
<tensorflow.python.keras.preprocessing.image.DirectoryIterator object at 0x000001AB1BD65E80>
<tensorflow.python.keras.preprocessing.image.DirectoryIterator object at 0x000001AB1BD65520>
VGG Modification of model structure
def refine_vgg_model(self):
x = self.base_model.outputs[0]
# use GlobalAveragePooling2D Reduce the parameters of the model
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(1024,activation=tf.nn.relu)(x)
y_predict = layers.Dense(5,activation=tf.nn.softmax)(x)
model = keras.Model(inputs=self.base_model.inputs,outputs=y_predict)
return model
First of all, yes VGG_nontop The output of the model is analyzed GlobalAveragePooling2D Reduce the parameters of full connection , Then customize and build two full connection layers , Get a new model .
freeze Drop the original VGG Model parameters
def freeze_vgg_model(self):
for layer in self.base_model.layers:
layer.trainable = False
compile 、 Train and save models
def compile(self,model):
model.compile(optimizer=adam_v2.Adam(),
loss=sparse_categorical_crossentropy,
metrics=['accuracy'])
def fit(self,model,train_gen,test_gen):
check = ModelCheckpoint('./ckpt/transfer_{epoch:02d}-{val_accuracy:.2f}.h5',
monitor='val_accuracy',
save_best_only=True,
save_weights_only=True,
mode='auto',
period=1)
model.fit_generator(train_gen, epochs=3, validation_data=test_gen, callbacks=[check])
The main function
if __name__ == '__main__':
tm = TransferModel()
train_gen,test_gen = tm.get_local_data()
model = tm.refine_vgg_model()
# print(tm.refine_vgg_model().summary())
tm.freeze_vgg_model()
tm.compile(model)
tm.fit(model,train_gen,test_gen)
After the training, you will get the following documents :
Model to predict
def predict(self,model):
model.load_weights('./ckpt/transfer_02-0.93.h5')
image = load_img('./data/test/bus/300.jpg',target_size=(224,224))
# print(image)
image = img_to_array(image)
# print(" The shape of the picture :", image.shape)
# Shape from 3 Change the dimension to 4 dimension
img = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
# print(" Changing the shape results in :", img.shape)
# 3、 Process image content , Normalization, etc , To make predictions
img = preprocess_input(img)
print(img.shape)
y_predict = model.predict(img)
index = np.argmax(y_predict, axis=1)
#
print(self.label_dict[str(index[0])])
The prediction results are as follows :
版权声明
本文为[Stephen_ Tao]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230548468317.html
边栏推荐
- EasymodbusTCP之clientexample解析
- Go file operation
- 48. 旋转图像
- Element calculation distance and event object
- 198. 打家劫舍-动态规划
- C1 notes [task training chapter I]
- MySQL advanced index [classification, performance analysis, use, design principles]
- 土地覆盖/利用数据产品下载
- Fashion classification case based on keras
- 2022江西储能技术展会,中国电池展,动力电池展,燃料电池展
猜你喜欢

Hcip fifth experiment

2022 Shanghai safety officer C certificate operation certificate examination question bank and simulation examination

2022江西光伏展,中國分布式光伏展會,南昌太陽能利用展

Halo 开源项目学习(二):实体类与数据表

Gets the time range of the current week

Examination question bank and online simulation examination of the third batch (main person in charge) of special operation certificate of safety officer a certificate in Guangdong Province in 2022

92. Reverse linked list II byte skipping high frequency question

Go's gin framework learning

Go语言JSON包使用

云原生虚拟化:基于 Kubevirt 构建边缘计算实例
随机推荐
列表的使用-增删改查
cv_ Solution of mismatch between bridge and opencv
编译原理 求first集 follow集 select集预测分析表 判断符号串是否符合文法定义(有源码!!!)
Element calculation distance and event object
958. Complete binary tree test
Leak detection and vacancy filling (6)
Type judgment in [untitled] JS
JVM class loading mechanism
.105Location
958. 二叉树的完全性检验
587. 安装栅栏 / 剑指 Offer II 014. 字符串中的变位词
Leak detection and vacancy filling (VIII)
Summary of common SQL statements
Special effects case collection: mouse planet small tail
The JS timestamp of wechat applet is converted to / 1000 seconds. After six hours and one day, this Friday option calculates the time
470. 用 Rand7() 实现 Rand10()
Comparison between xtask and kotlin coroutine
Anchor location - how to set the distance between the anchor and the top of the page. The anchor is located and offset from the top
Summary of common server error codes
Go language JSON package usage