当前位置:网站首页>pytorch-09. Multi-classification problem
pytorch-09. Multi-classification problem
2022-08-10 05:56:00 【Shengxin Research Ape】
NLLLoss(Negative Log Likelihood Loss),最大似然函数. 把Label对应的输出log_result值拿出来,求和取平均. --------------------------------------------------------------------------------------- CrossEntropyLoss交叉熵损失函数. 一步执行完:softmax+log+NLLLoss合并起来了.
NLLLoss
CrossEntropyLoss
softmax:
CrossEntropyLoss示例:
#torch.LongTensor是64位整型 #torch.Tensor默认torch.FloatTensor,是32位浮点类型数据. #torch.tensor是一个类,Used to generate a single-precision floating-point tensor.
import torch
criterion = torch.nn.CrossEntropyLoss()
Y = torch.LongTensor([2,0,1])
Y_pred1 = torch.Tensor([
[0.1,0.2,0.9],
[1.1,0.1,0.2],
[0.2,2.1,0.1]
])
Y_pred2 = torch.Tensor([
[0.8,0.2,0.3],
[0.2,0.3,0.5],
[0.2,0.2,0.5]
])
l1 = criterion(Y_pred1,Y)
l2 = criterion(Y_pred2,Y)
print("Loss1 = ",l1.data.item(),"\nLoss2 = ",l2.data.item())
mnist数据集实践
#minstThe mean of the dataset is0.1307,标准差是0.3081
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
batch_size = 64
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307),(0.3081))
])
train_dataset = datasets.MNIST(root='../dataset/mnist/',train=True,download=True,transform=transform)
train_loader = DataLoader(train_dataset,shuffle = True,batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist/',train=False,download=True,transform=transform)
test_loader = DataLoader(train_dataset,shuffle = False,batch_size=batch_size)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l1 = torch.nn.Linear(784,512)
self.l2 = torch.nn.Linear(512,256)
self.l3 = torch.nn.Linear(256, 128)
self.l4 = torch.nn.Linear(128, 64)
self.l5 = torch.nn.Linear(64, 10)
def forward(self,x):
x = x.view(-1,784)
x = F.relu(self.l1(x))
x = F.relu(self.l2(x))
x = F.relu(self.l3(x))
x = F.relu(self.l4(x))
return self.l5(x) #最后一层不做激活,不进行非线性变换
model = Net()
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr=0.01,momentum=0.5)
def train(epoch):
running_loss = 0.0
for batch_idx,data in enumerate(train_loader,0):
inputs, target = data
optimizer.zero_grad()
#forward + backward + update
outputs = model(inputs)
loss = criterion(outputs,target)
loss.backward()
optimizer.step()
running_loss +=loss.item()
if batch_idx % 300 ==299:
print('[%d,%5d]loss:%.3f'%(epoch+1,batch_idx+1,running_loss/300))
running_loss = 0
def test():
correct = 0
total = 0
with torch.no_grad():
for data in test_loader:
images,labels = data
outputs = model(images)
_,predicted = torch.max(outputs.data,dim=1) #dim=1维度1,行是第0个维度,列是第1个维度
total +=labels.size(0)
correct +=(predicted==labels).sum().item()
print('Accuracy on test set:%d %%'%(100*correct/total) )
if __name__ == '__main__':
for epoch in range(10):
train(epoch)
test()
结果:
[1, 300]loss:2.223
[1, 600]loss:0.923
[1, 900]loss:0.435
Accuracy on test set:89 %
[2, 300]loss:0.328
[2, 600]loss:0.272
[2, 900]loss:0.239
Accuracy on test set:94 %
[3, 300]loss:0.188
[3, 600]loss:0.175
[3, 900]loss:0.158
Accuracy on test set:96 %
[4, 300]loss:0.126
[4, 600]loss:0.130
[4, 900]loss:0.121
Accuracy on test set:97 %
[5, 300]loss:0.098
[5, 600]loss:0.099
[5, 900]loss:0.097
Accuracy on test set:97 %
[6, 300]loss:0.078
[6, 600]loss:0.079
[6, 900]loss:0.081
Accuracy on test set:97 %
[7, 300]loss:0.066
[7, 600]loss:0.063
[7, 900]loss:0.064
Accuracy on test set:98 %
[8, 300]loss:0.050
[8, 600]loss:0.056
[8, 900]loss:0.051
Accuracy on test set:98 %
[9, 300]loss:0.041
[9, 600]loss:0.043
[9, 900]loss:0.043
Accuracy on test set:99 %
[10, 300]loss:0.034
[10, 600]loss:0.034
[10, 900]loss:0.037
Accuracy on test set:99 %
Process finished with exit code 0
边栏推荐
猜你喜欢
随机推荐
视图【】【】【】【】
Likou - Number of Provinces
pytorch-05.用pytorch实现线性回归
.Net Core导入千万级数据至Mysql
opencv
Convolutional Neural Network (CNN) for Clothing Image Classification
机器学习——聚类——商场客户聚类
21天挑战杯MySQL-Day05
21天挑战杯MySQL——Day06
Set Sources Resources and other folders in the IDEA project
Reflection 【Notes】
generic notes()()()
Reprint fstream, detailed usage of ifstream
棋类游戏-五子棋小游戏
Index Notes【】【】
Collection tool class
.Net Core imports tens of millions of data to Mysql
基于 .NET Core MVC 的权限管理系统
cesium listens to map zoom or zoom to control whether the content added on the map is displayed
cesium rotate image