当前位置:网站首页>PyTorch 神经网络训练器
PyTorch 神经网络训练器
2022-04-23 11:17:00 【荷碧·TZJ】
写神经网络的训练代码是个枯燥又重复性极强的过程,为此我写了个训练器,可实现以下功能:
- 保存在测试集上损失最小的网络参数
- 重新运行代码时,自动加载神经网络参数文件;没有检测到文件时,使用何凯明大佬的方法初始化神经网络参数
- 可作为各种神经网络训练器的父类,只需要重写 _forward 方法即可复用
- 输出训练/验证过程的进度条
环境及变量
import os
import numpy as np
import torch
import torch.nn.functional as F
SGD = torch.optim.SGD
Adam = torch.optim.Adam
基础函数
def param_init(neural_net):
''' 网络参数初始化'''
parameters = neural_net.state_dict()
for key in parameters:
if len(parameters[key].shape) >= 2:
parameters[key] = torch.nn.init.kaiming_normal_(parameters[key], a=0, mode='fan_in',
nonlinearity='leaky_relu')
def get_progress(current, target, bar_len=30):
''' current: 当前完成任务数
target: 任务总数
bar_len: 进度条长度
return: 进度条字符串'''
assert current <= target
percent = round(current / target * 100, 1)
unit = 100 / bar_len
solid = round(percent / unit)
hollow = bar_len - solid
return '■' * solid + '□' * hollow + f' {current}/{target}({percent}%)'
训练器基类
不可直接使用,其中的 _forward 函数需要重写 (见下文)
class Trainer:
''' 训练器
net: 网络模型
net_file: 网络模型保存路径 (.pt)
adam: 使用 Adam 优化器
bar_len: 进度条长度'''
def __init__(self, net, net_file: str, lr: float,
adam: bool, bar_len: int):
# 设置系统参数
self.net = net.cuda()
self._net_file = net_file
self._min_loss = np.inf
# 载入网络参数
if os.path.isfile(self._net_file):
state_dict = torch.load(self._net_file)
self.net.load_state_dict(state_dict)
else:
param_init(self.net)
# 实例化优化器
parameters = self.net.parameters()
if adam:
self._optimizer = Adam(parameters, lr=lr)
else:
self._optimizer = SGD(parameters, lr=lr)
self._bar_len = bar_len
def train(self, train_set, profix='train'):
assert self._min_loss != np.inf, '请先运行 eval 函数'
self.net.train()
avg_loss = self._forward(train_set, train=True, profix=profix)
return avg_loss
def eval(self, eval_set, profix='eval'):
self.net.eval()
avg_loss = self._forward(eval_set, train=False, profix=profix)
# 保存在测试集上表现最好的网络
if avg_loss <= self._min_loss:
self._min_loss = avg_loss
torch.save(self.net.state_dict(), self._net_file)
return avg_loss
def _forward(self, data_set, train: bool, profix):
''' data_set: 数据集
train: 训练 (bool)
profix: 进度条前缀
return: 平均损失'''
pass
分类网络训练器
以图像分类为例,使用交叉熵损失,_forward 函数的返回值需要是平均损失
class Classfier(Trainer):
''' 分类器
net: 网络模型
net_file: 网络模型保存路径 (.pt)
adam: 使用 Adam 优化器
bar_len: 进度条长度'''
def __init__(self, net, net_file: str, lr: float,
adam: bool = True, bar_len: int = 20):
super(Classfier, self).__init__(net, net_file, lr, adam, bar_len)
def _forward(self, data_set, train: bool, profix):
''' data_set: 数据集
train: 训练 (bool)
profix: 进度条前缀
return: 平均损失'''
# 批信息、数据数
batch_num = len(data_set)
batch_size = data_set.batch_size
task_num = batch_num * batch_size
# 初始化 acc 计算器
counter = Acc_Counter(batch_size, task_num, profix, self._bar_len)
for idx, (image, target) in enumerate(data_set):
image, target = image.cuda(), target.cuda()
logits = self.net(image)
del image
# 常规分类
loss = F.cross_entropy(logits, target)
if train:
# loss 反向传播梯度,并迭代
self._optimizer.zero_grad()
loss.backward()
self._optimizer.step()
# 更新 acc 计算器
avg_loss = counter.update(idx, logits, target, loss)
print()
return avg_loss
在这个过程中使用了 Acc_Counter 这个类计算分类准确率,同时输出进度条,其代码如下:
class Acc_Counter:
''' acc 计算器
batch_size: 批大小
task_num: 数据量
profix: 进度条前缀
bar_len: 进度条长度'''
def __init__(self, batch_size, task_num, profix, bar_len):
self._batch_size = batch_size
self._task_num = task_num
self._correct = 0
self._figured = 0
self._loss_sum = 0
self._profix = profix
self._bar_len = bar_len
def update(self, idx, logits, target, loss):
''' idx: logits索引'''
# 取 logits 的最大值索引做为 result
result = logits.argmax(1)
# 计算得出 acc
self._correct += (result.eq(target).sum()).item()
self._figured += self._batch_size
acc = self._correct / self._figured * 100
# 计算得出 avg_loss
loss = loss.item()
self._loss_sum += loss
avg_loss = self._loss_sum / (idx + 1)
# 输出 train / eval 数据
progress = get_progress(self._figured, self._task_num, bar_len=self._bar_len)
print(
f'\r{self._profix}: {progress}\tacc: {self._correct}/{self._figured}({acc:.2f}%)\tavg_loss: {avg_loss:.8f}',
end='')
return avg_loss
版权声明
本文为[荷碧·TZJ]所创,转载请带上原文链接,感谢
https://hebitzj.blog.csdn.net/article/details/124332723
边栏推荐
- vm设置静态虚拟机
- Excel·VBA自定义函数获取单元格多数值
- Facing the global market, platefarm today logs in to four major global platforms such as Huobi
- Typora operation skill description (I)
- MIT:用无监督为世界上每个像素都打上标签!人类:再也不用为1小时视频花800个小时了
- MySQL8.0升级的踩坑历险记
- Detailed explanation of how to smoothly go online after MySQL table splitting
- An interesting interview question
- Canvas详解
- 26. Delete duplicates in ordered array
猜你喜欢

Promise details

Visual common drawing (V) scatter diagram

Mysql8. 0 installation guide

年度最尴尬的社死瞬间,是Siri给的

26. 删除有序数组中的重复项

Database management software sqlpro for SQLite for Mac 2022.30

Introduction to neo4j authoritative guide, recommended by Qiu Bojun, Zhou Hongxiang, Hu Xiaofeng, Zhou Tao and other celebrities

比深度学习更值得信赖的模型ART

Structure of C language (Advanced)

Excel·VBA自定义函数获取单元格多数值
随机推荐
学习 Go 语言 0x02:对切片 Slice 的理解
Excel · VBA array bubble sorting function
Oracle连通性测试小工具
Visual solutions to common problems (VIII) mathematical formulas
Difference between pregnancy box and delivery box
Visualization Road (11) detailed explanation of Matplotlib color
Using Baidu PaddlePaddle EasyDL to accomplish specified target recognition
MIT: label every pixel in the world with unsupervised! Humans: no more 800 hours for an hour of video
Usage of rename in cygwin
防止web项目中的SQL注入
vm设置静态虚拟机
Prevent SQL injection in web projects
《Neo4j权威指南》简介,求伯君、周鸿袆、胡晓峰、周涛等大咖隆重推荐
学习 Go 语言 0x01:从官网开始
Visual common drawing (III) area map
Usage Summary of datetime and timestamp in MySQL
Google Earth engine (GEE) - scale up the original image (taking Hainan as an example)
学习 Go 语言 0x07:《Go 语言之旅》中 Stringer 练习题代码
PDMS软光刻加工过程
讯飞2021年营收183亿:同比增41% 净利为15.56亿