当前位置:网站首页>pytorch-10.卷积神经网络
pytorch-10.卷积神经网络
2022-08-10 05:32:00 【生信研究猿】
示例:
import torch
in_channels, out_channels = 5,10
width , height = 100,100
kernel_size = 3
batch_size = 1
input = torch.randn(batch_size,in_channels,width,height)
conv_layer = torch.nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size)
output = conv_layer(input)
print(input.shape)
print(output.shape)
print(conv_layer.weight.shape)
padding示例:
如padding=1表示输入时,图像外面再补一圈0像素点
import torch
input = [
3,4,5,6,7,
2,4,6,8,2,
1,6,7,8,4,
9,7,4,6,2,
3,7,5,4,1
]
input = torch.Tensor(input).view(1,1,5,5) #1,1,5,5 batch,channel,W,H
conv_layer = torch.nn.Conv2d(1,1,kernel_size=3,padding=1,bias=False) #padding=1图像外面再补一圈0像素点
kernel = torch.Tensor([1,2,3,4,5,6,7,8,9]).view(1,1,3,3) #1,1,3,3 output_channel, input_channel,W,H
conv_layer.weight.data = kernel.data
output = conv_layer(input)
print(output)
stride 步长示例:
#步长
import torch
input = [
3,4,5,6,7,
2,4,6,8,2,
1,6,7,8,4,
9,7,4,6,2,
3,7,5,4,1
]
input = torch.Tensor(input).view(1,1,5,5) #1,1,5,5 batch,channel,W,H
conv_layer = torch.nn.Conv2d(1,1,kernel_size=3,stride=2,bias=False) #padding=1图像外面再补一圈0像素点
kernel = torch.Tensor([1,2,3,4,5,6,7,8,9]).view(1,1,3,3) #1,1,3,3 output_channel, input_channel,W,H
conv_layer.weight.data = kernel.data
output = conv_layer(input)
print(output)
MaxPooling
# MaxPooling
import torch
input = [
3,4,5,6,
2,4,6,8,
1,6,7,8,
9,7,4,6
]
input = torch.Tensor(input).view(1,1,4,4)
maxpooling_layer = torch.nn.MaxPool2d(kernel_size=2) #kernel_size=2默认步长为2
output = maxpooling_layer(input)
print(output)
input2 = [
3,4,5,6,7,8,
2,4,6,8,5,6,
1,6,7,8,3,9,
9,7,4,6,5,5,
6,2,8,4,7,6,
8,7,8,3,9,8
]
input2 = torch.Tensor(input2).view(1,1,6,6)
maxpooling_layer = torch.nn.MaxPool2d(kernel_size=3)
output2 = maxpooling_layer(input2)
print(output2)
用GPU跑的代码:
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
batch_size = 64
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307),(0.3081))
])
train_dataset = datasets.MNIST(root='../dataset/mnist/',train=True,download=True,transform=transform)
train_loader = DataLoader(train_dataset,shuffle = True,batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist/',train=False,download=True,transform=transform)
test_loader = DataLoader(train_dataset,shuffle = False,batch_size=batch_size)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Conv2d(1,10,kernel_size=5)
self.conv2 = torch.nn.Conv2d(10,20,kernel_size=5) # input_channel , output_channel
self.pooling = torch.nn.MaxPool2d(2)
self.fc = torch.nn.Linear(320,10)
def forward(self,x):
batch_size = x.size(0)
x = F.relu(self.pooling(self.conv1(x)))
x = F.relu(self.pooling(self.conv2(x)))
x = x.view(batch_size,-1)
x = self.fc(x)
return x #最后一层不做激活,不进行非线性变换
model = Net()
device = torch.device("cuda:0"if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr=0.01,momentum=0.5)
def train(epoch):
running_loss = 0.0
for batch_idx,data in enumerate(train_loader,0):
inputs, target = data
inputs, target = inputs.to(device),target.to(device) #送到GPU
optimizer.zero_grad()
#forward + backward + update
outputs = model(inputs)
loss = criterion(outputs,target)
loss.backward()
optimizer.step()
running_loss +=loss.item()
if batch_idx % 300 ==299:
print('[%d,%5d]loss:%.3f'%(epoch+1,batch_idx+1,running_loss/300))
running_loss = 0
def test():
correct = 0
total = 0
with torch.no_grad():
for data in test_loader:
images,labels = data
images, labels = images.to(device), labels.to(device) # 送到GPU
outputs = model(images)
_,predicted = torch.max(outputs.data,dim=1) #dim=1维度1,行是第0个维度,列是第1个维度
total +=labels.size(0)
correct +=(predicted==labels).sum().item()
print('Accuracy on test set:%d %%'%(100*correct/total) )
if __name__ == '__main__':
for epoch in range(10):
train(epoch)
test()
结果:
[1, 300]loss:0.568
[1, 600]loss:0.186
[1, 900]loss:0.138
Accuracy on test set:96 %
[2, 300]loss:0.108
[2, 600]loss:0.101
[2, 900]loss:0.091
Accuracy on test set:97 %
[3, 300]loss:0.076
[3, 600]loss:0.077
[3, 900]loss:0.073
Accuracy on test set:97 %
[4, 300]loss:0.065
[4, 600]loss:0.061
[4, 900]loss:0.062
Accuracy on test set:98 %
[5, 300]loss:0.056
[5, 600]loss:0.055
[5, 900]loss:0.052
Accuracy on test set:98 %
[6, 300]loss:0.049
[6, 600]loss:0.053
[6, 900]loss:0.047
Accuracy on test set:98 %
[7, 300]loss:0.046
[7, 600]loss:0.044
[7, 900]loss:0.044
Accuracy on test set:98 %
[8, 300]loss:0.043
[8, 600]loss:0.041
[8, 900]loss:0.042
Accuracy on test set:98 %
[9, 300]loss:0.038
[9, 600]loss:0.037
[9, 900]loss:0.037
Accuracy on test set:99 %
[10, 300]loss:0.033
[10, 600]loss:0.039
[10, 900]loss:0.033
Accuracy on test set:99 %
Process finished with exit code 0
边栏推荐
猜你喜欢
随机推荐
generic notes()()()
【el和template区别】
网络安全作业
Bifrost micro synchronous database implementation services across the library data synchronization
sqlplus 显示上一条命令及可用退格键
一个基于.Net Core 开源的物联网基础平台
LeetCode 剑指offer 21.调整数组顺序使奇数位于偶数前面(简单)
MySQL中MyISAM为什么比InnoDB查询快
复杂的“元宇宙”,为您解读,链读APP即将上线!
tinymce富文本编辑器
ACID四种特性
链读 | 最新最全的数字藏品发售日历-07.28
链读精选:星巴克着眼于数字收藏品并更好地吸引客户
Using sqlplus to operate database in shell script
优先队列
Linux数据库Oracle客户端安装,用于shell脚本用sqlplus连接数据库
redis---非关系型数据库(NoSql)
使用Google Protobuf 在 Matlab 中工作
ZigBee 网络设备相关内容
毫米波雷达数据集Scorp使用