当前位置:网站首页>空间金字塔池化 -Spatial Pyramid Pooling(含源码)

空间金字塔池化 -Spatial Pyramid Pooling(含源码)

2022-08-11 05:35:00 KPer_Yang

目录

参考:

1、Spatial Pyramid Pooling解决的问题

2、Spatial Pyramid Pooling实现原理

3、Spatial Pyramid Pooling的代码实现


参考:

《Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition》

论文链接:[1406.4729] Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition (arxiv.org)

1、Spatial Pyramid Pooling解决的问题

空间金字塔池化主要用来解决输入图片的分辨率不一致的问题。之前解决图片分辨率不一致使用的是图片缩放或者裁剪,这样容易导致图片信息丢失。两种解决图片分辨率不一致问题的方法的区别如图1.1所示:

图1.1 裁剪、缩放和Spatial Pyramid Pooling的区别

2、Spatial Pyramid Pooling实现原理

如图2.1所示,SPP-Net的实现是由多种不同大小的池化层对特征图进行池化,然后进行向量展平和拼接。文中使用的是16*16、4*4、1*1的池化层,在具体应用到自己的任务中时,可以根据特征图的大小等因素进行更改。同时,当特征图不是长宽相等,需要进行padding操作,并且16*16、4*4都是按照划分网格的方式进行池化,跟普通的池化层的操作有区别。

图2.1  Spatial Pyramid Pooling实现原理图示

3、Spatial Pyramid Pooling的代码实现

yueruchen/sppnet-pytorch: A simple Spatial Pyramid Pooling layer which could be added in CNN (github.com)

import math
def spatial_pyramid_pool(self,previous_conv, num_sample, previous_conv_size, out_pool_size):
    '''
    previous_conv: a tensor vector of previous convolution layer
    num_sample: an int number of image in the batch
    previous_conv_size: an int vector [height, width] of the matrix features size of previous convolution layer
    out_pool_size: a int vector of expected output size of max pooling layer
    
    returns: a tensor vector with shape [1 x n] is the concentration of multi-level pooling
    '''    
    # print(previous_conv.size())
    for i in range(len(out_pool_size)):
        # print(previous_conv_size)
        h_wid = int(math.ceil(previous_conv_size[0] / out_pool_size[i]))
        w_wid = int(math.ceil(previous_conv_size[1] / out_pool_size[i]))
        h_pad = (h_wid*out_pool_size[i] - previous_conv_size[0] + 1)/2
        w_pad = (w_wid*out_pool_size[i] - previous_conv_size[1] + 1)/2
        maxpool = nn.MaxPool2d((h_wid, w_wid), stride=(h_wid, w_wid), padding=(h_pad, w_pad))
        x = maxpool(previous_conv)
        if(i == 0):
            spp = x.view(num_sample,-1)
            # print("spp size:",spp.size())
        else:
            # print("size:",spp.size())
            spp = torch.cat((spp,x.view(num_sample,-1)), 1)
    return 

原网站

版权声明
本文为[KPer_Yang]所创,转载请带上原文链接,感谢
https://blog.csdn.net/KPer_Yang/article/details/125902734