当前位置:网站首页>Transfer learning of five categories of pictures based on VGg
Transfer learning of five categories of pictures based on VGg
2022-04-23 17:54:00 【Stephen_ Tao】
List of articles
Introduction to dataset
It is divided into two parts: training set and test set , Each part contains 5 Categories of data , They are cars 、 The dinosaur 、 Elephant 、 Flowers and horses .
Code implementation
It is mainly divided into the following five steps :
- Read local picture data and categories
- VGG Modification of model structure ( Add a custom classification layer )
- freeze Drop the original VGG Model
- compile 、 Train and save models
Import of related packages
import numpy as np
import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator,load_img,img_to_array
from tensorflow.python.keras.applications.vgg16 import VGG16,preprocess_input
from tensorflow.python.keras import layers
from tensorflow.python.keras.optimizers import adam_v2
from tensorflow.python.keras.losses import sparse_categorical_crossentropy
from tensorflow.python.keras.callbacks import ModelCheckpoint
Read local picture data and categories
class TransferModel(object):
def __init__(self):
self.train_dir = './data/train'
self.test_dir = './data/test'
self.model_size = (224,224)
self.batch_size = 32
self.train_generator = ImageDataGenerator(rescale=1.0/255.0)
self.test_generator = ImageDataGenerator(rescale=1.0/255.0)
self.base_model = VGG16(include_top=False)
def get_local_data(self):
""" Read local picture data and categories :return: Iterators for training data and test data """
train_gen = self.train_generator.flow_from_directory(directory=self.train_dir,
target_size=self.model_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True)
test_gen = self.test_generator.flow_from_directory(directory=self.test_dir,
target_size=self.model_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True)
return train_gen,test_gen
Yes train_gen as well as test_gen Print , We can get the following results :
Found 400 images belonging to 5 classes.
Found 100 images belonging to 5 classes.
<tensorflow.python.keras.preprocessing.image.DirectoryIterator object at 0x000001AB1BD65E80>
<tensorflow.python.keras.preprocessing.image.DirectoryIterator object at 0x000001AB1BD65520>
VGG Modification of model structure
def refine_vgg_model(self):
x = self.base_model.outputs[0]
# use GlobalAveragePooling2D Reduce the parameters of the model
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(1024,activation=tf.nn.relu)(x)
y_predict = layers.Dense(5,activation=tf.nn.softmax)(x)
model = keras.Model(inputs=self.base_model.inputs,outputs=y_predict)
return model
First of all, yes VGG_nontop The output of the model is analyzed GlobalAveragePooling2D Reduce the parameters of full connection , Then customize and build two full connection layers , Get a new model .
freeze Drop the original VGG Model parameters
def freeze_vgg_model(self):
for layer in self.base_model.layers:
layer.trainable = False
compile 、 Train and save models
def compile(self,model):
model.compile(optimizer=adam_v2.Adam(),
loss=sparse_categorical_crossentropy,
metrics=['accuracy'])
def fit(self,model,train_gen,test_gen):
check = ModelCheckpoint('./ckpt/transfer_{epoch:02d}-{val_accuracy:.2f}.h5',
monitor='val_accuracy',
save_best_only=True,
save_weights_only=True,
mode='auto',
period=1)
model.fit_generator(train_gen, epochs=3, validation_data=test_gen, callbacks=[check])
The main function
if __name__ == '__main__':
tm = TransferModel()
train_gen,test_gen = tm.get_local_data()
model = tm.refine_vgg_model()
# print(tm.refine_vgg_model().summary())
tm.freeze_vgg_model()
tm.compile(model)
tm.fit(model,train_gen,test_gen)
After the training, you will get the following documents :
Model to predict
def predict(self,model):
model.load_weights('./ckpt/transfer_02-0.93.h5')
image = load_img('./data/test/bus/300.jpg',target_size=(224,224))
# print(image)
image = img_to_array(image)
# print(" The shape of the picture :", image.shape)
# Shape from 3 Change the dimension to 4 dimension
img = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
# print(" Changing the shape results in :", img.shape)
# 3、 Process image content , Normalization, etc , To make predictions
img = preprocess_input(img)
print(img.shape)
y_predict = model.predict(img)
index = np.argmax(y_predict, axis=1)
#
print(self.label_dict[str(index[0])])
The prediction results are as follows :
版权声明
本文为[Stephen_ Tao]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230548468317.html
边栏推荐
- 极致体验,揭晓抖音背后的音视频技术
- ROS package NMEA_ navsat_ Driver reads GPS and Beidou Positioning Information Notes
- Go语言JSON包使用
- 2022 judgment questions and answers for operation of refrigeration and air conditioning equipment
- .104History
- JS interview question: FN call. call. call. Call (FN2) parsing
- Future usage details
- 土地覆盖/利用数据产品下载
- Client example analysis of easymodbustcp
- Construction of functions in C language programming
猜你喜欢
102. 二叉树的层序遍历
Laser slam theory and practice of dark blue College Chapter 3 laser radar distortion removal exercise
Go对文件操作
2022 judgment questions and answers for operation of refrigeration and air conditioning equipment
Open source key component multi_ Button use, including test engineering
土地覆盖/利用数据产品下载
SystemVerilog(六)-变量
.104History
470. 用 Rand7() 实现 Rand10()
2022制冷与空调设备运行操作判断题及答案
随机推荐
386. Dictionary order (medium) - iteration - full arrangement
Examination question bank and online simulation examination of the third batch (main person in charge) of special operation certificate of safety officer a certificate in Guangdong Province in 2022
2022江西光伏展,中国分布式光伏展会,南昌太阳能利用展
48. Rotate image
Sword finger offer 03 Duplicate number in array
2022江西储能技术展会,中国电池展,动力电池展,燃料电池展
ES6 new method
ROS package NMEA_ navsat_ Driver reads GPS and Beidou Positioning Information Notes
470. Rand10() is implemented with rand7()
Auto.js 自定义对话框
440. The k-th small number of dictionary order (difficult) - dictionary tree - number node - byte skipping high-frequency question
Error in created hook: "referenceerror:" promise "undefined“
209. 长度最小的子数组-滑动窗口
Vite configure proxy proxy to solve cross domain
Special effects case collection: mouse planet small tail
JS high frequency interview questions
Timestamp to formatted date
JS parsing and execution process
SystemVerilog(六)-变量
239. 滑动窗口最大值(困难)-单向队列、大顶堆-字节跳动高频题