当前位置:网站首页>pytorch-10. Convolutional Neural Networks (homework)
pytorch-10. Convolutional Neural Networks (homework)
2022-08-10 05:56:00 【Shengxin Research Ape】

手算推导: 1 × 28 × 28 -> 10 × 24 × 24(conv) -> 10 × 12 × 12(maxpooling) -> 20 × 10 × 10 (conv) -> 20 × 5 × 5(maxpooling) -> 30 × 4 × 4(conv) -> 30 × 2 × 2(maxpooling) -> 120,64 (Linear)-> 64,32(Linear) -> 32,10(Linear)
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
batch_size = 64
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307),(0.3081))
])
train_dataset = datasets.MNIST(root='../dataset/mnist/',train=True,download=True,transform=transform)
train_loader = DataLoader(train_dataset,shuffle = True,batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist/',train=False,download=True,transform=transform)
test_loader = DataLoader(train_dataset,shuffle = False,batch_size=batch_size)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Conv2d(1,10,kernel_size=5)
self.conv2 = torch.nn.Conv2d(10,20,kernel_size=3) # input_channel , output_channel
self.conv3 = torch.nn.Conv2d(20, 30, kernel_size=2)
self.pooling = torch.nn.MaxPool2d(2)
self.fc1 = torch.nn.Linear(120,64)
self.fc2 = torch.nn.Linear(64, 32)
self.fc3 = torch.nn.Linear(32, 10)
'''
1 × 28 × 28 -> 10 × 24 × 24(conv) -> 10 × 12 × 12(maxpooling) -> 20 × 10 × 10 (conv) -> 20 × 5 × 5(maxpooling) -> 30 × 4 × 4(conv)
-> 30 × 2 × 2(maxpooling) -> 120,64 (Linear)-> 64,32(Linear) -> 32,10(Linear)
'''
def forward(self,x):
batch_size = x.size(0)
x = F.relu(self.pooling(self.conv1(x)))
x = F.relu(self.pooling(self.conv2(x)))
x = F.relu(self.pooling(self.conv3(x)))
x = x.view(batch_size,-1)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
model = Net()
device = torch.device("cuda:0"if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr=0.01,momentum=0.5)
def train(epoch):
running_loss = 0.0
for batch_idx,data in enumerate(train_loader,0):
inputs, target = data
inputs, target = inputs.to(device),target.to(device) #送到GPU
optimizer.zero_grad()
#forward + backward + update
outputs = model(inputs)
loss = criterion(outputs,target)
loss.backward()
optimizer.step()
running_loss +=loss.item()
if batch_idx % 300 ==299:
print('[%d,%5d]loss:%.3f'%(epoch+1,batch_idx+1,running_loss/300))
running_loss = 0
def test():
correct = 0
total = 0
with torch.no_grad():
for data in test_loader:
images,labels = data
images, labels = images.to(device), labels.to(device) # 送到GPU
outputs = model(images)
_,predicted = torch.max(outputs.data,dim=1) #dim=1维度1,行是第0个维度,列是第1个维度
total +=labels.size(0)
correct +=(predicted==labels).sum().item()
print('Accuracy on test set:%d %%'%(100*correct/total) )
if __name__ == '__main__':
for epoch in range(10):
train(epoch)
test()结果:
[1, 300]loss:2.235
[1, 600]loss:0.707
[1, 900]loss:0.239
Accuracy on test set:94 %
[2, 300]loss:0.164
[2, 600]loss:0.127
[2, 900]loss:0.120
Accuracy on test set:96 %
[3, 300]loss:0.103
[3, 600]loss:0.103
[3, 900]loss:0.095
Accuracy on test set:97 %
[4, 300]loss:0.088
[4, 600]loss:0.080
[4, 900]loss:0.079
Accuracy on test set:97 %
[5, 300]loss:0.072
[5, 600]loss:0.073
[5, 900]loss:0.070
Accuracy on test set:97 %
[6, 300]loss:0.074
[6, 600]loss:0.062
[6, 900]loss:0.060
Accuracy on test set:98 %
[7, 300]loss:0.060
[7, 600]loss:0.057
[7, 900]loss:0.060
Accuracy on test set:98 %
[8, 300]loss:0.050
[8, 600]loss:0.061
[8, 900]loss:0.048
Accuracy on test set:98 %
[9, 300]loss:0.051
[9, 600]loss:0.049
[9, 900]loss:0.048
Accuracy on test set:98 %
[10, 300]loss:0.046
[10, 600]loss:0.045
[10, 900]loss:0.050
Accuracy on test set:98 %
Process finished with exit code 0
边栏推荐
猜你喜欢
随机推荐
Content related to ZigBee network devices
细说MySql索引原理
LeetCode 2011.执行操作后的变量值(简单)
21天挑战杯MySQL——Day06
el-dropdown drop-down menu style modification, remove the small triangle
LeetCode refers to the offer 21. Adjust the order of the array so that the odd numbers are in front of the even numbers (simple)
大端以及小端以及读寄存器习惯
Pytorch - 07. Multidimensional characteristics of input processing
wiki confluence installation
2022李宏毅机器学习hw1--COVID-19 Cases Prediction
【简易笔记】PyTorch官方教程简易笔记 EP2
Day1 微信小程序-小程序代码的构成
Set Sources Resources and other folders in the IDEA project
The latest and most complete digital collection sales calendar-07.26
pytorch-09. Multi-classification problem
LeetCode 292. Nim Game (Simple)
LeetCode 剑指offer 10-I.斐波那契数列(简单)
我不喜欢我的代码
String常用方法
LeetCode 1720.解码异或后的数组(简单)









