当前位置:网站首页>TensorFlow—计算梯度与控制梯度 : tf.gradients和compute_gradients和apply_gradients和clip_by_global_norm控制梯度
TensorFlow—计算梯度与控制梯度 : tf.gradients和compute_gradients和apply_gradients和clip_by_global_norm控制梯度
2022-08-09 10:42:00 【模糊包】
TensorFlow的梯度
我们知道训练神经网络有一个很重要的就是反向传播更新参数,如果没有经历过2015-2017年的神经网络的研究生,这一步听陌生的,但是不重要,我们知道TensorFlow给我们API怎么用就行了。
对于反向传播这一步,我们常见的代码是如下:
# 损失计算,也就是优化对象
loss = tf.nn..............
# 反向传播
# 定义优化器,学习率定义1.0
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
# 定义优化对象,对象loss,计算当前step,并且计算一次就+1
optimizer.minimize(loss ,global_step=global_step)
1. 计算与更新(梯度)
其中minimize包含了两个操作:
- 获得变量的梯度。
- 用梯度更新变量
其中主要计算过程是:
def minimize(self, loss, global_step=None, var_list=None, name=None):
grads_and_vars = self.compute_gradients(loss, var_list=var_list)
..............................
.......中间这些不重要...........
..............................
return self.apply_gradients(grads_and_vars,global_step=global_step, name=name)
这里有两个重要的代码段落,一定要弄清楚。
1. grads_and_vars = self.compute_gradients(loss, var_list=var_list)
这一段和名字意思一样 : 将梯度和变量打包
可以知道做了两件事:1.计算梯度。2.打包梯度和变量。
参数解释:loss就是我们上面说到的损失值,其实就是我们计算得到的loss;var_list是我们要计算梯度的变量,其实就是计算图中所有的参数。
该函数等价于下面的代码下面常用于个人写代码:
# 等同于上面的var_list,其实就是所有的变量
trainable_variables = tf.trainable_variables()
# 获取损失值对各个变量的偏导数之和,即梯度,大小len(input_data)
grads = tf.gradients(cost/tf.to_float(batch_size), trainable_variables)
# 这里一般约束梯度,避免梯度爆炸,也可以不写。
grads, _ = tf.clip_by_global_norm(grads, MAX_GRAD_NORM)
# 打包梯度和变量的,结构是元组(grads, 参数值)
'''这个可以看下面例子'''
grads_and_vars = zip(grads, trainable_variables)
而上述的代码,其实就是我们常用写法,因为这样拆解出来更灵活。关于tf.gradients()可以看这个文章关于tf.gradients
看不明白compute_gradients(),可以举个例子如下:
x1 = tf.Variable(initial_value=[2.,3.], dtype='float32')
w = tf.Variable(initial_value=[[3.,4.],[1.,2.],[2.5,4.1]], dtype='float32')
b = tf.Variable(initial_value=1., dtype='float32')
# 这里简化了wx+b,不然定义w会好麻烦
y = x1*w+b
opt = tf.train.GradientDescentOptimizer(0.1)
grad = opt.compute_gradients(y, [w])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(grad))
# 输出
[(array([[2., 3.],
[2., 3.],
[2., 3.]], dtype=float32), array([[3. , 4. ],
[1. , 2. ],
[2.5, 4.1]], dtype=float32))]
''' grads ---> array([[2., 3.], [2., 3.], [2., 3.]], dtype=float32) 大小len(w) varias---> array([[3. , 4. ], [1. , 2. ], [2.5, 4.1]], dtype=float32) 变量w '''
这个例子,大家可以试试对x1求梯度大小,结果是不一样的
1.2 apply_gradients(grads_and_vars,global_step, name)
这个函数主要是应用梯度下降到变量上,用于更新变量
将compute_gradients()返回的值作为输入参数对variable进行更新
参数解释:grads_and_vars这就是compute_gradients(loss, var_list)计算得到梯度,代表了各个变量的偏导数 ;global_step就是从全局中获得的step ; name一般是minimize优化时候定义的名字,默认是None
2.约束梯度大小防止爆炸
tf.clip_by_value和tf.clip_by_norm和clip_by_global_norm都可以。不做累述。
边栏推荐
- unix环境编程 第十四章 14.8 存储映射I/O
- 小程序员的发展计划
- unix环境编程 第十四章 14.4 I/O多路转接
- [Original] Usage of @PrePersist and @PreUpdate in JPA
- UNIX Environment Programming Chapter 15 15.5FIFO
- 机器学习-逻辑回归(logistics regression)
- 编程技术提升
- Solve the ali cloud oss - the original 】 【 exe double-click response can't open, to provide a solution
- 机器学习--朴素贝叶斯(Naive Bayes)
- 10000以内素数表(代码块)
猜你喜欢
随机推荐
基于STM32设计的环境检测设备
相伴成长,彼此成就 用友U9 cloud做好制造业数智化升级的同路人
数据存储:对dataframe类,使用to_csv()将中文数据写入csv文件
小程序员的发展计划
阿里神作!吃透这份资料入厂率高达99%
工作--今天的学习
1003 我要通过! (20 分)
实现下拉加载更多
numpy库中的函数 bincount() where() diag() all()
使用cpolar远程连接群晖NAS(创建临时链接)
MySQL索引的B+树到底有多高?
关于页面初始化
虚拟列表key复用问题
Dialogue with the DPO of a multinational consumer brand: How to start with data security compliance?See you on 8.11 Live!
Electron application development best practices
在犹豫中度过了老多天,今天的工作时记录
绝了,这套RESTful API接口设计总结
使用.NET简单实现一个Redis的高性能克隆版(四、五)
10000以内素数表(代码块)
Win32控件------------显示系统使用的控件版本









