当前位置:网站首页>Transfer learning of five categories of pictures based on VGg
Transfer learning of five categories of pictures based on VGg
2022-04-23 17:54:00 【Stephen_ Tao】
List of articles
Introduction to dataset
It is divided into two parts: training set and test set , Each part contains 5 Categories of data , They are cars 、 The dinosaur 、 Elephant 、 Flowers and horses .
Code implementation
It is mainly divided into the following five steps :
- Read local picture data and categories
- VGG Modification of model structure ( Add a custom classification layer )
- freeze Drop the original VGG Model
- compile 、 Train and save models
Import of related packages
import numpy as np
import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator,load_img,img_to_array
from tensorflow.python.keras.applications.vgg16 import VGG16,preprocess_input
from tensorflow.python.keras import layers
from tensorflow.python.keras.optimizers import adam_v2
from tensorflow.python.keras.losses import sparse_categorical_crossentropy
from tensorflow.python.keras.callbacks import ModelCheckpoint
Read local picture data and categories
class TransferModel(object):
def __init__(self):
self.train_dir = './data/train'
self.test_dir = './data/test'
self.model_size = (224,224)
self.batch_size = 32
self.train_generator = ImageDataGenerator(rescale=1.0/255.0)
self.test_generator = ImageDataGenerator(rescale=1.0/255.0)
self.base_model = VGG16(include_top=False)
def get_local_data(self):
""" Read local picture data and categories :return: Iterators for training data and test data """
train_gen = self.train_generator.flow_from_directory(directory=self.train_dir,
target_size=self.model_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True)
test_gen = self.test_generator.flow_from_directory(directory=self.test_dir,
target_size=self.model_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True)
return train_gen,test_gen
Yes train_gen as well as test_gen Print , We can get the following results :
Found 400 images belonging to 5 classes.
Found 100 images belonging to 5 classes.
<tensorflow.python.keras.preprocessing.image.DirectoryIterator object at 0x000001AB1BD65E80>
<tensorflow.python.keras.preprocessing.image.DirectoryIterator object at 0x000001AB1BD65520>
VGG Modification of model structure
def refine_vgg_model(self):
x = self.base_model.outputs[0]
# use GlobalAveragePooling2D Reduce the parameters of the model
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(1024,activation=tf.nn.relu)(x)
y_predict = layers.Dense(5,activation=tf.nn.softmax)(x)
model = keras.Model(inputs=self.base_model.inputs,outputs=y_predict)
return model
First of all, yes VGG_nontop The output of the model is analyzed GlobalAveragePooling2D Reduce the parameters of full connection , Then customize and build two full connection layers , Get a new model .
freeze Drop the original VGG Model parameters
def freeze_vgg_model(self):
for layer in self.base_model.layers:
layer.trainable = False
compile 、 Train and save models
def compile(self,model):
model.compile(optimizer=adam_v2.Adam(),
loss=sparse_categorical_crossentropy,
metrics=['accuracy'])
def fit(self,model,train_gen,test_gen):
check = ModelCheckpoint('./ckpt/transfer_{epoch:02d}-{val_accuracy:.2f}.h5',
monitor='val_accuracy',
save_best_only=True,
save_weights_only=True,
mode='auto',
period=1)
model.fit_generator(train_gen, epochs=3, validation_data=test_gen, callbacks=[check])
The main function
if __name__ == '__main__':
tm = TransferModel()
train_gen,test_gen = tm.get_local_data()
model = tm.refine_vgg_model()
# print(tm.refine_vgg_model().summary())
tm.freeze_vgg_model()
tm.compile(model)
tm.fit(model,train_gen,test_gen)
After the training, you will get the following documents :
Model to predict
def predict(self,model):
model.load_weights('./ckpt/transfer_02-0.93.h5')
image = load_img('./data/test/bus/300.jpg',target_size=(224,224))
# print(image)
image = img_to_array(image)
# print(" The shape of the picture :", image.shape)
# Shape from 3 Change the dimension to 4 dimension
img = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
# print(" Changing the shape results in :", img.shape)
# 3、 Process image content , Normalization, etc , To make predictions
img = preprocess_input(img)
print(img.shape)
y_predict = model.predict(img)
index = np.argmax(y_predict, axis=1)
#
print(self.label_dict[str(index[0])])
The prediction results are as follows :
版权声明
本文为[Stephen_ Tao]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230548468317.html
边栏推荐
- Kubernetes 服务发现 监控Endpoints
- Auto.js 自定义对话框
- 92. Reverse linked list II byte skipping high frequency question
- 2022年茶艺师(初级)考试模拟100题及模拟考试
- The method of changing a value in the array and a value in the object of wechat applet
- 2022年流动式起重机司机国家题库模拟考试平台操作
- Add animation to the picture under V-for timing
- 102. Sequence traversal of binary tree
- 2022 Jiangxi energy storage technology exhibition, China Battery exhibition, power battery exhibition and fuel cell Exhibition
- 92. 反转链表 II-字节跳动高频题
猜你喜欢
Eigen learning summary
Go language JSON package usage
92. Reverse linked list II byte skipping high frequency question
Future usage details
Halo open source project learning (II): entity classes and data tables
Leak detection and vacancy filling (6)
关于gcc输出typeid完整名的方法
Summary of floating point double precision, single precision and half precision knowledge
2022 Jiangxi Photovoltaic Exhibition, China distributed Photovoltaic Exhibition, Nanchang solar energy utilization Exhibition
394. 字符串解码-辅助栈
随机推荐
Sword finger offer 22 The penultimate node in the linked list - speed pointer
Theory and practice of laser slam in dark blue College - Chapter 2 (odometer calibration)
Write a regular
Go对文件操作
Construction of functions in C language programming
2022 Jiangxi Photovoltaic Exhibition, China distributed Photovoltaic Exhibition, Nanchang solar energy utilization Exhibition
Arithmetic expression
Open source key component multi_ Button use, including test engineering
极致体验,揭晓抖音背后的音视频技术
QTableWidget使用讲解
2022年上海市安全员C证操作证考试题库及模拟考试
C language implements memcpy, memset, strcpy, strncpy, StrCmp, strncmp and strlen
高德地图搜索、拖拽 查询地址
.104History
SQL optimization for advanced learning of MySQL [insert, primary key, sort, group, page, count]
41. The first missing positive number
undefined reference to `Nabo::NearestNeighbourSearch
Gets the time range of the current week
Applet learning notes (I)
239. 滑动窗口最大值(困难)-单向队列、大顶堆-字节跳动高频题