当前位置:网站首页>TensorFlow—计算梯度与控制梯度 : tf.gradients和compute_gradients和apply_gradients和clip_by_global_norm控制梯度
TensorFlow—计算梯度与控制梯度 : tf.gradients和compute_gradients和apply_gradients和clip_by_global_norm控制梯度
2022-08-09 10:42:00 【模糊包】
TensorFlow的梯度
我们知道训练神经网络有一个很重要的就是反向传播更新参数,如果没有经历过2015-2017年的神经网络的研究生,这一步听陌生的,但是不重要,我们知道TensorFlow
给我们API
怎么用就行了。
对于反向传播这一步,我们常见的代码是如下:
# 损失计算,也就是优化对象
loss = tf.nn..............
# 反向传播
# 定义优化器,学习率定义1.0
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
# 定义优化对象,对象loss,计算当前step,并且计算一次就+1
optimizer.minimize(loss ,global_step=global_step)
1. 计算与更新(梯度)
其中minimize
包含了两个操作:
- 获得变量的梯度。
- 用梯度更新变量
其中主要计算过程是:
def minimize(self, loss, global_step=None, var_list=None, name=None):
grads_and_vars = self.compute_gradients(loss, var_list=var_list)
..............................
.......中间这些不重要...........
..............................
return self.apply_gradients(grads_and_vars,global_step=global_step, name=name)
这里有两个重要的代码段落,一定要弄清楚。
1. grads_and_vars = self.compute_gradients(loss, var_list=var_list)
这一段和名字意思一样 : 将梯度和变量打包
可以知道做了两件事:1.计算梯度。2.打包梯度和变量。
参数解释:loss
就是我们上面说到的损失值,其实就是我们计算得到的loss;var_list
是我们要计算梯度的变量,其实就是计算图中所有的参数。
该函数等价于下面的代码下面常用于个人写代码:
# 等同于上面的var_list,其实就是所有的变量
trainable_variables = tf.trainable_variables()
# 获取损失值对各个变量的偏导数之和,即梯度,大小len(input_data)
grads = tf.gradients(cost/tf.to_float(batch_size), trainable_variables)
# 这里一般约束梯度,避免梯度爆炸,也可以不写。
grads, _ = tf.clip_by_global_norm(grads, MAX_GRAD_NORM)
# 打包梯度和变量的,结构是元组(grads, 参数值)
'''这个可以看下面例子'''
grads_and_vars = zip(grads, trainable_variables)
而上述的代码,其实就是我们常用写法,因为这样拆解出来更灵活。关于tf.gradients()
可以看这个文章关于tf.gradients
看不明白compute_gradients()
,可以举个例子如下:
x1 = tf.Variable(initial_value=[2.,3.], dtype='float32')
w = tf.Variable(initial_value=[[3.,4.],[1.,2.],[2.5,4.1]], dtype='float32')
b = tf.Variable(initial_value=1., dtype='float32')
# 这里简化了wx+b,不然定义w会好麻烦
y = x1*w+b
opt = tf.train.GradientDescentOptimizer(0.1)
grad = opt.compute_gradients(y, [w])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(grad))
# 输出
[(array([[2., 3.],
[2., 3.],
[2., 3.]], dtype=float32), array([[3. , 4. ],
[1. , 2. ],
[2.5, 4.1]], dtype=float32))]
''' grads ---> array([[2., 3.], [2., 3.], [2., 3.]], dtype=float32) 大小len(w) varias---> array([[3. , 4. ], [1. , 2. ], [2.5, 4.1]], dtype=float32) 变量w '''
这个例子,大家可以试试对x1求梯度大小,结果是不一样的
1.2 apply_gradients(grads_and_vars,global_step, name)
这个函数主要是应用梯度下降到变量上,用于更新变量
将compute_gradients()
返回的值作为输入参数对variable
进行更新
参数解释:grads_and_vars
这就是compute_gradients(loss, var_list)
计算得到梯度,代表了各个变量的偏导数 ;global_step就是从全局中获得的step
; name
一般是minimize
优化时候定义的名字,默认是None
2.约束梯度大小防止爆炸
tf.clip_by_value和tf.clip_by_norm和clip_by_global_norm都可以。不做累述。
边栏推荐
- BERT预训练模型(Bidirectional Encoder Representations from Transformers)-原理详解
- PoseNet: A Convolutional Network for Real-Time 6-DOF Camera Relocalization论文阅读
- ESIM(Enhanced Sequential Inference Model)- 模型详解
- 2022年台湾省矢量数据(点线面)及数字高程数据下载
- UNIX Environment Programming Chapter 15 15.5FIFO
- numpy的ndarray取数操作
- 研发需求的验收标准应该怎么写? | 敏捷实践
- OpenGL ES2.0编程三部曲(转载自MyArrow)
- antd表单
- 基于STM32设计的环境检测设备
猜你喜欢
程序员的专属浪漫——用3D Engine 5分钟实现烟花绽放效果
StratoVirt 中的虚拟网卡是如何实现的?
Attentional Feature Fusion
ESIM(Enhanced Sequential Inference Model)- 模型详解
Cpolar内网穿透的面板功能介绍
【 original 】 VMware Workstation implementation Openwrt soft routing, the ESXI, content is very detailed!
分类预测 | MATLAB实现CNN-GRU(卷积门控循环单元)多特征分类预测
Dialogue with the DPO of a multinational consumer brand: How to start with data security compliance?See you on 8.11 Live!
上传张最近做的E2用的xmms的界面的截图
在webgis中显示矢量化后的风险防控信息
随机推荐
How tall is the B+ tree of the MySQL index?
snmp++编译错误问题解决方法
Unix Environment Programming Chapter 15 15.9 Shared Storage
1005 继续(3n+1)猜想 (25 分)
可能95%的人还在犯的PyTorch错误
Win7 远程桌面限制IP
壁纸
力扣(LeetCode)220. 存在重复元素 III(2022.08.08)
Probably 95% of the people are still making PyTorch mistakes
在犹豫中度过了老多天,今天的工作时记录
The GNU Privacy Guard
好久没更新博客了
985毕业,工作3年,分享从阿里辞职到了国企的一路辛酸和经验
强化学习 (Reinforcement Learning)
TELNET协议相关RFC
BERT预训练模型(Bidirectional Encoder Representations from Transformers)-原理详解
通过Doc在MySQL数据库中建表
xmms已经发布到v1.3了,好久没写博客了
百度云大文件网页直接下载
antd表单