当前位置:网站首页>使用 Kears 实现ResNet-34 CNN
使用 Kears 实现ResNet-34 CNN
2022-04-23 05:12:00 【看到我你要笑一下】
让我们开始实现ResNet-34.
一、创建一个ResidualUnit层。
from functools import partial
import tensorflow as tf
from tensorflow import keras
DefaultConv2D = partial(keras.layers.Conv2D, kernel_size=3, strides=1,
padding="SAME", use_bias=False)
class ResidualUnit(keras.layers.Layer):
def __init__(self, filters, strides=1, activation="relu", **kwargs):
super().__init__(**kwargs)
self.activation = keras.activations.get(activation)
self.main_layers = [
DefaultConv2D(filters, strides=strides),
keras.layers.BatchNormalization(),
self.activation,
DefaultConv2D(filters),
keras.layers.BatchNormalization()]
self.skip_layers = []
if strides > 1:
self.skip_layers = [
DefaultConv2D(filters, kernel_size=1, strides=strides),
keras.layers.BatchNormalization()]
def call(self, inputs):
Z = inputs
for layer in self.main_layers:
Z = layer(Z)
skip_Z = inputs
for layer in self.skip_layers:
skip_Z = layer(skip_Z)
return self.activation(Z + skip_Z)
由上,在构造函数中,我们创建了所需要的所有层:
主要层、跳过层(当步幅大于1时需要)。
在call()方法中,我们使输入经过主要层、跳过层,然后添加输出层并应用激活函数。
二、用Sequential 模型来构建ResNet-34.
这个模型 实际上是一个 非常长的 层序列。
现在有了上面的ResidualUnit类,我们可以将每个残差单元是为一个层。
model = keras.models.Sequential()
model.add(DefaultConv2D(64, kernel_size=7, strides=2,
input_shape=[224, 224, 3]))
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.Activation("relu"))
model.add(keras.layers.MaxPool2D(pool_size=3, strides=2, padding="SAME"))
prev_filters = 64
for filters in [64] * 3 + [128] * 4 + [256] * 6 + [512] * 3:
strides = 1 if filters == prev_filters else 2
model.add(ResidualUnit(filters, strides=strides))
prev_filters = filters
model.add(keras.layers.GlobalAvgPool2D())
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(10, activation="softmax"))
这里将ResidualUnit 层 加到模型的循环:
前3个RU(残差单元) 具有64个滤波器,然后余下4个有128个。以此类推。
当滤波器的数量与上一个RU层相同时,将步幅设置为1,否则为2.
然后添加ResidualUnit ,最后更新prev_filters.
三、别忘了实例化模型。
model.summary()
结果:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 112, 112, 64) 9408
batch_normalization (BatchN (None, 112, 112, 64) 256
ormalization)
activation (Activation) (None, 112, 112, 64) 0
max_pooling2d (MaxPooling2D (None, 56, 56, 64) 0
)
residual_unit (ResidualUnit (None, 56, 56, 64) 74240
)
residual_unit_1 (ResidualUn (None, 56, 56, 64) 74240
it)
residual_unit_2 (ResidualUn (None, 56, 56, 64) 74240
it)
residual_unit_3 (ResidualUn (None, 28, 28, 128) 230912
it)
residual_unit_4 (ResidualUn (None, 28, 28, 128) 295936
it)
residual_unit_5 (ResidualUn (None, 28, 28, 128) 295936
it)
residual_unit_6 (ResidualUn (None, 28, 28, 128) 295936
it)
residual_unit_7 (ResidualUn (None, 14, 14, 256) 920576
it)
residual_unit_8 (ResidualUn (None, 14, 14, 256) 1181696
it)
residual_unit_9 (ResidualUn (None, 14, 14, 256) 1181696
it)
residual_unit_10 (ResidualU (None, 14, 14, 256) 1181696
nit)
residual_unit_11 (ResidualU (None, 14, 14, 256) 1181696
nit)
residual_unit_12 (ResidualU (None, 14, 14, 256) 1181696
nit)
residual_unit_13 (ResidualU (None, 7, 7, 512) 3676160
nit)
residual_unit_14 (ResidualU (None, 7, 7, 512) 4722688
nit)
residual_unit_15 (ResidualU (None, 7, 7, 512) 4722688
nit)
global_average_pooling2d (G (None, 512) 0
lobalAveragePooling2D)
flatten (Flatten) (None, 512) 0
dense (Dense) (None, 10) 5130
=================================================================
Total params: 21,306,826
Trainable params: 21,289,802
Non-trainable params: 17,024
_________________________________________________________________
参考:
Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow, 2nd Edition, 作者: Aurelien Geron(法语) , 又 O Reilly 出版, 书号 978-1-492-03264-9。
版权声明
本文为[看到我你要笑一下]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_51153436/article/details/124227878
边栏推荐
- Transaction isolation level of MySQL transactions
- MySQL 慢查询
- Independent station operation | Facebook marketing artifact - chat robot manychat
- Sword finger offer: the median in the data stream (priority queue large top heap small top heap leetcode 295)
- What are instruction cycles, machine cycles, and clock cycles?
- Basic concepts of multithreading (concurrency and parallelism, threads and processes) and entry cases
- Tensorflow realizes web face login system
- The WebService interface writes and publishes calls to the WebService interface (I)
- View, modify and delete [database] table
- Differences between redis and MySQL
猜你喜欢

源码剖析Redis中如何使用跳表的
![[database] MySQL single table query](/img/27/99d174219109ea7a455cfdf55e0996.png)
[database] MySQL single table query

Cross border e-commerce | Facebook and instagram: which social media is more suitable for you?

DIY is an excel version of subnet calculator

持续集成(CI)/持续交付(CD)如何彻底改变自动化测试

Deep learning notes - object detection and dataset + anchor box

《2021年IT行业项目管理调查报告》重磅发布!

好的测试数据管理,到底要怎么做?

MySQL 慢查询

改进DevSecOps框架的 5 大关键技术
随机推荐
项目经理值得一试的思维方式:项目成功方程式
Introduction to load balancing
[untitled] kimpei kdboxpro's cool running lantern coexists with beauty and strength
Streamexecutionenvironment of Flink source code
MySQL basics 3
Unity C# 网络学习(四)
Use AES encryption - reuse the wisdom of predecessors
TypeError: ‘Collection‘ object is not callable. If you meant to call the ......
Detailed explanation of concurrent topics
MySQL realizes row to column SQL
Nacos source code startup error report solution
leetcode——启发式搜索
Uglifyjs compress JS
QPushButton slot function is triggered multiple times
Basic knowledge of vegetable chicken database
Summary of MySQL knowledge points
Deep learning notes - semantic segmentation and data sets
Day. JS common methods
Flip coin (Blue Bridge Cup)
What are the redis data types