当前位置:网站首页>Tensorflow realizes gradient accumulation, and then returns
Tensorflow realizes gradient accumulation, and then returns
2022-04-23 20:48:00 【NuerNuer】
Because the host graphics card has only 12g Explicit memory of , And only one piece 30 Series of cards , Therefore, when running code, you will inevitably encounter batch_size Don't be too embarrassed , So you can use , The gradient accumulation method is optimized , To expand in disguise batch_size. This kind of operation is in pytorch Good implementation in , But in tf It's a little complicated .
Code up , explain :
def train():
...
...
# All trainable parameters
trainable_vars = tf.trainable_variables()
# Specify the parameters to be trained
vit_trainable_vars = [var for var in trainable_vars if 'VGG' not in var.name] #both generate and vision_transformer #291
print("************vars to train:",len(vit_trainable_vars))
# Define an operation in the calculation diagram , Create an etc. for each parameter to be trained shape Of all the 0 Variable
accum_vars = [tf.Variable(tf.zeros_like(tv.initialized_value()), trainable=False) for tv in vit_trainable_vars]
# Gradient zeroing operation
zero_ops = [tv.assign(tf.zeros_like(tv)) for tv in accum_vars]
global_step = tf.Variable(1, trainable=False)
# Define optimization operations
with tf.device('/gpu:0'):
with tf.name_scope('train'):
#train_step = tf.train.AdamOptimizer(learning_rate, 0.9, 0.999).minimize(loss, global_step=global_step)
# Optimizer
optimizer = tf.train.AdamOptimizer(learning_rate, 0.9, 0.999)
# Calculate the gradient
grads = optimizer.compute_gradients(loss, vit_trainable_vars)
# Add this gradient class to
accum_ops = [accum_vars[i].assign_add(gv[0]) for i, gv in enumerate(grads)]
# Optimization parameters
train_step = optimizer.apply_gradients([(accum_vars[i], gv[1]) for i, gv in enumerate(grads)], global_step=global_step)
...
iter = 0
while True:
...
...
iter += 1
sess.run(accum_ops) # Accumulate two gradients
if iter % 2 == 0:
...
sess.run(train_step, feed_dict={...}) # Optimize the parameters once
...
sess.run(zero_ops) # Set the gradient to 0
...
This completes the calculation of the gradient twice , And accumulate the purpose of the return , amount to batch_size It's doubled .
It is worth noting that , If we don't specify the parameterization to save , The newly created Variable It will also be preserved , It will make our model larger , Therefore, only the parameters of the original model should be saved . example , I use it in practice :
var_to_save = [val for val in var if 'Adam' not in val.name and 'Variable_' not in val.name]
saver = tf.train.Saver(var_to_save, max_to_keep=None)
版权声明
本文为[NuerNuer]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204210545522575.html
边栏推荐
- Async function ------ ES6
- 【SQL】字符串系列2:将一个字符串根据特定字符分拆成多行
- Some thoughts on super in pytorch, combined with code
- Introduction to intrusion detection data set
- Fastdfs思维导图
- UKFslam
- MySQL数据库常识之储存引擎
- 100天拿下11K,转岗测试的超全学习指南
- Bracket matching -- [implementation of one-dimensional array]
- Unity Odin ProgressBar add value column
猜你喜欢

LeetCode 116. Populate the next right node pointer for each node

Deep analysis of C language pointer (Part I)

wait、waitpid

Unity Odin ProgressBar add value column

Resolve the eslint warning -- ignore the warning that there is no space between the method name and ()

Latex formula

【SQL】字符串系列2:将一个字符串根据特定字符分拆成多行

GO语言开发天天生鲜项目第三天 案例-新闻发布系统二

內網滲透之DOS命令

What about laptop Caton? Teach you to reinstall the system with one click to "revive" the computer
随机推荐
Leetcode 20. Valid parentheses
Bracket matching -- [implementation of one-dimensional array]
Unity ECS dots notes
Learn to C language fourth day
Leetcode 232, queue with stack
The iswow64process function determines the number of program bits
[SQL] string series 2: split a string into multiple lines according to specific characters
Is qiniu school useful and is the recommended securities account safe
Preliminary understanding of cache elimination algorithm (LRU and LFU)
Communication between RING3 and ring0
Leetcode 542, 01 matrix
3-5 obtaining cookies through XSS and the use of XSS background management system
[matlab 2016 use mex command to find editor visual studio 2019]
Some thoughts on super in pytorch, combined with code
DOS command of Intranet penetration
Sequential state
GO语言开发天天生鲜项目第三天 案例-新闻发布系统二
Centralized record of experimental problems
启牛学堂有用吗,推荐的证券账户是否安全
Realrange, reduce, repeat and einops in einops package layers. Rearrange and reduce in torch. Processing methods of high-dimensional data