当前位置:网站首页>Tensorflow realizes gradient accumulation, and then returns
Tensorflow realizes gradient accumulation, and then returns
2022-04-23 20:48:00 【NuerNuer】
Because the host graphics card has only 12g Explicit memory of , And only one piece 30 Series of cards , Therefore, when running code, you will inevitably encounter batch_size Don't be too embarrassed , So you can use , The gradient accumulation method is optimized , To expand in disguise batch_size. This kind of operation is in pytorch Good implementation in , But in tf It's a little complicated .
Code up , explain :
def train():
...
...
# All trainable parameters
trainable_vars = tf.trainable_variables()
# Specify the parameters to be trained
vit_trainable_vars = [var for var in trainable_vars if 'VGG' not in var.name] #both generate and vision_transformer #291
print("************vars to train:",len(vit_trainable_vars))
# Define an operation in the calculation diagram , Create an etc. for each parameter to be trained shape Of all the 0 Variable
accum_vars = [tf.Variable(tf.zeros_like(tv.initialized_value()), trainable=False) for tv in vit_trainable_vars]
# Gradient zeroing operation
zero_ops = [tv.assign(tf.zeros_like(tv)) for tv in accum_vars]
global_step = tf.Variable(1, trainable=False)
# Define optimization operations
with tf.device('/gpu:0'):
with tf.name_scope('train'):
#train_step = tf.train.AdamOptimizer(learning_rate, 0.9, 0.999).minimize(loss, global_step=global_step)
# Optimizer
optimizer = tf.train.AdamOptimizer(learning_rate, 0.9, 0.999)
# Calculate the gradient
grads = optimizer.compute_gradients(loss, vit_trainable_vars)
# Add this gradient class to
accum_ops = [accum_vars[i].assign_add(gv[0]) for i, gv in enumerate(grads)]
# Optimization parameters
train_step = optimizer.apply_gradients([(accum_vars[i], gv[1]) for i, gv in enumerate(grads)], global_step=global_step)
...
iter = 0
while True:
...
...
iter += 1
sess.run(accum_ops) # Accumulate two gradients
if iter % 2 == 0:
...
sess.run(train_step, feed_dict={...}) # Optimize the parameters once
...
sess.run(zero_ops) # Set the gradient to 0
...
This completes the calculation of the gradient twice , And accumulate the purpose of the return , amount to batch_size It's doubled .
It is worth noting that , If we don't specify the parameterization to save , The newly created Variable It will also be preserved , It will make our model larger , Therefore, only the parameters of the original model should be saved . example , I use it in practice :
var_to_save = [val for val in var if 'Adam' not in val.name and 'Variable_' not in val.name]
saver = tf.train.Saver(var_to_save, max_to_keep=None)
版权声明
本文为[NuerNuer]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204210545522575.html
边栏推荐
- LeetCode-279-完全平方数
- mmap、munmap
- Learn to C language fourth day
- UnhandledPromiseRejectionwarning:CastError: Cast to ObjectId failed for value
- The iswow64process function determines the number of program bits
- Preliminary understanding of cache elimination algorithm (LRU and LFU)
- 高薪程序员&面试题精讲系列91之Limit 20000加载很慢怎么解决?如何定位慢SQL?
- LeetCode 116. Populate the next right node pointer for each node
- On the three paradigms of database design
- CUDA, NVIDIA driver, cudnn download address and version correspondence
猜你喜欢

Commande dos pour la pénétration de l'Intranet

GSI-ECM工程建设管理数字化平台

MySQL basic collection

缓存淘汰算法初步认识(LRU和LFU)

A login and exit component based on token
![[matlab 2016 use mex command to find editor visual studio 2019]](/img/34/dd883f0ce4358234eb694287228687.png)
[matlab 2016 use mex command to find editor visual studio 2019]

2021-09-02 unity project uses rider to build hot change project failure record of ilruntime

Leetcode 74. Search two-dimensional matrix

PHP的Laravel与Composer部署项目时常见问题

Rust更适合经验较少的程序员?
随机推荐
Awk example skills
Syntax Error: TypeError: this. getOptions is not a function
Bracket matching -- [implementation of one-dimensional array]
The more you use the computer, the slower it will be? Recovery method of file accidental deletion
Solve the Chinese garbled code of URL in JS - decoding
Syntaxerror: unexpected token r in JSON at position 0
Zhongchuang storage | how to choose a useful distributed storage cloud disk
Addition, deletion, modification and query of advanced MySQL data (DML)
Rust更适合经验较少的程序员?
mmap、munmap
Summary and effect analysis of methods for calculating binocular parallax
Psychological formula for converting RGB to gray value
Some thoughts on super in pytorch, combined with code
IOT 设计与开发
Leetcode 20. Valid parentheses
小米手机全球已舍弃“MI”品牌,全面改用“xiaomi”全称品牌
Opencv reports an error. Expected PTR < CV:: UMAT > for argument '% s'‘
Realrange, reduce, repeat and einops in einops package layers. Rearrange and reduce in torch. Processing methods of high-dimensional data
Mysql database common sense storage engine
Explore ASP Net core read request The correct way of body