当前位置:网站首页>pytorch-10.卷积神经网络(作业)
pytorch-10.卷积神经网络(作业)
2022-08-10 05:32:00 【生信研究猿】
手算推导: 1 × 28 × 28 -> 10 × 24 × 24(conv) -> 10 × 12 × 12(maxpooling) -> 20 × 10 × 10 (conv) -> 20 × 5 × 5(maxpooling) -> 30 × 4 × 4(conv) -> 30 × 2 × 2(maxpooling) -> 120,64 (Linear)-> 64,32(Linear) -> 32,10(Linear)
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
batch_size = 64
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307),(0.3081))
])
train_dataset = datasets.MNIST(root='../dataset/mnist/',train=True,download=True,transform=transform)
train_loader = DataLoader(train_dataset,shuffle = True,batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist/',train=False,download=True,transform=transform)
test_loader = DataLoader(train_dataset,shuffle = False,batch_size=batch_size)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Conv2d(1,10,kernel_size=5)
self.conv2 = torch.nn.Conv2d(10,20,kernel_size=3) # input_channel , output_channel
self.conv3 = torch.nn.Conv2d(20, 30, kernel_size=2)
self.pooling = torch.nn.MaxPool2d(2)
self.fc1 = torch.nn.Linear(120,64)
self.fc2 = torch.nn.Linear(64, 32)
self.fc3 = torch.nn.Linear(32, 10)
'''
1 × 28 × 28 -> 10 × 24 × 24(conv) -> 10 × 12 × 12(maxpooling) -> 20 × 10 × 10 (conv) -> 20 × 5 × 5(maxpooling) -> 30 × 4 × 4(conv)
-> 30 × 2 × 2(maxpooling) -> 120,64 (Linear)-> 64,32(Linear) -> 32,10(Linear)
'''
def forward(self,x):
batch_size = x.size(0)
x = F.relu(self.pooling(self.conv1(x)))
x = F.relu(self.pooling(self.conv2(x)))
x = F.relu(self.pooling(self.conv3(x)))
x = x.view(batch_size,-1)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
model = Net()
device = torch.device("cuda:0"if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr=0.01,momentum=0.5)
def train(epoch):
running_loss = 0.0
for batch_idx,data in enumerate(train_loader,0):
inputs, target = data
inputs, target = inputs.to(device),target.to(device) #送到GPU
optimizer.zero_grad()
#forward + backward + update
outputs = model(inputs)
loss = criterion(outputs,target)
loss.backward()
optimizer.step()
running_loss +=loss.item()
if batch_idx % 300 ==299:
print('[%d,%5d]loss:%.3f'%(epoch+1,batch_idx+1,running_loss/300))
running_loss = 0
def test():
correct = 0
total = 0
with torch.no_grad():
for data in test_loader:
images,labels = data
images, labels = images.to(device), labels.to(device) # 送到GPU
outputs = model(images)
_,predicted = torch.max(outputs.data,dim=1) #dim=1维度1,行是第0个维度,列是第1个维度
total +=labels.size(0)
correct +=(predicted==labels).sum().item()
print('Accuracy on test set:%d %%'%(100*correct/total) )
if __name__ == '__main__':
for epoch in range(10):
train(epoch)
test()
结果:
[1, 300]loss:2.235
[1, 600]loss:0.707
[1, 900]loss:0.239
Accuracy on test set:94 %
[2, 300]loss:0.164
[2, 600]loss:0.127
[2, 900]loss:0.120
Accuracy on test set:96 %
[3, 300]loss:0.103
[3, 600]loss:0.103
[3, 900]loss:0.095
Accuracy on test set:97 %
[4, 300]loss:0.088
[4, 600]loss:0.080
[4, 900]loss:0.079
Accuracy on test set:97 %
[5, 300]loss:0.072
[5, 600]loss:0.073
[5, 900]loss:0.070
Accuracy on test set:97 %
[6, 300]loss:0.074
[6, 600]loss:0.062
[6, 900]loss:0.060
Accuracy on test set:98 %
[7, 300]loss:0.060
[7, 600]loss:0.057
[7, 900]loss:0.060
Accuracy on test set:98 %
[8, 300]loss:0.050
[8, 600]loss:0.061
[8, 900]loss:0.048
Accuracy on test set:98 %
[9, 300]loss:0.051
[9, 600]loss:0.049
[9, 900]loss:0.048
Accuracy on test set:98 %
[10, 300]loss:0.046
[10, 600]loss:0.045
[10, 900]loss:0.050
Accuracy on test set:98 %
Process finished with exit code 0
边栏推荐
- I use this recruit let the team to improve the development efficiency of 100%!
- 分享一款恋爱星座男女配对微信小程序源码
- 细说MySql索引原理
- Count down the six weapons of the domestic interface collaboration platform!
- Smart contracts and DAPP decentralized applications
- LeetCode 面试题17.14 最小k个数(中等)
- 转载fstream,ifstream的详细用法
- Set Sources Resources and other folders in the IDEA project
- Timer (setInterval) on and off
- 第十天作业
猜你喜欢
随机推荐
redis---非关系型数据库(NoSql)
并查集原理与API设计
Using sqlplus to operate database in shell script
MySql constraints
泛型笔记()()()
我不喜欢我的代码
细说MySql索引原理
Bifrost micro synchronous database implementation services across the library data synchronization
微信小程序-小程序的宿主环境
Chain Reading Good Article: Jeff Garzik Launches Web3 Production Company
Smart contracts and DAPP decentralized applications
wiki confluence 安装
tinymce rich text editor
图片批量添加水印批量缩放图片到指定大小
链读|最新最全的数字藏品发售日历-07.29
Count down the six weapons of the domestic interface collaboration platform!
cesium 旋转图片
WeChat applet wx.writeBLECharacteristicValue Chinese character to buffer problem
Canal reports Could not find first log file name in binary log index file
I use this recruit let the team to improve the development efficiency of 100%!