当前位置:网站首页>用tensorflow.keras模块化搭建神经网络模型
用tensorflow.keras模块化搭建神经网络模型
2022-08-09 06:56:00 【Anakin6174】
资料来源:北京大学 曹建教授的课程 人工智能实践:TensorFlow笔记
使用八股搭建神经网络
其中第三步使用Sequential只能搭建简易的全连接模型,如果是有跳转的卷积网络或者其他复杂设计的网络需要自己创建一个类来设计;
利用鸢尾花数据集来搭建网络举例:
# 用sequential或自己搭建model类
import tensorflow as tf
from sklearn import datasets
import numpy as np
x_train = datasets.load_iris().data
y_train = datasets.load_iris().target
np.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)
# ******************
# 可以自己搭建模型类,效果一样
# class IrisModel(Model):
# def __init__(self):
# super(IrisModel, self).__init__()
# self.d1 = Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
#
# def call(self, x):
# y = self.d1(x)
# return y
#
# model = IrisModel()
# ******************
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
])
model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
model.summary()
使用mnist数据集搭建神经网络
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Model
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
class MnistModel(Model):
def __init__(self):
super(MnistModel, self).__init__()
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(10, activation='softmax')
def call(self, x):
x = self.flatten(x)
x = self.d1(x)
y = self.d2(x)
return y
model = MnistModel()
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()
边栏推荐
猜你喜欢
随机推荐
当酷雷曼VR直播遇上视频号,会摩擦出怎样的火花?
SIGINT, SIGKILL, SIGTERM signal difference, summary of various signals
e-learning summary
长沙学院2022暑假训练赛(一)六级阅读
集合内之部原理总结
imageio读取.exr报错 ValueError: Could not find a backend to open `xxx.exr‘ with iomode `r`
【sqlite3】sqlite3.OperationalError: table addresses has 7 columns but 6 values were supplied
虚拟机网卡报错:Bringing up interface eth0: Error: No suitable device found: no device found for connection
The Integer thread safe
Introduction and use of BeautifulSoup4
shardingsphere data sharding configuration item description and example
stm32定时器之简单封装
【ROS2原理8】节点到参与者的重映射
默默重新开始,第一页也是新的一页
Use of PlantUML plugin in idea
RK3568商显版开源鸿蒙板卡产品解决方案
分布式事务的应用场景
Altium designer software commonly used the most complete package library, including schematic library, PCB library and 3D model library
【修电脑】系统重装但IP不变后VScode Remote SSH连接失败解决
Import the pycharm environment package into another environment