当前位置:网站首页>PyTorch 神经网络训练器
PyTorch 神经网络训练器
2022-04-23 11:17:00 【荷碧·TZJ】
写神经网络的训练代码是个枯燥又重复性极强的过程,为此我写了个训练器,可实现以下功能:
- 保存在测试集上损失最小的网络参数
- 重新运行代码时,自动加载神经网络参数文件;没有检测到文件时,使用何凯明大佬的方法初始化神经网络参数
- 可作为各种神经网络训练器的父类,只需要重写 _forward 方法即可复用
- 输出训练/验证过程的进度条
环境及变量
import os
import numpy as np
import torch
import torch.nn.functional as F
SGD = torch.optim.SGD
Adam = torch.optim.Adam
基础函数
def param_init(neural_net):
''' 网络参数初始化'''
parameters = neural_net.state_dict()
for key in parameters:
if len(parameters[key].shape) >= 2:
parameters[key] = torch.nn.init.kaiming_normal_(parameters[key], a=0, mode='fan_in',
nonlinearity='leaky_relu')
def get_progress(current, target, bar_len=30):
''' current: 当前完成任务数
target: 任务总数
bar_len: 进度条长度
return: 进度条字符串'''
assert current <= target
percent = round(current / target * 100, 1)
unit = 100 / bar_len
solid = round(percent / unit)
hollow = bar_len - solid
return '■' * solid + '□' * hollow + f' {current}/{target}({percent}%)'
训练器基类
不可直接使用,其中的 _forward 函数需要重写 (见下文)
class Trainer:
''' 训练器
net: 网络模型
net_file: 网络模型保存路径 (.pt)
adam: 使用 Adam 优化器
bar_len: 进度条长度'''
def __init__(self, net, net_file: str, lr: float,
adam: bool, bar_len: int):
# 设置系统参数
self.net = net.cuda()
self._net_file = net_file
self._min_loss = np.inf
# 载入网络参数
if os.path.isfile(self._net_file):
state_dict = torch.load(self._net_file)
self.net.load_state_dict(state_dict)
else:
param_init(self.net)
# 实例化优化器
parameters = self.net.parameters()
if adam:
self._optimizer = Adam(parameters, lr=lr)
else:
self._optimizer = SGD(parameters, lr=lr)
self._bar_len = bar_len
def train(self, train_set, profix='train'):
assert self._min_loss != np.inf, '请先运行 eval 函数'
self.net.train()
avg_loss = self._forward(train_set, train=True, profix=profix)
return avg_loss
def eval(self, eval_set, profix='eval'):
self.net.eval()
avg_loss = self._forward(eval_set, train=False, profix=profix)
# 保存在测试集上表现最好的网络
if avg_loss <= self._min_loss:
self._min_loss = avg_loss
torch.save(self.net.state_dict(), self._net_file)
return avg_loss
def _forward(self, data_set, train: bool, profix):
''' data_set: 数据集
train: 训练 (bool)
profix: 进度条前缀
return: 平均损失'''
pass
分类网络训练器
以图像分类为例,使用交叉熵损失,_forward 函数的返回值需要是平均损失
class Classfier(Trainer):
''' 分类器
net: 网络模型
net_file: 网络模型保存路径 (.pt)
adam: 使用 Adam 优化器
bar_len: 进度条长度'''
def __init__(self, net, net_file: str, lr: float,
adam: bool = True, bar_len: int = 20):
super(Classfier, self).__init__(net, net_file, lr, adam, bar_len)
def _forward(self, data_set, train: bool, profix):
''' data_set: 数据集
train: 训练 (bool)
profix: 进度条前缀
return: 平均损失'''
# 批信息、数据数
batch_num = len(data_set)
batch_size = data_set.batch_size
task_num = batch_num * batch_size
# 初始化 acc 计算器
counter = Acc_Counter(batch_size, task_num, profix, self._bar_len)
for idx, (image, target) in enumerate(data_set):
image, target = image.cuda(), target.cuda()
logits = self.net(image)
del image
# 常规分类
loss = F.cross_entropy(logits, target)
if train:
# loss 反向传播梯度,并迭代
self._optimizer.zero_grad()
loss.backward()
self._optimizer.step()
# 更新 acc 计算器
avg_loss = counter.update(idx, logits, target, loss)
print()
return avg_loss
在这个过程中使用了 Acc_Counter 这个类计算分类准确率,同时输出进度条,其代码如下:
class Acc_Counter:
''' acc 计算器
batch_size: 批大小
task_num: 数据量
profix: 进度条前缀
bar_len: 进度条长度'''
def __init__(self, batch_size, task_num, profix, bar_len):
self._batch_size = batch_size
self._task_num = task_num
self._correct = 0
self._figured = 0
self._loss_sum = 0
self._profix = profix
self._bar_len = bar_len
def update(self, idx, logits, target, loss):
''' idx: logits索引'''
# 取 logits 的最大值索引做为 result
result = logits.argmax(1)
# 计算得出 acc
self._correct += (result.eq(target).sum()).item()
self._figured += self._batch_size
acc = self._correct / self._figured * 100
# 计算得出 avg_loss
loss = loss.item()
self._loss_sum += loss
avg_loss = self._loss_sum / (idx + 1)
# 输出 train / eval 数据
progress = get_progress(self._figured, self._task_num, bar_len=self._bar_len)
print(
f'\r{self._profix}: {progress}\tacc: {self._correct}/{self._figured}({acc:.2f}%)\tavg_loss: {avg_loss:.8f}',
end='')
return avg_loss
版权声明
本文为[荷碧·TZJ]所创,转载请带上原文链接,感谢
https://hebitzj.blog.csdn.net/article/details/124332723
边栏推荐
- SVN的使用:
- 学习 Go 语言 0x06:《Go 语言之旅》中 斐波纳契闭包 练习题代码
- How to quickly query 10 million pieces of data in MySQL
- mysql分表之后如何平滑上线详解
- Cumcm 2021 - B: préparation d'oléfines C4 par couplage éthanol (2)
- CUMCM 2021-B:乙醇偶合制備C4烯烴(2)
- Detailed explanation of typora Grammar (I)
- remote: Support for password authentication was removed on August 13, 2021.
- redis优化系列(二)Redis主从原理、主从常用配置
- After the MySQL router is reinstalled, it reconnects to the cluster for boot - a problem that has been configured in this host before
猜你喜欢
进程间通信 -- 消息队列
2022爱分析· 工业互联网厂商全景报告
Promise详解
MIT: label every pixel in the world with unsupervised! Humans: no more 800 hours for an hour of video
Canvas详解
学习 Go 语言 0x04:《Go 语言之旅》中切片的练习题代码
Database management software sqlpro for SQLite for Mac 2022.30
qt5.8 64 位静态库中想使用sqlite但静态库没有编译支持库的方法
Constraintlayout layout
Visual common drawing (III) area map
随机推荐
学习 Go 语言 0x03:理解变量之间的依赖以及初始化顺序
面向全球市场,PlatoFarm今日登录HUOBI等全球四大平台
oh-my-lotto
Learning go language 0x02: understanding slice
PDMS soft lithography process
使用 PHP PDO ODBC 示例的 Microsoft Access 数据库
Database management software sqlpro for SQLite for Mac 2022.30
mysql中整数数据类型tinyint详解
R-drop: a more powerful dropout regularization method
活动进行时! 点击链接加入直播间参与“AI真的能节能吗?”的讨论吧!
Google Earth Engine(GEE)——将原始影像进行升尺度计算(以海南市为例)
Usage Summary of datetime and timestamp in MySQL
Mysql8.0安装指南
SWAT - Introduction to Samba web management tool
Excel·VBA数组冒泡排序函数
Go interface usage
MySQL8. 0 upgraded stepping on the pit Adventure
Promise详解
On lambda powertools typescript
妊娠箱和分娩箱的区别