当前位置:网站首页>Pytorch neural network trainer
Pytorch neural network trainer
2022-04-23 11:17:00 【Herbie TZJ】
Writing the training code of neural network is a boring and highly repetitive process , So I wrote a trainer , The following functions can be realized :
- Save the network parameters with the least loss on the test set
- When you rerun the code , Automatically load the neural network parameter file ; When no file is detected , Use the method of he Kaiming to initialize the neural network parameters
- It can be used as the parent class of various neural network trainers , Just rewrite _forward Method can be reused
- Output training / The progress bar of the verification process
Environment and variables
import os
import numpy as np
import torch
import torch.nn.functional as F
SGD = torch.optim.SGD
Adam = torch.optim.Adam
Basic function
def param_init(neural_net):
''' Network parameter initialization '''
parameters = neural_net.state_dict()
for key in parameters:
if len(parameters[key].shape) >= 2:
parameters[key] = torch.nn.init.kaiming_normal_(parameters[key], a=0, mode='fan_in',
nonlinearity='leaky_relu')
def get_progress(current, target, bar_len=30):
''' current: Current number of completed tasks
target: Total tasks
bar_len: Progress bar length
return: Progress bar string '''
assert current <= target
percent = round(current / target * 100, 1)
unit = 100 / bar_len
solid = round(percent / unit)
hollow = bar_len - solid
return '■' * solid + '□' * hollow + f' {current}/{target}({percent}%)'
Trainer base class
Don't use it directly , Among them _forward Function needs to be rewritten ( See below )
class Trainer:
''' Trainer
net: A network model
net_file: Network model save path (.pt)
adam: Use Adam Optimizer
bar_len: Progress bar length '''
def __init__(self, net, net_file: str, lr: float,
adam: bool, bar_len: int):
# Set system parameters
self.net = net.cuda()
self._net_file = net_file
self._min_loss = np.inf
# Load network parameters
if os.path.isfile(self._net_file):
state_dict = torch.load(self._net_file)
self.net.load_state_dict(state_dict)
else:
param_init(self.net)
# Instantiation optimizer
parameters = self.net.parameters()
if adam:
self._optimizer = Adam(parameters, lr=lr)
else:
self._optimizer = SGD(parameters, lr=lr)
self._bar_len = bar_len
def train(self, train_set, profix='train'):
assert self._min_loss != np.inf, ' Please run first. eval function '
self.net.train()
avg_loss = self._forward(train_set, train=True, profix=profix)
return avg_loss
def eval(self, eval_set, profix='eval'):
self.net.eval()
avg_loss = self._forward(eval_set, train=False, profix=profix)
# Save the best performing network on the test set
if avg_loss <= self._min_loss:
self._min_loss = avg_loss
torch.save(self.net.state_dict(), self._net_file)
return avg_loss
def _forward(self, data_set, train: bool, profix):
''' data_set: Data sets
train: Training (bool)
profix: Progress bar prefix
return: Average loss '''
pass
Classification network trainer
Take image classification for example , Using cross entropy loss ,_forward The return value of the function needs to be the average loss
class Classfier(Trainer):
''' classifier
net: A network model
net_file: Network model save path (.pt)
adam: Use Adam Optimizer
bar_len: Progress bar length '''
def __init__(self, net, net_file: str, lr: float,
adam: bool = True, bar_len: int = 20):
super(Classfier, self).__init__(net, net_file, lr, adam, bar_len)
def _forward(self, data_set, train: bool, profix):
''' data_set: Data sets
train: Training (bool)
profix: Progress bar prefix
return: Average loss '''
# Batch information 、 Number of data
batch_num = len(data_set)
batch_size = data_set.batch_size
task_num = batch_num * batch_size
# initialization acc Calculator
counter = Acc_Counter(batch_size, task_num, profix, self._bar_len)
for idx, (image, target) in enumerate(data_set):
image, target = image.cuda(), target.cuda()
logits = self.net(image)
del image
# Conventional classification
loss = F.cross_entropy(logits, target)
if train:
# loss Back propagation gradient , And iterate
self._optimizer.zero_grad()
loss.backward()
self._optimizer.step()
# to update acc Calculator
avg_loss = counter.update(idx, logits, target, loss)
print()
return avg_loss
In this process Acc_Counter This class calculates the classification accuracy , At the same time, the progress bar , The code is as follows :
class Acc_Counter:
''' acc Calculator
batch_size: Batch size
task_num: Data volume
profix: Progress bar prefix
bar_len: Progress bar length '''
def __init__(self, batch_size, task_num, profix, bar_len):
self._batch_size = batch_size
self._task_num = task_num
self._correct = 0
self._figured = 0
self._loss_sum = 0
self._profix = profix
self._bar_len = bar_len
def update(self, idx, logits, target, loss):
''' idx: logits Indexes '''
# take logits The maximum index of is result
result = logits.argmax(1)
# calculated acc
self._correct += (result.eq(target).sum()).item()
self._figured += self._batch_size
acc = self._correct / self._figured * 100
# calculated avg_loss
loss = loss.item()
self._loss_sum += loss
avg_loss = self._loss_sum / (idx + 1)
# Output train / eval data
progress = get_progress(self._figured, self._task_num, bar_len=self._bar_len)
print(
f'\r{self._profix}: {progress}\tacc: {self._correct}/{self._figured}({acc:.2f}%)\tavg_loss: {avg_loss:.8f}',
end='')
return avg_loss
版权声明
本文为[Herbie TZJ]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204231117105931.html
边栏推荐
- The songbird document editor will be open source: starting with but not limited to markdown
- MBA-day5数学-应用题-工程问题
- Visualization Road (10) detailed explanation of segmentation canvas function
- MBA-day5数学-应用题-工程问题
- 面向全球市场,PlatoFarm今日登录HUOBI等全球四大平台
- FileProvider 路径配置策略的理解
- Mba-day5 Mathematics - application problems - engineering problems
- Implementation of partition table of existing data table by MySQL
- VM set up static virtual machine
- stylecloud ,wordcloud 库学习及使用例子
猜你喜欢

ConstraintLayout布局

Learn go language 0x04: Code of exercises sliced in go language journey

Google Earth Engine(GEE)——将原始影像进行升尺度计算(以海南市为例)

Use of SVN:

About the three commonly used auxiliary classes of JUC

Using Baidu PaddlePaddle EasyDL to accomplish specified target recognition

学习 Go 语言 0x04:《Go 语言之旅》中切片的练习题代码

CUMCM 2021-B:乙醇偶合制備C4烯烴(2)

关于JUC三大常用辅助类

Get things technology network optimization - CDN resource request Optimization Practice
随机推荐
Visualization Road (11) detailed explanation of Matplotlib color
解决 『SunCertPathBuilderException:unable to find valid certification path to requested target』 问题
初探 Lambda Powertools TypeScript
MySQL8. 0 upgraded stepping on the pit Adventure
Mba-day5 Mathematics - application problems - engineering problems
语雀文档编辑器将开源:始于但不止于Markdown
Mysql排序的特性详情
vm设置静态虚拟机
Implementation of partition table of existing data table by MySQL
Excel·VBA数组冒泡排序函数
详解MySQL中timestamp和datetime时区问题导致做DTS遇到的坑
On lambda powertools typescript
MySQL partition table can be classified by month
Jupyter lab top ten high productivity plug-ins
Typora operation skill description (I)
When the activity is in progress! Click the link to join the live studio to participate in "can AI really save energy?" Let's have a discussion!
Understanding of fileprovider path configuration strategy
Visualized common drawing (II) line chart
@Valid, @ validated learning notes
Oracle connectivity test gadget