当前位置:网站首页>pytorch-10.卷积神经网络(作业)
pytorch-10.卷积神经网络(作业)
2022-08-10 05:32:00 【生信研究猿】

手算推导: 1 × 28 × 28 -> 10 × 24 × 24(conv) -> 10 × 12 × 12(maxpooling) -> 20 × 10 × 10 (conv) -> 20 × 5 × 5(maxpooling) -> 30 × 4 × 4(conv) -> 30 × 2 × 2(maxpooling) -> 120,64 (Linear)-> 64,32(Linear) -> 32,10(Linear)
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
batch_size = 64
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307),(0.3081))
])
train_dataset = datasets.MNIST(root='../dataset/mnist/',train=True,download=True,transform=transform)
train_loader = DataLoader(train_dataset,shuffle = True,batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist/',train=False,download=True,transform=transform)
test_loader = DataLoader(train_dataset,shuffle = False,batch_size=batch_size)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Conv2d(1,10,kernel_size=5)
self.conv2 = torch.nn.Conv2d(10,20,kernel_size=3) # input_channel , output_channel
self.conv3 = torch.nn.Conv2d(20, 30, kernel_size=2)
self.pooling = torch.nn.MaxPool2d(2)
self.fc1 = torch.nn.Linear(120,64)
self.fc2 = torch.nn.Linear(64, 32)
self.fc3 = torch.nn.Linear(32, 10)
'''
1 × 28 × 28 -> 10 × 24 × 24(conv) -> 10 × 12 × 12(maxpooling) -> 20 × 10 × 10 (conv) -> 20 × 5 × 5(maxpooling) -> 30 × 4 × 4(conv)
-> 30 × 2 × 2(maxpooling) -> 120,64 (Linear)-> 64,32(Linear) -> 32,10(Linear)
'''
def forward(self,x):
batch_size = x.size(0)
x = F.relu(self.pooling(self.conv1(x)))
x = F.relu(self.pooling(self.conv2(x)))
x = F.relu(self.pooling(self.conv3(x)))
x = x.view(batch_size,-1)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
model = Net()
device = torch.device("cuda:0"if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr=0.01,momentum=0.5)
def train(epoch):
running_loss = 0.0
for batch_idx,data in enumerate(train_loader,0):
inputs, target = data
inputs, target = inputs.to(device),target.to(device) #送到GPU
optimizer.zero_grad()
#forward + backward + update
outputs = model(inputs)
loss = criterion(outputs,target)
loss.backward()
optimizer.step()
running_loss +=loss.item()
if batch_idx % 300 ==299:
print('[%d,%5d]loss:%.3f'%(epoch+1,batch_idx+1,running_loss/300))
running_loss = 0
def test():
correct = 0
total = 0
with torch.no_grad():
for data in test_loader:
images,labels = data
images, labels = images.to(device), labels.to(device) # 送到GPU
outputs = model(images)
_,predicted = torch.max(outputs.data,dim=1) #dim=1维度1,行是第0个维度,列是第1个维度
total +=labels.size(0)
correct +=(predicted==labels).sum().item()
print('Accuracy on test set:%d %%'%(100*correct/total) )
if __name__ == '__main__':
for epoch in range(10):
train(epoch)
test()结果:
[1, 300]loss:2.235
[1, 600]loss:0.707
[1, 900]loss:0.239
Accuracy on test set:94 %
[2, 300]loss:0.164
[2, 600]loss:0.127
[2, 900]loss:0.120
Accuracy on test set:96 %
[3, 300]loss:0.103
[3, 600]loss:0.103
[3, 900]loss:0.095
Accuracy on test set:97 %
[4, 300]loss:0.088
[4, 600]loss:0.080
[4, 900]loss:0.079
Accuracy on test set:97 %
[5, 300]loss:0.072
[5, 600]loss:0.073
[5, 900]loss:0.070
Accuracy on test set:97 %
[6, 300]loss:0.074
[6, 600]loss:0.062
[6, 900]loss:0.060
Accuracy on test set:98 %
[7, 300]loss:0.060
[7, 600]loss:0.057
[7, 900]loss:0.060
Accuracy on test set:98 %
[8, 300]loss:0.050
[8, 600]loss:0.061
[8, 900]loss:0.048
Accuracy on test set:98 %
[9, 300]loss:0.051
[9, 600]loss:0.049
[9, 900]loss:0.048
Accuracy on test set:98 %
[10, 300]loss:0.046
[10, 600]loss:0.045
[10, 900]loss:0.050
Accuracy on test set:98 %
Process finished with exit code 0
边栏推荐
猜你喜欢
随机推荐
tinymce富文本编辑器
Database Notes Create Database, Table Backup
我不喜欢我的代码
测一测异性的你长什么样?
[Difference between el and template]
集合 Map
优先队列
ORACLE system table space SYSTEM is full and cannot expand table space problem solving process
21天挑战杯MySQL——Day06
连接 Nacos 报超时错误
分享一款恋爱星座男女配对微信小程序源码
el-dropdown drop-down menu style modification, remove the small triangle
LeetCode 面试题17.14 最小k个数(中等)
Batch add watermark to pictures batch scale pictures to specified size
每天一个小知识点
视图【】【】【】【】
常用类 BigDecimal
集合 set接口
Reflection 【Notes】
String常用方法








