当前位置:网站首页>用tensorflow.keras模块化搭建神经网络模型
用tensorflow.keras模块化搭建神经网络模型
2022-08-09 06:56:00 【Anakin6174】
资料来源:北京大学 曹建教授的课程 人工智能实践:TensorFlow笔记
使用八股搭建神经网络
其中第三步使用Sequential只能搭建简易的全连接模型,如果是有跳转的卷积网络或者其他复杂设计的网络需要自己创建一个类来设计;


利用鸢尾花数据集来搭建网络举例:
# 用sequential或自己搭建model类
import tensorflow as tf
from sklearn import datasets
import numpy as np
x_train = datasets.load_iris().data
y_train = datasets.load_iris().target
np.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)
# ******************
# 可以自己搭建模型类,效果一样
# class IrisModel(Model):
# def __init__(self):
# super(IrisModel, self).__init__()
# self.d1 = Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
#
# def call(self, x):
# y = self.d1(x)
# return y
#
# model = IrisModel()
# ******************
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
])
model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
model.summary()
使用mnist数据集搭建神经网络
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Model
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
class MnistModel(Model):
def __init__(self):
super(MnistModel, self).__init__()
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(10, activation='softmax')
def call(self, x):
x = self.flatten(x)
x = self.d1(x)
y = self.d2(x)
return y
model = MnistModel()
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()
边栏推荐
- P7阿里面试题2020.07 之滑动窗算法(阿里云面试)
- 【报错】Root Cause com.mysql.jdbc.exceptions.jdbc4.CommunicationsException: Communications link failure
- Built-in macros in C language (define log macros)
- String.toLowerCase(Locale.ROOT)
- 细谈VR全景:数字营销时代的宠儿
- 【Docker】Docker安装MySQL
- db.sqlite3没有“as Data Source“解决方法
- 排序第一节——插入排序(直接插入排序+希尔排序)(视频讲解26分钟)
- 找不到和chrome浏览器版本不同的chromedriver的解决方法
- C语言的内置宏(定义日志宏)
猜你喜欢
随机推荐
e-learning summary
【烂笔头】各厂商手机手动抓log
P6阿里机试题之2020 斐波那契数
Explain the wait() function and waitpid() function in C language in detail
物理层课后作业
简单工厂模式
shardingsphere数据分片配置项说明和示例
当酷雷曼VR直播遇上视频号,会摩擦出怎样的火花?
移远EC20 4G模块拨号相关
The JVM thread state
longest substring without repeating characters
报错jinja2.exceptions.UndefinedError: ‘form‘ is undefined
常用测试用例设计方法之正交实验法详解
C语言的内置宏(定义日志宏)
多米诺骨牌
AD picture PCB tutorial 20 minutes clear label shop operation process, copper network
The AD in the library of library file suffix. Intlib. Schlib. Pcblib difference
01 自然语言处理NLP介绍
idea中PlantUML插件使用
买口罩(0-1背包)







