当前位置:网站首页>pytorch-09. Multi-classification problem
pytorch-09. Multi-classification problem
2022-08-10 05:56:00 【Shengxin Research Ape】
NLLLoss(Negative Log Likelihood Loss),最大似然函数. 把Label对应的输出log_result值拿出来,求和取平均. --------------------------------------------------------------------------------------- CrossEntropyLoss交叉熵损失函数. 一步执行完:softmax+log+NLLLoss合并起来了.
NLLLoss
CrossEntropyLoss
softmax:
CrossEntropyLoss示例:
#torch.LongTensor是64位整型 #torch.Tensor默认torch.FloatTensor,是32位浮点类型数据. #torch.tensor是一个类,Used to generate a single-precision floating-point tensor.
import torch
criterion = torch.nn.CrossEntropyLoss()
Y = torch.LongTensor([2,0,1])
Y_pred1 = torch.Tensor([
[0.1,0.2,0.9],
[1.1,0.1,0.2],
[0.2,2.1,0.1]
])
Y_pred2 = torch.Tensor([
[0.8,0.2,0.3],
[0.2,0.3,0.5],
[0.2,0.2,0.5]
])
l1 = criterion(Y_pred1,Y)
l2 = criterion(Y_pred2,Y)
print("Loss1 = ",l1.data.item(),"\nLoss2 = ",l2.data.item())
mnist数据集实践
#minstThe mean of the dataset is0.1307,标准差是0.3081
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
batch_size = 64
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307),(0.3081))
])
train_dataset = datasets.MNIST(root='../dataset/mnist/',train=True,download=True,transform=transform)
train_loader = DataLoader(train_dataset,shuffle = True,batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist/',train=False,download=True,transform=transform)
test_loader = DataLoader(train_dataset,shuffle = False,batch_size=batch_size)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l1 = torch.nn.Linear(784,512)
self.l2 = torch.nn.Linear(512,256)
self.l3 = torch.nn.Linear(256, 128)
self.l4 = torch.nn.Linear(128, 64)
self.l5 = torch.nn.Linear(64, 10)
def forward(self,x):
x = x.view(-1,784)
x = F.relu(self.l1(x))
x = F.relu(self.l2(x))
x = F.relu(self.l3(x))
x = F.relu(self.l4(x))
return self.l5(x) #最后一层不做激活,不进行非线性变换
model = Net()
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr=0.01,momentum=0.5)
def train(epoch):
running_loss = 0.0
for batch_idx,data in enumerate(train_loader,0):
inputs, target = data
optimizer.zero_grad()
#forward + backward + update
outputs = model(inputs)
loss = criterion(outputs,target)
loss.backward()
optimizer.step()
running_loss +=loss.item()
if batch_idx % 300 ==299:
print('[%d,%5d]loss:%.3f'%(epoch+1,batch_idx+1,running_loss/300))
running_loss = 0
def test():
correct = 0
total = 0
with torch.no_grad():
for data in test_loader:
images,labels = data
outputs = model(images)
_,predicted = torch.max(outputs.data,dim=1) #dim=1维度1,行是第0个维度,列是第1个维度
total +=labels.size(0)
correct +=(predicted==labels).sum().item()
print('Accuracy on test set:%d %%'%(100*correct/total) )
if __name__ == '__main__':
for epoch in range(10):
train(epoch)
test()
结果:
[1, 300]loss:2.223
[1, 600]loss:0.923
[1, 900]loss:0.435
Accuracy on test set:89 %
[2, 300]loss:0.328
[2, 600]loss:0.272
[2, 900]loss:0.239
Accuracy on test set:94 %
[3, 300]loss:0.188
[3, 600]loss:0.175
[3, 900]loss:0.158
Accuracy on test set:96 %
[4, 300]loss:0.126
[4, 600]loss:0.130
[4, 900]loss:0.121
Accuracy on test set:97 %
[5, 300]loss:0.098
[5, 600]loss:0.099
[5, 900]loss:0.097
Accuracy on test set:97 %
[6, 300]loss:0.078
[6, 600]loss:0.079
[6, 900]loss:0.081
Accuracy on test set:97 %
[7, 300]loss:0.066
[7, 600]loss:0.063
[7, 900]loss:0.064
Accuracy on test set:98 %
[8, 300]loss:0.050
[8, 600]loss:0.056
[8, 900]loss:0.051
Accuracy on test set:98 %
[9, 300]loss:0.041
[9, 600]loss:0.043
[9, 900]loss:0.043
Accuracy on test set:99 %
[10, 300]loss:0.034
[10, 600]loss:0.034
[10, 900]loss:0.037
Accuracy on test set:99 %
Process finished with exit code 0
边栏推荐
猜你喜欢
One step ahead, don't miss it again, the chain reading APP will be launched soon!
Day1 微信小程序-小程序代码的构成
LeetCode 面试题17.14 最小k个数(中等)
LeetCode 剑指offer 21.调整数组顺序使奇数位于偶数前面(简单)
卷积神经网络(CNN)实现服装图像分类
栈和队列
复杂的“元宇宙”,为您解读,链读APP即将上线!
【List练习】遍历集合并且按照价格从低到高排序,
Chained Picks: Starbucks looks at digital collectibles and better engages customers
深度学习TensorFlow入门环境配置
随机推荐
LeetCode 1351.统计有序矩阵中的负数(简单)
Count down the six weapons of the domestic interface collaboration platform!
学生管理系统以及其简单功能的实现
win12 modify dns script
el-dropdown drop-down menu style modification, remove the small triangle
opencv
深度学习TensorFlow入门环境配置
Chain Reading | The latest and most complete digital collection calendar-07.28
图片批量添加水印批量缩放图片到指定大小
LeetCode 162.寻找峰值(中等)
String common methods
细说MySql索引原理
LeetCode 面试题17.14 最小k个数(中等)
Content related to ZigBee network devices
cesium listens to map zoom or zoom to control whether the content added on the map is displayed
I use this recruit let the team to improve the development efficiency of 100%!
ORACLE系统表空间SYSTEM占满无法扩充表空间问题解决过程
2022李宏毅机器学习hw1--COVID-19 Cases Prediction
数据库 笔记 创建数据库、表 备份
MySQL中MyISAM为什么比InnoDB查询快