当前位置:网站首页>PyTorch 17. GPU concurrency
PyTorch 17. GPU concurrency
2022-04-23 07:29:00 【DCGJ666】
GPU Concurrent
many GPU Distributed parallel mechanism of operation
torch.nn.DataParallel
function : Packaging model , Implement the distribution parallel mechanism
torch.nn.DataParallel(module, device_ids = None, output_device=None, dim=0)
main parameter :
module: Models that need to be packaged and distributed
device_ids: Distributable gpu, Distribute to all visible and available by default gpu
output_device: Result output device
Combined with many previous blogs gpu Training and single gpu test , To view the , concrete
This blog
shortcoming :
- In each training batch (batch) in , Because the weight of the model is calculated first in one process , Then distribute them to every GPU On , So network communication has become a bottleneck , and GPU Usage is also usually low .
- besides ,nn.DataParallel Need all of GPU All on one node , And does not support Apex Mixed precision training for .
Use torch.distributed Speed up parallel training :
DataParallel: Single process control GPU
DistributedDataParallel: Multi process control GPU, Train the model together
Unlike single process training , Multi process training needs to pay attention to the following items :
- While feeding data , One batch Divided into multiple processes , When fetching data, each process should ensure that it gets different data (DistributedSampler)
- Tell each process who it is , Which piece to use GPU(args.local_rank)
- Doing it BN Pay attention to synchronizing data when .
Usage mode
In terms of multi process startup , We don't have to write by ourselves multiprocess Carry out a series of complex CPU、GPU Assigned tasks ,PyTorch It provides us with a very convenient starter torch.distributed.launch Used to start the file , So the way we run the training code becomes like this :
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py
Among them --nproc_per_node Parameter is used to specify the number of processes created for the current host , Because we are a single multi card , So here node The number of 1, Only the used... Is set here GPU Quantity is enough .
initialization
Start for us at the starter python After script , In the process of execution , The initiator will send the current index Pass parameters to python, We can get the current process in this way index: That is, through the parameter local_rank To tell us which process is currently using GPU, Used for us to specify different in each process device:
def parse():
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=0, help='node rank for distributed training')
args = parser.parse_args()
return args
def main():
args = parse()
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(
'nccl',
init_method='env://'
)
device = torch.device(f'cuda:{
args.local_rank}')
among torch.distributed.init_process_group For initialization GPU communication mode (NCLL) And how to get parameters (env Represents through the environment variable ). Use init_process_group Set up GPU Back end and port used for communication between , adopt NCCL Realization GPU signal communication
Dataloader
After we initialize data_loader You need to use torch.utils.data.distributed.DistributedSampler This feature :
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=..., sampler=train_sampler)
This gives each process a different sampler, Tell each process what data to fetch
Initialization of the model
and nn.DataParallel In the same way ,
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
Use DistributedDataParallel Packaging model , It can help us for different GPU The extraction obtained on all reduce( That is, summarize different GPU The calculated gradient , And synchronize the calculation results ).all reduce After different GPU The gradients in the model are all reduce Before GPU The mean of the gradient .
Sync BN
Existing standards BN Batch normalization can only be performed on a single card , Cross card synchronization BN The global samples can be normalized ,apex Can be used in
Multiple machines, multiple cards DDP(DistributedDataParallel)
Related concepts of process group
Group: Process group , In most cases DDP The processes of are under the same process group
world_size: Total number of processes ( In principle, one process Take up one GPU Is better )
rank: Sequence number of the current process , Used for inter program communication ,rank=0 The host of is master node
local_rank: Corresponding to the current process GPU Number
DDP The basic usage of ( The coding process )
- Use
torch.distributed.init_process_groupInitialize process group - Use
torch.nn.parallel.DistributedDataParallelCreate a distributed model - Use
torch.utils.data.distributed.DistributedSamplerestablish DataLoader - Adjust other necessary places
- Use
torch.distributed.launch / torch.multiprocessingor slurm Start training
Use apex Speed up ( Mixed precision training , Parallel training , Sync BN):
APEX It is a deep learning acceleration library from NVIDIA . Open source by NVIDIA , Perfect support PyTorch frame , A tool for changing the data format to reduce the occupation of model display memory . among amp(Automatic Mixed Precision), Use most of the operations of the model Float16 Data type testing , Some special operations are still used Float32. And users can perfectly migrate their training code to the model through only three lines of code .
apex Use
Amp:Automatic Mixed Precision
apex.amp It's a way to change scripts only 3 Line to enable mixed accuracy training tools , Through to the amp.initialize Offer different flags, Users can easily experiment with different pure precision and mixed precision training modes
Distributed Training
apex.parallel.DistributedDataParallel It's a module wrapper , Be similar to torch.nn.parallel.DistributedDataParallel, It supports convenient multi process distributed training , in the light of NVIDIA Of NCCL The communication library is optimized .
Synchronized Batch Normalization
apex.parallel.SyncBatchNorm Expanded torch.nn.modules.batchnorm.__BatchNorm To support synchronization BN. It reduces cross process statistics during multi process training . Sync BN Has been used for each GPU It can only accommodate a small local minibatch The situation of .
import apex
sync_bn_model = apex.parallel.convert_syncbn_model(model)
Checkpointing
In order to correctly save and load the reader's amp Training ,amp Introduced amp.state_dict(), It includes all loss_scalers And its corresponding non skipped steps , as well as amp.load_state_dict() To restore properties .
# Initialization
opt_level = 'O1'
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
# Train your model
...
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
...
# Save checkpoint
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'amp': amp.state_dict()
}
torch.save(checkpoint, 'amp_checkpoint.pt')
...
# Restore
model = ...
optimizer = ...
checkpoint = torch.load('amp_checkpoint.pt')
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
amp.load_state_dict(checkpoint['amp'])
# Continue training
Be careful
- When saving and loading model parameters, pay more attention to saving and loading amp.state_dict()
- model and optimizer All use amp.initialize() For packaging
User specified data format
amp.initialize(net, opt, opt_level="O1")
Among them opt-level Parameters are used to specify which data format to use for training
- O0: Adopt pure FP32 Training
- O1: Mixed precision training ( Recommended ), Use... Automatically according to the black and white list FP16 still FP32 Calculate
- O2: Mixed precision training , There is no black and white list , except BN, Almost all FP16 Calculation
- O3: pure FP16 Training
Spillover problem
because Float16 The number of data bits saved has become less , The absolute values of the upper and lower limits that can save data are also small . When dealing with summation , Such as sigmoid,softmax etc. , Can cause data overflow , Get the wrong result , For these operations , We want to use float32 As a data format , We just need to define in the model , In the constructor __init__() Add the following to be :
from apex import amp
class xxxNet(Module):
def __init__(using_map=False)
...
...
if using_amp:
amp.register_float_function(torch, 'sigmoid')
amp.register_float_function(torch, 'softmax')
and register_float_function Similar registration functions are :
- amp.register_half_function(module,function_name)
- amp.register_float_function (module, function_name)
- amp.register_promote_function (module, function_name)
版权声明
本文为[DCGJ666]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230611343745.html
边栏推荐
- SHA512/384 原理及C语言实现(附源码)
- 机器学习——朴素贝叶斯
- 《Attention in Natural Language Processing》翻译
- PyTorch 22. PyTorch常用代码段合集
- 【点云系列】Relationship-based Point Cloud Completion
- F. The wonderful use of pad
- UEFI学习01-ARM AARCH64编译、ArmPlatformPriPeiCore(SEC)
- Thanos. SH kill bully script, easily delete half of the files in the system at random
- FATFS FAT32学习小记
- Int8 quantification and inference of onnx model using TRT
猜你喜欢

基于openmv的无人机Apriltag动态追踪降落完整项目资料(labview+openmv+apriltag+正点原子四轴)

AUTOSAR从入门到精通100讲(五十二)-诊断和通信管理功能单元
![Gephi tutorial [1] installation](/img/f7/a37be7ac1af3216b7491e500760ad6.png)
Gephi tutorial [1] installation

Chapter 2 pytoch foundation 1

【期刊会议系列】IEEE系列模板下载指南
![[point cloud series] a rotation invariant framework for deep point cloud analysis](/img/43/065d552d216b3e253d25dcfdfaaff4.png)
[point cloud series] a rotation invariant framework for deep point cloud analysis

Chapter 2 pytoch foundation 2

CMSIS CM3源码注解

Chapter 5 fundamentals of machine learning

GIS实战应用案例100篇(五十一)-ArcGIS中根据指定的范围计算nc文件逐时次空间平均值的方法
随机推荐
Chapter 2 pytoch foundation 1
应急医疗通讯解决方案|MESH无线自组网系统
GIS实战应用案例100篇(三十四)-拼接2020globeland30
《Attention in Natural Language Processing》翻译
Wechat applet uses wxml2canvas plug-in to generate some problem records of pictures
Int8 quantification and inference of onnx model using TRT
F. The wonderful use of pad
【点云系列】PnP-3D: A Plug-and-Play for 3D Point Clouds
【點雲系列】SG-GAN: Adversarial Self-Attention GCN for Point Cloud Topological Parts Generation
Write a wechat double open gadget to your girlfriend
【3D形状重建系列】Implicit Functions in Feature Space for 3D Shape Reconstruction and Completion
scons 搭建嵌入式arm编译
【51单片机交通灯仿真】
Pep517 error during pycuda installation
Chapter 4 pytoch data processing toolbox
AUTOSAR从入门到精通100讲(八十四)-UDS之时间参数总结篇
Machine learning II: logistic regression classification based on Iris data set
【点云系列】点云隐式表达相关论文概要
x509解析
Solution to slow compilation speed of Xcode