当前位置:网站首页>DeamNet代码学习||网络框架核心代码 逐句查找学习

DeamNet代码学习||网络框架核心代码 逐句查找学习

2022-08-10 16:47:00 Claire_Shang

目录

DeamNet网络架构图

1.  定义各种类(class)

2.  定义编码-解码块

3. DEAM 模块和 NLO子网络

 4. 图像域转换与逆转换

5. 整体DeamNet

DeamNet网络架构图

1.  定义各种类(class)

nn.Module类的基本定义

在定义网络时,需要继承nn.Module类,并重新实现构造函数__init__()forward这两个方法。在构造函数__init__()中使用super(Model, self).init()来调用父类的构造函数,forward方法是必须要重写的,它是实现模型的功能,实现各个层之间的连接关系的核心。
1.一般把网络中具有可学习参数的层(如全连接层、卷积层)放在构造函数__init__()中。
2.一般把不具有可学习参数的层(如ReLU、dropout)可放在构造函数中,也可不放在构造函数中(在forward中使用
nn.functional来调用)。

import torch
import torch.nn as nn


class ConvLayer1(nn.Module):  #ConvLayer1为子类,nn.Module为父类

    def __init__(self, in_channels, out_channels, kernel_size, stride):
          #  in_channel: 输入数据的通道数,out_channel: 输出数据的通道数,stride 步长
        super(ConvLayer1, self).__init__()   # 调用父类的构造函数
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size-1)//2, stride= stride)
          # nn.Conv2d 二维卷积可处理二维数据
        nn.init.xavier_normal_(self.conv2d.weight.data)
        # nn.init. 参数初始化方法  xavier初始化方式 normal_ 正态分布  
        # .weight.data:得到的是一个Tensor的张量(向量),不可训练的类型
    def forward(self, x): 
        # out = self.reflection_pad(x)
        # out = self.conv2d(out)
        return self.conv2d(x)


class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        padding = (kernel_size - 1) // 2
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, stride=stride),
            nn.ReLU()
        )
     # self.属性名 = 属性的初始值.  添加属性并赋值  
  # nn.Sequential是一个有序的容器,神经网络模块按照传入构造器的顺序被添加到计算图中执行
        nn.init.xavier_normal_(self.block[0].weight.data)

    def forward(self, x):
        return self.block(x)


class line(nn.Module):
    def __init__(self):
        super(line, self).__init__()
        #randn(*size, out=None, dtype=None) 返回一个张量,包含了从标准正态分布(均值为0,方差为1,即高斯白噪声)中抽取的一组随机数,张量的形状由sizes定义
        self.delta = nn.Parameter(torch.randn(1, 1))
    
    # torch.mul(input, other, *, out=None) 输入:两个张量矩阵;输出:他们的点乘运算结果
    def forward(self, x, y):
        return torch.mul((1 - self.delta), x) + torch.mul(self.delta, y)

torch.nn.Parameter()

首先可以把这个函数理解为类型转换函数,将一个不可训练的类型Tensor转换成可以训练的类型parameter并将这个parameter绑定到这个module里面(net.parameter()中就有这个绑定的parameter,所以在参数优化的时候可以进行优化的),所以经过类型转换这个self.v变成了模型的一部分,成为了模型中根据训练可以改动的参数了。使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化。

2.  定义编码-解码块

class Encoding_block(nn.Module):
    def __init__(self, base_filter, n_convblock):
        super(Encoding_block, self).__init__()
        self.n_convblock = n_convblock
        modules_body = []   # 空列表 代表list列表数据类型
        for i in range(self.n_convblock - 1):
            modules_body.append(ConvLayer(base_filter, base_filter, 3, stride=1))
         #  append()函数用于在列表末尾添加新的对象
        modules_body.append(ConvLayer(base_filter, base_filter, 3, stride=2))
        self.body = nn.Sequential(*modules_body)
      # nn.Sequential的定义来看,输入遇到list,必须用*号进行转化,否则会报错

    def forward(self, x):
       for i in range(self.n_convblock - 1):
            x = self.body[i](x)
        ecode = x
        x = self.body[self.n_convblock - 1](x)
        return ecode, x


class UpsampleConvLayer(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
        super(UpsampleConvLayer, self).__init__()
        self.upsample = upsample
        self.conv2d = ConvLayer(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        x_in = x
        if self.upsample:
             # torch.nn.functional.interpolate实现插值和上采样
            x_in = torch.nn.functional.interpolate(x_in, scale_factor=self.upsample)
        out = self.conv2d(x_in)
        return out


class upsample1(nn.Module):
    def __init__(self, base_filter):
        super(upsample1, self).__init__()
        self.conv1 = ConvLayer(base_filter, base_filter, 3, stride=1)
        self.ConvTranspose = UpsampleConvLayer(base_filter, base_filter, kernel_size=3, stride=1, upsample=2) # 转置卷积
        self.cat = ConvLayer1(base_filter * 2, base_filter, kernel_size=1, stride=1)

    def forward(self, x, y):
        y = self.ConvTranspose(y)
        x = self.conv1(x)
          # torch.cat 在给定维度上对输入的张量序列seq进行拼接。
        return self.cat(torch.cat((x, y), dim=1))


class Decoding_block2(nn.Module):
    def __init__(self, base_filter, n_convblock):
        super(Decoding_block2, self).__init__()
        self.n_convblock = n_convblock
        self.upsample = upsample1(base_filter)
        modules_body = []
        for i in range(self.n_convblock - 1):
            modules_body.append(ConvLayer(base_filter, base_filter, 3, stride=1))
        modules_body.append(ConvLayer(base_filter, base_filter, 3, stride=1))
        self.body = nn.Sequential(*modules_body)

    def forward(self, x, y):
        x = self.upsample(x, y)
        for i in range(self.n_convblock):
            x = self.body[i](x)
        return x

3. DEAM 模块和 NLO子网络

# Corresponds to DEAM Module in NLO Sub-network
class Attention_unet(nn.Module):
    # 注意力机制 参数reduction为缩减率
    def __init__(self, channel, reduction=16):
        super(Attention_unet, self).__init__()
   #  // 的用法还没搜到,猜测是换行
        self.conv_du = nn.Sequential(
            ConvLayer1(in_channels=channel, out_channels=channel // reduction, kernel_size=3, stride=1),
            nn.ReLU(inplace=True),
            ConvLayer1(in_channels=channel // reduction, out_channels=channel, kernel_size=3, stride=1),
            nn.Sigmoid()
        )
        self.cat = ConvLayer1(in_channels=channel * 2, out_channels=channel, kernel_size=1, stride=1)
        self.C = ConvLayer1(in_channels=channel, out_channels=channel, kernel_size=3, stride=1)
        self.ConvTranspose = UpsampleConvLayer(channel, channel, kernel_size=3, stride=1, upsample=2)  # up-sampling

    def forward(self, x, g):
        up_g = self.ConvTranspose(g)  # 对应文中SA上采样模块
        weight = self.conv_du(self.cat(torch.cat([self.C(x), up_g], 1)))
        rich_x = torch.mul((1 - weight), up_g) + torch.mul(weight, x)
        return rich_x  # 返回Deam模块的输出

self.conv_du 对应文中的weights mapping模块,即图4中的3维卷积+Relu+3维卷积+Sigmoid

前面加上一个1维卷积连接起来得到(WM模块),weight 即生成的加权张量α

# Corresponds to NLO Sub-network
class ziwangluo1(nn.Module):
    def __init__(self, base_filter, n_convblock_in, n_convblock_out):
        super(ziwangluo1, self).__init__()
       # ConvLayer(in_channels, out_channels, kernel_size, stride)
        self.conv_dila1 = ConvLayer1(64, 64, 3, 1)
        self.conv_dila2 = ConvLayer1(64, 64, 5, 1)
        self.conv_dila3 = ConvLayer1(64, 64, 7, 1)
       
        #  nn.Conv2d的参数dilation:膨胀卷积. Pytoch中dilation默认为1,但是实际为不膨胀
        self.cat1 = torch.nn.Conv2d(in_channels=64 * 3, out_channels=64, kernel_size=1, stride=1, padding=0,
                                    dilation=1, bias=True)
        nn.init.xavier_normal_(self.cat1.weight.data)
        self.e3 = Encoding_block(base_filter, n_convblock_in)
        self.e2 = Encoding_block(base_filter, n_convblock_in)
        self.e1 = Encoding_block(base_filter, n_convblock_in)
        self.e0 = Encoding_block(base_filter, n_convblock_in)
   # 文中迭代次数K=4
        self.attention3 = Attention_unet(base_filter)
        self.attention2 = Attention_unet(base_filter)
        self.attention1 = Attention_unet(base_filter)
        self.attention0 = Attention_unet(base_filter)
        # 定义的ConvLayer 包含Relu(线性修正单元)
        self.mid = nn.Sequential(ConvLayer(base_filter, base_filter, 3, 1),
                                 ConvLayer(base_filter, base_filter, 3, 1))
        self.de3 = Decoding_block2(base_filter, n_convblock_out)
        self.de2 = Decoding_block2(base_filter, n_convblock_out)
        self.de1 = Decoding_block2(base_filter, n_convblock_out)
        self.de0 = Decoding_block2(base_filter, n_convblock_out)

       # 定义的ConvLayer1 不包含Relu
        self.final = ConvLayer1(base_filter, base_filter, 3, stride=1)

    def forward(self, x):
        _input = x
        encode0, down0 = self.e0(x)
        encode1, down1 = self.e1(down0)
        encode2, down2 = self.e2(down1)
        encode3, down3 = self.e3(down2)

        # media_end = self.Encoding_block_end(down3)
        media_end = self.mid(down3)

        g_conv3 = self.attention3(encode3, media_end)
        up3 = self.de3(g_conv3, media_end)
        g_conv2 = self.attention2(encode2, up3)
        up2 = self.de2(g_conv2, up3)

        g_conv1 = self.attention1(encode1, up2)
        up1 = self.de1(g_conv1, up2)

        g_conv0 = self.attention0(encode0, up1)
        up0 = self.de0(g_conv0, up1)

        final = self.final(up0)

        return _input + final


class line(nn.Module):

    def __init__(self):
        super(line, self).__init__()
        self.delta = nn.Parameter(torch.randn(1, 1))

    def forward(self, x, y):
        return torch.mul((1 - self.delta), x) + torch.mul(self.delta, y)
# 对应 DEAM 模块
class SCA(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SCA, self).__init__()
        self.conv_du = nn.Sequential(
            ConvLayer1(in_channels=channel, out_channels=channel // reduction, kernel_size=3, stride=1),
            nn.ReLU(inplace=True),
            ConvLayer1(in_channels=channel // reduction, out_channels=channel, kernel_size=3, stride=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.conv_du(x)
        return y


class Weight(nn.Module):
    def __init__(self, channel):
        super(Weight, self).__init__()
        self.cat = ConvLayer1(in_channels=channel * 2, out_channels=channel, kernel_size=1, stride=1)
        self.C = ConvLayer1(in_channels=channel, out_channels=channel, kernel_size=3, stride=1)
        self.weight = SCA(channel)

    def forward(self, x, y):
        delta = self.weight(self.cat(torch.cat([self.C(y), x], 1)))
        return delta

 4. 图像域转换与逆转换

根据文中介绍转换特征域(FD)与像素域的模块得到下面代码的网络结构

class transform_function(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(transform_function, self).__init__()
        self.ext = ConvLayer1(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1)
        self.pre = torch.nn.Sequential(
            ConvLayer1(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1),
            nn.ReLU(inplace=True),
            ConvLayer1(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1),

        )

    def forward(self, x):
        y = self.ext(x)
        return y + self.pre(y)

# 图像域变换与逆变换中定义的self.pre通道数不同

class Inverse_transform_function(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(Inverse_transform_function, self).__init__()
        self.ext = ConvLayer1(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1)
        self.pre = torch.nn.Sequential(
            ConvLayer1(in_channels=in_channel, out_channels=in_channel, kernel_size=3, stride=1),
            nn.ReLU(inplace=True),
            ConvLayer1(in_channels=in_channel, out_channels=in_channel, kernel_size=3, stride=1),
        )

    def forward(self, x):
        x = self.pre(x) + x
        x = self.ext(x)
        return x

5. 整体DeamNet

class Deam(nn.Module):
    def __init__(self, Isreal):
        super(Deam, self).__init__()
        if Isreal:
        # 由以下代码可知像素域通道数为3(或1,即彩色图像与灰度图),特征域通道数为64
            self.transform_function = transform_function(3, 64)
            self.inverse_transform_function = Inverse_transform_function(64, 3)
        else:
            self.transform_function = transform_function(1, 64)
            self.inverse_transform_function = Inverse_transform_function(64, 1)
       
        # 从网络结构看,转换后的X输入到NLO子网络和DEAM模块中
        self.line11 = Weight(64)
        self.line22 = Weight(64)
        self.line33 = Weight(64)
        self.line44 = Weight(64)

        self.net2 = ziwangluo1(64, 3, 2)   # DeamNet中NLO子网络的参数共享

    def forward(self, x):
        x = self.transform_function(x)
        y = x

        # Corresponds to NLO Sub-network
        x1 = self.net2(y)
        # Corresponds to DEAM Module
        delta_1 = self.line11(x1, y)
       # 这里y对应低分辨率的分支,x1对应高分辨率分支
        x1 = torch.mul((1 - delta_1), x1) + torch.mul(delta_1, y)

        x2 = self.net2(x1)
        delta_2 = self.line22(x2, y)
        x2 = torch.mul((1 - delta_2), x2) + torch.mul(delta_2, y)

        x3 = self.net2(x2)
        delta_3 = self.line33(x3, y)
        x3 = torch.mul((1 - delta_3), x3) + torch.mul(delta_3, y)

        x4 = self.net2(x3)
        delta_4 = self.line44(x4, y)
        x4 = torch.mul((1 - delta_4), x4) + torch.mul(delta_4, y)
        x4 = self.inverse_transform_function(x4)
        return x4


def print_network(net):
    num_params = 0
    for param in net.parameters():
        # += 先将运算符左边和右边的变量值相加,然后将相加的结果赋值给左边的变量
        # param.numel()  返回param中元素的数量
        num_params += param.numel()
    print(net) # 打印的是网络名,没有网络结构
    #  字符串输出: %d 有符号的十进制整数
    print('Total number of parameters: %d' % num_params)

net.parameters() 

逐列表项输出列表元素。构建好神经网络后,网络的参数都保存在parameters()函数当中

与net.named_parameters()的输出相对比,net.parameters()的输出里只包含参数的值,不包含参数的所属信息。

原网站

版权声明
本文为[Claire_Shang]所创,转载请带上原文链接,感谢
https://blog.csdn.net/Claire_wanqing/article/details/126167604