当前位置:网站首页>Tensorflow realizes gradient accumulation, and then returns
Tensorflow realizes gradient accumulation, and then returns
2022-04-23 20:48:00 【NuerNuer】
Because the host graphics card has only 12g Explicit memory of , And only one piece 30 Series of cards , Therefore, when running code, you will inevitably encounter batch_size Don't be too embarrassed , So you can use , The gradient accumulation method is optimized , To expand in disguise batch_size. This kind of operation is in pytorch Good implementation in , But in tf It's a little complicated .
Code up , explain :
def train():
...
...
# All trainable parameters
trainable_vars = tf.trainable_variables()
# Specify the parameters to be trained
vit_trainable_vars = [var for var in trainable_vars if 'VGG' not in var.name] #both generate and vision_transformer #291
print("************vars to train:",len(vit_trainable_vars))
# Define an operation in the calculation diagram , Create an etc. for each parameter to be trained shape Of all the 0 Variable
accum_vars = [tf.Variable(tf.zeros_like(tv.initialized_value()), trainable=False) for tv in vit_trainable_vars]
# Gradient zeroing operation
zero_ops = [tv.assign(tf.zeros_like(tv)) for tv in accum_vars]
global_step = tf.Variable(1, trainable=False)
# Define optimization operations
with tf.device('/gpu:0'):
with tf.name_scope('train'):
#train_step = tf.train.AdamOptimizer(learning_rate, 0.9, 0.999).minimize(loss, global_step=global_step)
# Optimizer
optimizer = tf.train.AdamOptimizer(learning_rate, 0.9, 0.999)
# Calculate the gradient
grads = optimizer.compute_gradients(loss, vit_trainable_vars)
# Add this gradient class to
accum_ops = [accum_vars[i].assign_add(gv[0]) for i, gv in enumerate(grads)]
# Optimization parameters
train_step = optimizer.apply_gradients([(accum_vars[i], gv[1]) for i, gv in enumerate(grads)], global_step=global_step)
...
iter = 0
while True:
...
...
iter += 1
sess.run(accum_ops) # Accumulate two gradients
if iter % 2 == 0:
...
sess.run(train_step, feed_dict={...}) # Optimize the parameters once
...
sess.run(zero_ops) # Set the gradient to 0
...
This completes the calculation of the gradient twice , And accumulate the purpose of the return , amount to batch_size It's doubled .
It is worth noting that , If we don't specify the parameterization to save , The newly created Variable It will also be preserved , It will make our model larger , Therefore, only the parameters of the original model should be saved . example , I use it in practice :
var_to_save = [val for val in var if 'Adam' not in val.name and 'Variable_' not in val.name]
saver = tf.train.Saver(var_to_save, max_to_keep=None)
版权声明
本文为[NuerNuer]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204210545522575.html
边栏推荐
- Summary and effect analysis of methods for calculating binocular parallax
- LeetCode 116. Populate the next right node pointer for each node
- Common problems in deploying projects with laravel and composer for PHP
- Flex layout
- 阿里云回应用户注册信息泄露事件
- "Meta function" of tidb 6.0: what is placement rules in SQL?
- MySQL进阶之表的增删改查
- Bash script learning -- for loop traversal
- UnhandledPromiseRejectionwarning:CastError: Cast to ObjectId failed for value
- The more you use the computer, the slower it will be? Recovery method of file accidental deletion
猜你喜欢

Deep analysis of C language function

缓存淘汰算法初步认识(LRU和LFU)

UnhandledPromiseRejectionwarning:CastError: Cast to ObjectId failed for value

Google 尝试在 Chrome 中使用 Rust

Go language development Daily Fresh Project Day 3 Case - Press Release System II

GO語言開發天天生鮮項目第三天 案例-新聞發布系統二

Vulnhub DC: 1 penetration notes

41. The first missing positive number

Come in and teach you how to solve the problem of port occupation

Leetcode 74. Search two-dimensional matrix
随机推荐
CONDA environment management command
Parsing methods of JSON data in C - jar and jobobject: error reading jar from jsonreader Current JsonReader item
On IRP from the perspective of source code
Addition, deletion, modification and query of advanced MySQL data (DML)
居家第二十三天的午饭
Leetcode 1351. Negative numbers in statistical ordered matrices
Unity animation creates sequence frame code and generates animationclip
IOT 设计与开发
PHP的Laravel与Composer部署项目时常见问题
Mysql database common sense storage engine
UKFslam
3-5通过XSS获取cookie以及XSS后台管理系统的使用
笔记本电脑卡顿怎么办?教你一键重装系统让电脑“复活”
缓存淘汰算法初步认识(LRU和LFU)
MySQL基础之写表(创建表)
Syntaxerror: unexpected token r in JSON at position 0
How many hacking methods do you know?
常用60类图表使用场景、制作工具推荐
Matlab: psychtoolbox installation
How to do after winning the new debt? Is it safe to open an account online