当前位置:网站首页>pytorch从零搭建神经网络实现多分类(训练自己的数据集)
pytorch从零搭建神经网络实现多分类(训练自己的数据集)
2022-08-09 14:52:00 【pomelo33】
简介
本文介绍如何使用pytorch搭建基础的神经网络,解决多分类问题。主要介绍了两个模型:①全连接层网络;②VGG11卷积神经网络模型(下次介绍)。为了演示方便,使用了Fashion-Mnist服装分类数据集(10分类数据集,介绍可以去网上搜一下,这里不赘述),也可以在自己的制作的数据集上训练(后面会稍作介绍)。在文章最后会附上完整的可运行的代码。
1 全连接层网络
全连接层网络包括输入层、隐藏层以及输出层。其中隐藏层中可以包括多个全连接层,理论上可以加无数层,加的越多,网络的深度越深。每个全连接层中可以包含多个节点,理论上也可以无数多,节点数越多,网络宽度越宽。但实际上,网络深和宽并不意味着性能越好,需要视情况而定。
一般每一层输出后还要使用激活函数,以及一些正则化手段如dropout。
1.1 搭建模型
class FCNet(nn.Module):#全连接网络
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784,512)
self.fc2 = nn.Linear(512,256)
self.fc3 = nn.Linear(256,128)
self.fc4 = nn.Linear(128,64)
self.fc5 = nn.Linear(64,10)
self.dropout = nn.Dropout(p=0.2)
def forward(self,x):
x = x.view(x.shape[0],-1)
x_1 = self.dropout(F.relu(self.fc1(x)))
x_2 = self.dropout(F.relu(self.fc2(x_1)))
x_3 = self.dropout(F.relu(self.fc3(x_2)))
x_4 = self.dropout(F.relu(self.fc4(x_3)))
x_out = F.softmax(self.fc5(x_4),1)
return x_out
可以看出全连接网络的搭建十分简单,很容易理解。首先创建一个类,继承Module类。初始化后定义各个全连接层,此处的定义并不一定要按照顺序,但为了容易理解,一般按顺序定义。
self.fc1 = nn.Linear(784,512)#第一层全连接层,节点数为512
self.fc2 = nn.Linear(512,256)#第二层全连接层,节点数为256
由于Fashion-Mnist数据集的每个样本的特征点数为784(28*28的图片),因此第一层全连接层的输入节点数为784,512则代表该全连接层的输出节点数(即该全连接层有512个节点)。以此类推,若下一层全连接层的节点数为256,则将输入节点数改为512,输出改为256。
①:
self.fc1 = nn.Linear(784,2048)#第一层全连接层,节点数为2048
self.fc2 = nn.Linear(2048,10)#第二层全连接层,节点数为10
②:
self.fc1 = nn.Linear(784,256)#第一层全连接层,节点数为256
self.fc2 = nn.Linear(256,128)#第二层全连接层,节点数为128
self.fc3 = nn.Linear(128,10)#第三层全连接层,节点数为10
实际上,对于一个全连接层网络,只需要固定输入节点数(784)和输出节点数(10)。其内部的全连接层节点数可以任意设定,例如以上①②所示都是可以的,只是效果会有差异。
self.dropout = nn.Dropout(p=0.2)
网络定义的时候,还定义了一个dropout层,因为全连接网络的节点数较多,而相邻层的每一个节点都两两相连,因此造成网络参数量较大,随着网络的深度和宽度加大,网络容易出现过拟合的现象。因此要采用正则化的手段,dropout为其中一种手段。其他正则化手段可以到以下链接稍作了解:正则化原理的简单分析(L1/L2正则化).
1.2 前向传播
def forward(self,x):
x = x.view(x.shape[0],-1)
x_1 = self.dropout(F.relu(self.fc1(x)))
x_2 = self.dropout(F.relu(self.fc2(x_1)))
x_3 = self.dropout(F.relu(self.fc3(x_2)))
x_4 = self.dropout(F.relu(self.fc4(x_3)))
x_out = F.softmax(self.fc5(x_4),1)
return x_out
前向传播的代编写也十分简单,首先要用.view()函数对每个输入数据样本展平为1 x 784的数据,才能传入该全连接网络模型。步骤:
①将输入x输入全连接层:self.fc(x)
②使用激活函数激活:relu(self.fc1(x))
③self.dropout(F.relu(self.fc1(x)))
最后得到第一层全连接层的输出x_1。以此类推,将x_1作为第二层的输入,继续正向传播。经过最后一层全连接层的10个节点后,使用softmax层输出分类结果。softmax层的作用是输出每个类别的概率。在此网络中的应用为十分类,因此输出为1 x 10个概率。简单说,几分类则最后一层全连接层的节点数就为几。
1.3 训练
读取数据
##预处理 将图片转换为tensor 归一化
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,)),transforms.Resize])
##获取数据集
train = datasets.FashionMNIST('dataset/',download=True,train=True,transform=transform)
test = datasets.FashionMNIST('dataset/',download=True,train=False,transform=transform)
##批量载入
batch_size = 64
train_iter = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=4)#num_workers要设置为你的cpu线程数
test_iter = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False, num_workers=4)
这里使用小批量训练,batch_size设置为64,可以视情况修改。num_workers要设置为你的cpu线程数
接下来看训练代码,首先要实例化网络;定义损失函数,这里使用的是交叉熵损失函数(交叉熵函数十分适合分类问题);定义优化器,这里使用Adam优化器,再设置epoch(训练轮数)。
具体的训练步骤参考注释:
net = FCNet()#实例化网络
lossFunc = nn.CrossEntropyLoss()#定义损失函数
optimizer = optim.Adam(net.parameters(),lr = 0.0001)#定义优化器,设置学习率
epochs = 20#训练轮数
train_loss, test_loss = [], []
print("开始训练FCNet")
for e in range(epochs):
running_loss = 0
for images,labels in train_iter: #小批量读取数据
optimizer.zero_grad() #将梯度清零
y_hat = net(images) #将数据输入网络
loss = lossFunc(y_hat,labels) #计算loss值
loss.backward() #误差反向传播
optimizer.step() #参数更新
running_loss += loss.item()## 将每轮的loss求和
test_runningloss = 0
test_acc = 0
with torch.no_grad(): #验证时不记录梯度
net.eval() #评估模式
for images,labels in test_iter:
y_hat = net(images)
test_runningloss += lossFunc(y_hat,labels)
ps = torch.exp(y_hat)
top_p,top_class = ps.topk(1,dim=1)
equals = top_class == labels.view(*top_class.shape)
test_acc += torch.mean(equals.type(torch.FloatTensor))
net.train()
train_loss.append(running_loss/len(train_iter))
test_loss.append(test_runningloss/len(test_iter))
print("训练集学习次数: {}/{}.. ".format(e + 1, epochs),
"训练误差: {:.3f}.. ".format(running_loss / len(train_iter)),
"测试误差: {:.3f}.. ".format(test_runningloss / len(test_iter)),
"模型分类准确率: {:.3f}".format(test_acc / len(test_iter)))
##训练结果可视化
plt.plot(train_loss,label='train loss')
plt.plot(test_loss,label='test loss')
plt.legend()
plt.show()
由loss可视化图中可以看出,loss值在20个epoch稳定下降。
完整代码
import torch
from torch import nn,optim
import torch.nn.functional as F
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
##预处理 将图片转换为tensor 归一化
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,)),transforms.Resize])
##获取数据集
train = datasets.FashionMNIST('dataset/',download=True,train=True,transform=transform)
test = datasets.FashionMNIST('dataset/',download=True,train=False,transform=transform)
##批量载入
batch_size = 64
train_iter = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=4)
test_iter = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False, num_workers=4)
class PomeloFCNet(nn.Module):#全连接网络
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784,512)
self.fc2 = nn.Linear(512,256)
self.fc3 = nn.Linear(256,128)
self.fc4 = nn.Linear(128,64)
self.fc5 = nn.Linear(64,10)
self.dropout = nn.Dropout(p=0.2)
def forward(self,x):
x = x.view(x.shape[0],-1)
x_1 = self.dropout(F.relu(self.fc1(x)))
x_2 = self.dropout(F.relu(self.fc2(x_1)))
x_3 = self.dropout(F.relu(self.fc3(x_2)))
x_4 = self.dropout(F.relu(self.fc4(x_3)))
x_out = F.softmax(self.fc5(x_4),1)
return x_out
net = PomeloFCNet()#实例化网络
lossFunc = nn.CrossEntropyLoss()#定义损失函数
optimizer = optim.Adam(net.parameters(),lr = 0.0001)#定义优化器,设置学习率
epochs = 20#训练轮数
train_loss, test_loss = [], []
print("开始训练PomeloFCNet")
for e in range(epochs):
running_loss = 0
for images,labels in train_iter:
optimizer.zero_grad()
y_hat = net(images)
loss = lossFunc(y_hat,labels)
loss.backward()
optimizer.step()
running_loss += loss.item()## 将每轮的loss求和
test_runningloss = 0
test_acc = 0
with torch.no_grad():
net.eval()
for images,labels in test_iter:
y_hat = net(images)
test_runningloss += lossFunc(y_hat,labels)
ps = torch.exp(y_hat)
top_p,top_class = ps.topk(1,dim=1)
equals = top_class == labels.view(*top_class.shape)
test_acc += torch.mean(equals.type(torch.FloatTensor))
net.train()
train_loss.append(running_loss/len(train_iter))
test_loss.append(test_runningloss/len(test_iter))
print("训练集学习次数: {}/{}.. ".format(e + 1, epochs),
"训练误差: {:.3f}.. ".format(running_loss / len(train_iter)),
"测试误差: {:.3f}.. ".format(test_runningloss / len(test_iter)),
"模型分类准确率: {:.3f}".format(test_acc / len(test_iter)))
plt.plot(train_loss,label='train loss')
plt.plot(test_loss,label='test loss')
plt.legend()
plt.show()
训练自己的数据集
若想要训练其他数据集,则需要修改数据读取部分代码。例如数据集为.npy格式的时候:
##读取数据
x_train = np.load('train_data.npy')#训练数据
x_train = torch.from_numpy(x_train)
x_train.float()
y_train = np.load('train_label.npy')#训练标签
y_train = torch.from_numpy(y_train)
y_train.float()
x_test = np.load('eval_data.npy')#验证数据
x_test = torch.from_numpy(x_test)
x_test.float()
y_test = np.load('eval_label.npy')#验证标签
y_test = torch.from_numpy(y_test)
y_test.float()
接下来创建dataset,使用TensorDataset()函数,将数据和标签传进去即可。num_workers数记得修改。
datasets_train = Data.TensorDataset(x_train,y_train)
train_iter = Data.DataLoader(datasets_train,batch_size=batch_size,shuffle=True,num_workers=16)
datasets_test = Data.TensorDataset(x_test,y_test)
test_iter = Data.DataLoader(datasets_test,batch_size=batch_size,shuffle=True,num_workers=16)
边栏推荐
- Analysis: Which method is used to build a stock quantitative trading database?
- [MySql] implement multi-table query - one-to-one, one-to-many
- Shell functions and arrays
- 基于FPGA的FIR滤波器的实现(2)—采用kaiserord & fir2 & firpm函数设计
- redis6在centos7的安装
- 程序化交易规则对于整个交易系统有什么意义?
- DBCO-PEG-DSPE,磷脂-聚乙二醇-二苯并环辛炔,在无铜离子的催化下反应
- 各种程序员线学习学习教程收集
- Redis6.2.1配置文件详解
- 在量化交易过程中,散户可以这样做
猜你喜欢
随机推荐
常微分方程的幂级数解法
多线程学习
What are the implications of programmatic trading rules for the entire trading system?
How to use and execute quantitative programmatic trading?
Servlet的生命周期
How to flexibly use the advantages of the quantitative trading interface to complement each other?
Shell编程之正则表达式
相干光(光学)
一种基于视频帧差异视频卡顿检测方案
What is the difference between the four common resistors?
C语言——void指针、NULL指针、指向指针的指针、常量和指针
[MySql]实现多表查询-一对一,一对多
机器学习--数学库--概率统计
防汛添利器,数字技术筑起抗洪“大堤”
Analysis: Which method is used to build a stock quantitative trading database?
约束性统计星号‘*’
EasyExcel的应用
Redis6.2.1配置文件详解
[MySql] implement multi-table query - one-to-one, one-to-many
突然想分析下房贷利率及利息计算