当前位置:网站首页>pytorch框架学习(5)torchvision模块&训练一个简单的自己的CNN (二)
pytorch框架学习(5)torchvision模块&训练一个简单的自己的CNN (二)
2022-08-10 05:29:00 【Time.Xu】
上一篇文章我们讲解了
1、数据的增强方法 和 导入方法(ToTensor、归一化)
2、使用torchvision模块中定义好的数据集格式来规范加载数据的方式
3、展示我们导入的数据
4、使用已有的模型以及权重 等等
这一节我们继续把整个网络训练做完
优化器设置
训练模块
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False,filename=filename):
since = time.time()
best_acc = 0
""" checkpoint = torch.load(filename) best_acc = checkpoint['best_acc'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) model.class_to_idx = checkpoint['mapping'] """
model.to(device)
val_acc_history = []
train_acc_history = []
train_losses = []
valid_losses = []
LRs = [optimizer.param_groups[0]['lr']]
best_model_wts = copy.deepcopy(model.state_dict())
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
# 训练和验证
for phase in ['train', 'valid']:
if phase == 'train':
model.train() # 训练
else:
model.eval() # 验证
running_loss = 0.0
running_corrects = 0
# 把数据都取个遍
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
# 清零
optimizer.zero_grad()
# 只有训练的时候计算和更新梯度
with torch.set_grad_enabled(phase == 'train'):
if is_inception and phase == 'train':
outputs, aux_outputs = model(inputs)
loss1 = criterion(outputs, labels)
loss2 = criterion(aux_outputs, labels)
loss = loss1 + 0.4*loss2
else:#resnet执行的是这里
outputs = model(inputs)
loss = criterion(outputs, labels)
_, preds = torch.max(outputs, 1)
# 训练阶段更新权重
if phase == 'train':
loss.backward()
optimizer.step()
# 计算损失
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(dataloaders[phase].dataset)
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
time_elapsed = time.time() - since
print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
# 得到最好那次的模型
if phase == 'valid' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
state = {
'state_dict': model.state_dict(),
'best_acc': best_acc,
'optimizer' : optimizer.state_dict(),
}
torch.save(state, filename)
if phase == 'valid':
val_acc_history.append(epoch_acc)
valid_losses.append(epoch_loss)
scheduler.step(epoch_loss)
if phase == 'train':
train_acc_history.append(epoch_acc)
train_losses.append(epoch_loss)
print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
LRs.append(optimizer.param_groups[0]['lr'])
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))
# 训练完后用最好的一次当做模型最终的结果
model.load_state_dict(best_model_wts)
return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
#开始训练!、
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=20, is_inception=(model_name=="inception"))
训练之后就可以测试查看了:
边栏推荐
- SQLSERVER 2008 parses data in Json format
- ThreadPoolExecutor线程池原理
- 论文精度 —— 2017 ACM《Globally and Locally Consistent Image Completion》
- flinkcdc 读取pgsql 的时间被放大了 有大佬知道咋回事吗 gmt_create':1
- `id` bigint(20) unsigned NOT NULL COMMENT 'Database primary key',
- 大咖说·对话生态|当Confluent遇见云:实时流动的数据更有价值
- 接口调试还能这么玩?
- 常用工具系列 - 常用正则表达式
- Guys, is it normal that the oracle archive log grows by 3G in 20 minutes after running cdc?
- [Thesis Notes] Prototypical Contrast Adaptation for Domain Adaptive Semantic Segmentation
猜你喜欢
summer preschool assignments
【裴蜀定理】CF1055C Lucky Days
MySQL simple tutorial
Kubernetes:(十七)Helm概述、安装及配置
Depth of carding: prevent model fitting method
How to simulate the background API call scene, very detailed!
CORS跨域资源共享漏洞的原理与挖掘方法
canvas canvas drawing clock
手把手带你写嵌入式物联网的第一个项目
大咖说·对话生态|当Confluent遇见云:实时流动的数据更有价值
随机推荐
咨询cdc 2.0 for mysql不执行flush with read lock.怎么保证bin
`id` bigint(20) unsigned NOT NULL COMMENT 'Database primary key',
常用工具系列 - 常用正则表达式
FPGA engineer interview questions collection 1~10
一篇文章带你搞懂什么是幂等性问题?如何解决幂等性问题?
Practical skills 19: Several postures of List to Map List
`id` bigint(20) unsigned NOT NULL COMMENT '数据库主键',
k-近邻实现手写数字识别
栈与队列 | 用栈实现队列 | 用队列实现栈 | 基础理论与代码原理
虚拟土地价格暴跌85% 房地产泡沫破裂?依托炒作的暴富游戏需谨慎参与
应用在智能触摸遥控器中的触摸芯片
oracle cdc时,设置并行度2插槽数1,最终任务只有一个tm,是不是因为oracle不支持并发
FPGA工程师面试试题集锦11~20
Concurrency tool class - introduction and use of CountDownLatch, CyclicBarrier, Semaphore, Exchanger
Transforming into a product, is it reliable to take the NPDP test?
How to improve product quality from the code layer
Error when installing oracle rac 11g and executing root.sh
每周推荐短视频:探索AI的应用边界
Important transformation and upgrading
8.STM32F407之HAL库——PWM笔记