当前位置:网站首页>基于VGG对五种类别图片的迁移学习
基于VGG对五种类别图片的迁移学习
2022-04-23 05:50:00 【Stephen_Tao】
数据集的介绍
分为训练集和测试集两个部分,每个部分都包含5个类别的数据,分别为汽车、恐龙、大象、花以及马。
代码实现
主要分为以下的五个步骤:
- 读取本地的图片数据及类别
- VGG模型结构的修改(添加自定义的分类层)
- freeze掉原始的VGG模型
- 编译、训练并保存模型
相关包的导入
import numpy as np
import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator,load_img,img_to_array
from tensorflow.python.keras.applications.vgg16 import VGG16,preprocess_input
from tensorflow.python.keras import layers
from tensorflow.python.keras.optimizers import adam_v2
from tensorflow.python.keras.losses import sparse_categorical_crossentropy
from tensorflow.python.keras.callbacks import ModelCheckpoint
读取本地的图片数据及类别
class TransferModel(object):
def __init__(self):
self.train_dir = './data/train'
self.test_dir = './data/test'
self.model_size = (224,224)
self.batch_size = 32
self.train_generator = ImageDataGenerator(rescale=1.0/255.0)
self.test_generator = ImageDataGenerator(rescale=1.0/255.0)
self.base_model = VGG16(include_top=False)
def get_local_data(self):
""" 读取本地的图片数据以及类别 :return:训练数据和测试数据的迭代器 """
train_gen = self.train_generator.flow_from_directory(directory=self.train_dir,
target_size=self.model_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True)
test_gen = self.test_generator.flow_from_directory(directory=self.test_dir,
target_size=self.model_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True)
return train_gen,test_gen
对train_gen以及test_gen进行打印,可以得到以下的结果:
Found 400 images belonging to 5 classes.
Found 100 images belonging to 5 classes.
<tensorflow.python.keras.preprocessing.image.DirectoryIterator object at 0x000001AB1BD65E80>
<tensorflow.python.keras.preprocessing.image.DirectoryIterator object at 0x000001AB1BD65520>
VGG模型结构的修改
def refine_vgg_model(self):
x = self.base_model.outputs[0]
# 采用GlobalAveragePooling2D减少模型的参数
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(1024,activation=tf.nn.relu)(x)
y_predict = layers.Dense(5,activation=tf.nn.softmax)(x)
model = keras.Model(inputs=self.base_model.inputs,outputs=y_predict)
return model
先是对VGG_nontop模型的输出进行GlobalAveragePooling2D减少全连接的参数,然后自定义构建两个全连接层,获得新的模型。
freeze掉原始的VGG模型参数
def freeze_vgg_model(self):
for layer in self.base_model.layers:
layer.trainable = False
编译、训练并保存模型
def compile(self,model):
model.compile(optimizer=adam_v2.Adam(),
loss=sparse_categorical_crossentropy,
metrics=['accuracy'])
def fit(self,model,train_gen,test_gen):
check = ModelCheckpoint('./ckpt/transfer_{epoch:02d}-{val_accuracy:.2f}.h5',
monitor='val_accuracy',
save_best_only=True,
save_weights_only=True,
mode='auto',
period=1)
model.fit_generator(train_gen, epochs=3, validation_data=test_gen, callbacks=[check])
主函数
if __name__ == '__main__':
tm = TransferModel()
train_gen,test_gen = tm.get_local_data()
model = tm.refine_vgg_model()
# print(tm.refine_vgg_model().summary())
tm.freeze_vgg_model()
tm.compile(model)
tm.fit(model,train_gen,test_gen)
训练结束后将得到以下文件:
模型预测
def predict(self,model):
model.load_weights('./ckpt/transfer_02-0.93.h5')
image = load_img('./data/test/bus/300.jpg',target_size=(224,224))
# print(image)
image = img_to_array(image)
# print("图片的形状:", image.shape)
# 形状从3维度修改成4维
img = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
# print("改变形状结果:", img.shape)
# 3、处理图像内容,归一化处理等,进行预测
img = preprocess_input(img)
print(img.shape)
y_predict = model.predict(img)
index = np.argmax(y_predict, axis=1)
#
print(self.label_dict[str(index[0])])
预测结果如下:
版权声明
本文为[Stephen_Tao]所创,转载请带上原文链接,感谢
https://blog.csdn.net/professor_tao/article/details/119976841
边栏推荐
- [UDS unified diagnostic service] IV. typical diagnostic service (2) - data transmission function unit
- Protection of shared data
- 搭建openstack平台
- Robocode教程8——AdvancedRobot
- Easy to use data set and open source network comparison website
- C#中?的这种形式
- 爬西瓜视频url
- LaTeX配置与使用
- Tabbar implementation of dynamic bottom navigation bar in uniapp, authority management
- sqlite3加密版
猜你喜欢
随机推荐
函数的调用过程
在MFC中使用printf
类和对象
[learn] HF net training
Call procedure of function
Installation of GCC, G + +, GDB
猜數字遊戲
【UDS统一诊断服务】(补充)五、ECU bootloader开发要点详解 (2)
【UDS统一诊断服务】一、诊断概述(1)— 诊断概述
CUDA环境安装
for()循环参数调用顺序
搭建openstack平台
vs中能编译通过,但是会有红色下划线提示未定义标示符问题
类和对象的初始化(构造函数与析构函数)
Flask操作多个数据库
grub boot. S code analysis
Camera calibration: key point method vs direct method
爬取蝉妈妈数据平台商品数据
爬取小米有品app商品数据
C语言进阶要点笔记5