当前位置:网站首页>用tensorflow.keras模块化搭建神经网络模型
用tensorflow.keras模块化搭建神经网络模型
2022-08-09 06:56:00 【Anakin6174】
资料来源:北京大学 曹建教授的课程 人工智能实践:TensorFlow笔记
使用八股搭建神经网络
其中第三步使用Sequential只能搭建简易的全连接模型,如果是有跳转的卷积网络或者其他复杂设计的网络需要自己创建一个类来设计;
利用鸢尾花数据集来搭建网络举例:
# 用sequential或自己搭建model类
import tensorflow as tf
from sklearn import datasets
import numpy as np
x_train = datasets.load_iris().data
y_train = datasets.load_iris().target
np.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)
# ******************
# 可以自己搭建模型类,效果一样
# class IrisModel(Model):
# def __init__(self):
# super(IrisModel, self).__init__()
# self.d1 = Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
#
# def call(self, x):
# y = self.d1(x)
# return y
#
# model = IrisModel()
# ******************
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
])
model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
model.summary()
使用mnist数据集搭建神经网络
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Model
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
class MnistModel(Model):
def __init__(self):
super(MnistModel, self).__init__()
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(10, activation='softmax')
def call(self, x):
x = self.flatten(x)
x = self.d1(x)
y = self.d2(x)
return y
model = MnistModel()
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()
边栏推荐
- io.lettuce.core。RedisCommandTimeoutException命令超时
- 字节跳动面试题之镜像二叉树2020
- ByteDance Interview Questions: Mirror Binary Tree 2020
- The Integer thread safe
- 字节也开始缩招了...
- 排序第二节——选择排序(选择排序+堆排序)(两个视频讲解)
- RK3568商显版开源鸿蒙板卡产品解决方案
- P7 Alibaba Interview Questions 2020.07 Sliding Window Algorithm (Alibaba Cloud Interview)
- The working principle of the transformer (illustration, schematic explanation, understand at a glance)
- Use baidu EasyDL intelligent bin
猜你喜欢
C language implements sequential stack and chain queue
报错:FSADeprecationWarning: SQLALCHEMY_TRACK_MODIFICATIONS重大开销和将disab补充道
ByteDance Written Exam 2020 (Douyin E-commerce)
Use of PlantUML plugin in idea
高项 04 项目整体管理
db.sqlite3没有“as Data Source“解决方法
Error jinja2.exceptions.UndefinedError: 'form' is undefined
Distributed id generator implementation
makefile记录
分布式id 生成器实现
随机推荐
The solution that does not work and does not take effect after VScode installs ESlint
线程的6种状态
【ROS2原理8】节点到参与者的重映射
composer 内存不足够
分布式理论
Use of PlantUML plugin in idea
idea中PlantUML插件使用
排序第四节——归并排序(附有自己的视频讲解)
MVN 中配置flyway mysq
报错:FSADeprecationWarning: SQLALCHEMY_TRACK_MODIFICATIONS重大开销和将disab补充道
长沙学院2022暑假训练赛(一)六级阅读
顺序表删除所有值为e的元素
C语言的内置宏(定义日志宏)
failed (13: Permission denied) while connecting to upstream
高项 04 项目整体管理
6 states of a thread
Distributed id generator implementation
我入职阿里后,才知道原来简历这么写
The AD in the library of library file suffix. Intlib. Schlib. Pcblib difference
Mysql实操