当前位置:网站首页>Pytorch neural network trainer
Pytorch neural network trainer
2022-04-23 11:17:00 【Herbie TZJ】
Writing the training code of neural network is a boring and highly repetitive process , So I wrote a trainer , The following functions can be realized :
- Save the network parameters with the least loss on the test set
- When you rerun the code , Automatically load the neural network parameter file ; When no file is detected , Use the method of he Kaiming to initialize the neural network parameters
- It can be used as the parent class of various neural network trainers , Just rewrite _forward Method can be reused
- Output training / The progress bar of the verification process
Environment and variables
import os
import numpy as np
import torch
import torch.nn.functional as F
SGD = torch.optim.SGD
Adam = torch.optim.Adam
Basic function
def param_init(neural_net):
''' Network parameter initialization '''
parameters = neural_net.state_dict()
for key in parameters:
if len(parameters[key].shape) >= 2:
parameters[key] = torch.nn.init.kaiming_normal_(parameters[key], a=0, mode='fan_in',
nonlinearity='leaky_relu')
def get_progress(current, target, bar_len=30):
''' current: Current number of completed tasks
target: Total tasks
bar_len: Progress bar length
return: Progress bar string '''
assert current <= target
percent = round(current / target * 100, 1)
unit = 100 / bar_len
solid = round(percent / unit)
hollow = bar_len - solid
return '■' * solid + '□' * hollow + f' {current}/{target}({percent}%)'
Trainer base class
Don't use it directly , Among them _forward Function needs to be rewritten ( See below )
class Trainer:
''' Trainer
net: A network model
net_file: Network model save path (.pt)
adam: Use Adam Optimizer
bar_len: Progress bar length '''
def __init__(self, net, net_file: str, lr: float,
adam: bool, bar_len: int):
# Set system parameters
self.net = net.cuda()
self._net_file = net_file
self._min_loss = np.inf
# Load network parameters
if os.path.isfile(self._net_file):
state_dict = torch.load(self._net_file)
self.net.load_state_dict(state_dict)
else:
param_init(self.net)
# Instantiation optimizer
parameters = self.net.parameters()
if adam:
self._optimizer = Adam(parameters, lr=lr)
else:
self._optimizer = SGD(parameters, lr=lr)
self._bar_len = bar_len
def train(self, train_set, profix='train'):
assert self._min_loss != np.inf, ' Please run first. eval function '
self.net.train()
avg_loss = self._forward(train_set, train=True, profix=profix)
return avg_loss
def eval(self, eval_set, profix='eval'):
self.net.eval()
avg_loss = self._forward(eval_set, train=False, profix=profix)
# Save the best performing network on the test set
if avg_loss <= self._min_loss:
self._min_loss = avg_loss
torch.save(self.net.state_dict(), self._net_file)
return avg_loss
def _forward(self, data_set, train: bool, profix):
''' data_set: Data sets
train: Training (bool)
profix: Progress bar prefix
return: Average loss '''
pass
Classification network trainer
Take image classification for example , Using cross entropy loss ,_forward The return value of the function needs to be the average loss
class Classfier(Trainer):
''' classifier
net: A network model
net_file: Network model save path (.pt)
adam: Use Adam Optimizer
bar_len: Progress bar length '''
def __init__(self, net, net_file: str, lr: float,
adam: bool = True, bar_len: int = 20):
super(Classfier, self).__init__(net, net_file, lr, adam, bar_len)
def _forward(self, data_set, train: bool, profix):
''' data_set: Data sets
train: Training (bool)
profix: Progress bar prefix
return: Average loss '''
# Batch information 、 Number of data
batch_num = len(data_set)
batch_size = data_set.batch_size
task_num = batch_num * batch_size
# initialization acc Calculator
counter = Acc_Counter(batch_size, task_num, profix, self._bar_len)
for idx, (image, target) in enumerate(data_set):
image, target = image.cuda(), target.cuda()
logits = self.net(image)
del image
# Conventional classification
loss = F.cross_entropy(logits, target)
if train:
# loss Back propagation gradient , And iterate
self._optimizer.zero_grad()
loss.backward()
self._optimizer.step()
# to update acc Calculator
avg_loss = counter.update(idx, logits, target, loss)
print()
return avg_loss
In this process Acc_Counter This class calculates the classification accuracy , At the same time, the progress bar , The code is as follows :
class Acc_Counter:
''' acc Calculator
batch_size: Batch size
task_num: Data volume
profix: Progress bar prefix
bar_len: Progress bar length '''
def __init__(self, batch_size, task_num, profix, bar_len):
self._batch_size = batch_size
self._task_num = task_num
self._correct = 0
self._figured = 0
self._loss_sum = 0
self._profix = profix
self._bar_len = bar_len
def update(self, idx, logits, target, loss):
''' idx: logits Indexes '''
# take logits The maximum index of is result
result = logits.argmax(1)
# calculated acc
self._correct += (result.eq(target).sum()).item()
self._figured += self._batch_size
acc = self._correct / self._figured * 100
# calculated avg_loss
loss = loss.item()
self._loss_sum += loss
avg_loss = self._loss_sum / (idx + 1)
# Output train / eval data
progress = get_progress(self._figured, self._task_num, bar_len=self._bar_len)
print(
f'\r{self._profix}: {progress}\tacc: {self._correct}/{self._figured}({acc:.2f}%)\tavg_loss: {avg_loss:.8f}',
end='')
return avg_loss
版权声明
本文为[Herbie TZJ]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204231117105931.html
边栏推荐
- Visual common drawing (V) scatter diagram
- Solve the problem of "suncertpathbuilderexception: unable to find valid certification path to requested target"
- Promise details
- MBA-day5数学-应用题-工程问题
- Typora operation skill description (I) md
- Software testers, how to mention bugs?
- MySQL8. 0 upgraded stepping on the pit Adventure
- Facing the global market, platefarm today logs in to four major global platforms such as Huobi
- PDMS soft lithography process
- 学习 Go 语言 0x04:《Go 语言之旅》中切片的练习题代码
猜你喜欢
随机推荐
Google Earth Engine(GEE)——将原始影像进行升尺度计算(以海南市为例)
MySQL数据库10秒内插入百万条数据的实现
活动进行时! 点击链接加入直播间参与“AI真的能节能吗?”的讨论吧!
VM set up static virtual machine
Oracle连通性测试小工具
Learn go language 0x05: the exercise code of map in go language journey
2022爱分析· 工业互联网厂商全景报告
CUMCM 2021-B:乙醇偶合制备C4烯烃(2)
PDMS soft lithography process
When the activity is in progress! Click the link to join the live studio to participate in "can AI really save energy?" Let's have a discussion!
Software testers, how to mention bugs?
Visual Road (XII) detailed explanation of collection class
Typora operation skill description (I) md
Facing the global market, platefarm today logs in to four major global platforms such as Huobi
Three web components (servlet, filter, listener)
Learning website materials
ffmpeg命令行常用参数
SVN的使用:
GO接口使用
Upgrade the functions available for cpolar intranet penetration