当前位置:网站首页>动手学深度学习_VggNet
动手学深度学习_VggNet
2022-08-04 18:37:00 【CV小Rookie】
VggNet 是 AlexNet 的改进,将 AlexNet 网络中的 5x5 或 11x11 的卷积核全都替换成 3x3 的大小;maxpoolin 改成了 2x2 的大小;另外对于部分卷积和Pooling进行块的封装,使用循环来构建复杂的框架。

上图是Vgg16的结构图,可以看到一共有 5 块,前两块是2层卷积+一层maxpooling,后三层是3层卷积+一层maxpooling。
具体设计网络的时候可以参考下面的表格,十分清晰。

def Conv3x3BNReLU(in_channels,out_channels):
return nn.Sequential(
nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU6(inplace=True)
)
class VGG(nn.Module):
def __init__(self, block_nums,num_classes=10):
super(VGG, self).__init__()
self.stage1 = self._make_layers(in_channels=1, out_channels=64, block_num=block_nums[0])
self.stage2 = self._make_layers(in_channels=64, out_channels=128, block_num=block_nums[1])
self.stage3 = self._make_layers(in_channels=128, out_channels=256, block_num=block_nums[2])
self.stage4 = self._make_layers(in_channels=256, out_channels=512, block_num=block_nums[3])
self.stage5 = self._make_layers(in_channels=512, out_channels=512, block_num=block_nums[4])
self.classifier = nn.Sequential(
nn.Linear(in_features=512*7*7,out_features=4096),
nn.Dropout(p=0.2),
nn.Linear(in_features=4096, out_features=4096),
nn.Dropout(p=0.2),
nn.Linear(in_features=4096, out_features=num_classes)
)
self._init_params()
def _make_layers(self, in_channels, out_channels, block_num):
layers = []
layers.append(Conv3x3BNReLU(in_channels,out_channels))
for i in range(1,block_num):
layers.append(Conv3x3BNReLU(out_channels,out_channels))
layers.append(nn.MaxPool2d(kernel_size=2,stride=2, ceil_mode=False))
return nn.Sequential(*layers)
def _init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.stage5(x)
x = x.view(x.size(0),-1)
out = self.classifier(x)
return out
def VGG16():
block_nums = [2, 2, 3, 3, 3]
model = VGG(block_nums)
return model
def VGG19():
block_nums = [2, 2, 4, 4, 4]
model = VGG(block_nums)
return model打印一下看看网络具体结构吧
Sequential output shape: torch.Size([1, 64, 112, 112]) Sequential output shape: torch.Size([1, 128, 56, 56]) Sequential output shape: torch.Size([1, 256, 28, 28]) Sequential output shape: torch.Size([1, 512, 14, 14]) Sequential output shape: torch.Size([1, 512, 7, 7]) Flatten output shape: torch.Size([1, 25088]) Linear output shape: torch.Size([1, 4096]) ReLU output shape: torch.Size([1, 4096]) Dropout output shape: torch.Size([1, 4096]) Linear output shape: torch.Size([1, 4096]) ReLU output shape: torch.Size([1, 4096]) Dropout output shape: torch.Size([1, 4096]) Linear output shape: torch.Size([1, 10])
边栏推荐
猜你喜欢
随机推荐
win10 uwp 修改Pivot Header 颜色
面试官:MVCC是如何实现的?
Short-term reliability and economic evaluation of resilient microgrids under incentive-based demand response programs (Matlab code implementation)
网络运维管理从基础到实战-自用笔记(1)构建综合园区网、接入互联网
单行、多行文本超出显示省略号
如何进行自动化测试?
unity中实现ue眼球的渲染
Develop those things: How to obtain the traffic statistics of the monitoring site through the EasyCVR platform?
数据库SqlServer迁移PostgreSql实践
Day018 继承
【AI+医疗】斯坦福大学最新博士论文《深度学习在医学影像理解中的应用》,205页pdf
Understanding of margin collapse and coincidence
在线生成接口文档
数据集成:holo数据同步至redis。redis必须是集群模式?
mood swings
EuROC 数据集格式及相关代码
Go language Go language, understand Go language file operation in one article
powershell和cmd对比
How can test engineers break through career bottlenecks?
curl命令的那些事










