当前位置:网站首页>Pytorch neural network trainer
Pytorch neural network trainer
2022-04-23 11:17:00 【Herbie TZJ】
Writing the training code of neural network is a boring and highly repetitive process , So I wrote a trainer , The following functions can be realized :
- Save the network parameters with the least loss on the test set
- When you rerun the code , Automatically load the neural network parameter file ; When no file is detected , Use the method of he Kaiming to initialize the neural network parameters
- It can be used as the parent class of various neural network trainers , Just rewrite _forward Method can be reused
- Output training / The progress bar of the verification process
Environment and variables
import os
import numpy as np
import torch
import torch.nn.functional as F
SGD = torch.optim.SGD
Adam = torch.optim.Adam
Basic function
def param_init(neural_net):
''' Network parameter initialization '''
parameters = neural_net.state_dict()
for key in parameters:
if len(parameters[key].shape) >= 2:
parameters[key] = torch.nn.init.kaiming_normal_(parameters[key], a=0, mode='fan_in',
nonlinearity='leaky_relu')
def get_progress(current, target, bar_len=30):
''' current: Current number of completed tasks
target: Total tasks
bar_len: Progress bar length
return: Progress bar string '''
assert current <= target
percent = round(current / target * 100, 1)
unit = 100 / bar_len
solid = round(percent / unit)
hollow = bar_len - solid
return '■' * solid + '□' * hollow + f' {current}/{target}({percent}%)'
Trainer base class
Don't use it directly , Among them _forward Function needs to be rewritten ( See below )
class Trainer:
''' Trainer
net: A network model
net_file: Network model save path (.pt)
adam: Use Adam Optimizer
bar_len: Progress bar length '''
def __init__(self, net, net_file: str, lr: float,
adam: bool, bar_len: int):
# Set system parameters
self.net = net.cuda()
self._net_file = net_file
self._min_loss = np.inf
# Load network parameters
if os.path.isfile(self._net_file):
state_dict = torch.load(self._net_file)
self.net.load_state_dict(state_dict)
else:
param_init(self.net)
# Instantiation optimizer
parameters = self.net.parameters()
if adam:
self._optimizer = Adam(parameters, lr=lr)
else:
self._optimizer = SGD(parameters, lr=lr)
self._bar_len = bar_len
def train(self, train_set, profix='train'):
assert self._min_loss != np.inf, ' Please run first. eval function '
self.net.train()
avg_loss = self._forward(train_set, train=True, profix=profix)
return avg_loss
def eval(self, eval_set, profix='eval'):
self.net.eval()
avg_loss = self._forward(eval_set, train=False, profix=profix)
# Save the best performing network on the test set
if avg_loss <= self._min_loss:
self._min_loss = avg_loss
torch.save(self.net.state_dict(), self._net_file)
return avg_loss
def _forward(self, data_set, train: bool, profix):
''' data_set: Data sets
train: Training (bool)
profix: Progress bar prefix
return: Average loss '''
pass
Classification network trainer
Take image classification for example , Using cross entropy loss ,_forward The return value of the function needs to be the average loss
class Classfier(Trainer):
''' classifier
net: A network model
net_file: Network model save path (.pt)
adam: Use Adam Optimizer
bar_len: Progress bar length '''
def __init__(self, net, net_file: str, lr: float,
adam: bool = True, bar_len: int = 20):
super(Classfier, self).__init__(net, net_file, lr, adam, bar_len)
def _forward(self, data_set, train: bool, profix):
''' data_set: Data sets
train: Training (bool)
profix: Progress bar prefix
return: Average loss '''
# Batch information 、 Number of data
batch_num = len(data_set)
batch_size = data_set.batch_size
task_num = batch_num * batch_size
# initialization acc Calculator
counter = Acc_Counter(batch_size, task_num, profix, self._bar_len)
for idx, (image, target) in enumerate(data_set):
image, target = image.cuda(), target.cuda()
logits = self.net(image)
del image
# Conventional classification
loss = F.cross_entropy(logits, target)
if train:
# loss Back propagation gradient , And iterate
self._optimizer.zero_grad()
loss.backward()
self._optimizer.step()
# to update acc Calculator
avg_loss = counter.update(idx, logits, target, loss)
print()
return avg_loss
In this process Acc_Counter This class calculates the classification accuracy , At the same time, the progress bar , The code is as follows :
class Acc_Counter:
''' acc Calculator
batch_size: Batch size
task_num: Data volume
profix: Progress bar prefix
bar_len: Progress bar length '''
def __init__(self, batch_size, task_num, profix, bar_len):
self._batch_size = batch_size
self._task_num = task_num
self._correct = 0
self._figured = 0
self._loss_sum = 0
self._profix = profix
self._bar_len = bar_len
def update(self, idx, logits, target, loss):
''' idx: logits Indexes '''
# take logits The maximum index of is result
result = logits.argmax(1)
# calculated acc
self._correct += (result.eq(target).sum()).item()
self._figured += self._batch_size
acc = self._correct / self._figured * 100
# calculated avg_loss
loss = loss.item()
self._loss_sum += loss
avg_loss = self._loss_sum / (idx + 1)
# Output train / eval data
progress = get_progress(self._figured, self._task_num, bar_len=self._bar_len)
print(
f'\r{self._profix}: {progress}\tacc: {self._correct}/{self._figured}({acc:.2f}%)\tavg_loss: {avg_loss:.8f}',
end='')
return avg_loss
版权声明
本文为[Herbie TZJ]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204231117105931.html
边栏推荐
- MIT:用无监督为世界上每个像素都打上标签!人类:再也不用为1小时视频花800个小时了
- @valid,@Validated 的学习笔记
- Mba-day5 Mathematics - application problems - engineering problems
- ffmpeg命令行常用参数
- Typora operation skill description (I)
- MBA-day6 逻辑学-假言推理练习题
- MBA-day5数学-应用题-工程问题
- Get things technology network optimization - CDN resource request Optimization Practice
- SWAT - Introduction to Samba web management tool
- Mysql排序的特性详情
猜你喜欢
Google Earth Engine(GEE)——将原始影像进行升尺度计算(以海南市为例)
Canvas详解
Structure of C language (Advanced)
2022爱分析· 工业互联网厂商全景报告
Visual common drawing (III) area map
More reliable model art than deep learning
qt5.8 64 位静态库中想使用sqlite但静态库没有编译支持库的方法
初探 Lambda Powertools TypeScript
Excel · VBA custom function to obtain multiple cell values
C#的学习笔记【八】SQL【一】
随机推荐
采用百度飞桨EasyDL完成指定目标识别
Mysql8. 0 installation guide
@valid,@Validated 的学习笔记
MySQL面试题讲解之如何设置Hash索引
About the three commonly used auxiliary classes of JUC
升级cpolar内网穿透能获得的功能
MySQL failed to insert the datetime type field without single quotation marks
Learn go language 0x03: understand the dependency between variables and initialization order
Google Earth Engine(GEE)——将原始影像进行升尺度计算(以海南市为例)
Mysql排序的特性详情
Canvas详解
Cumcm 2021 - B: préparation d'oléfines C4 par couplage éthanol (2)
Visual Road (XII) detailed explanation of collection class
MIT: label every pixel in the world with unsupervised! Humans: no more 800 hours for an hour of video
CUMCM 2021-B:乙醇偶合制備C4烯烴(2)
Visual common drawing (V) scatter diagram
Mysql8.0安装指南
Mba-day6 logic - hypothetical reasoning exercises
After the MySQL router is reinstalled, it reconnects to the cluster for boot - a problem that has been configured in this host before
Database management software sqlpro for SQLite for Mac 2022.30