当前位置:网站首页>模型训练出现NAN

模型训练出现NAN

2022-08-11 08:52:00 小乐快乐

【功能模块】完整代码在附件,数据集需要的话也可以提供

class EmbeddingImagenet(nn.Cell):
    def __init__(self,emb_size,cifar_flag=False):
        super(EmbeddingImagenet, self).__init__()
        # set size
        self.hidden = 64
        self.last_hidden = self.hidden * 25 if not cifar_flag else self.hidden * 4
        self.emb_size = emb_size
        self.out_dim = emb_size

        # set layers
        self.conv_1 = nn.SequentialCell(nn.Conv2d(in_channels=3,
                                              out_channels=self.hidden,
                                              kernel_size=3,
                                              padding=1,
                                              pad_mode='pad',
                                              has_bias=False),
                                    nn.BatchNorm2d(num_features=self.hidden),
                                    nn.MaxPool2d(kernel_size=2,stride=2),
                                    nn.LeakyReLU(alpha=0.2))
        self.conv_2 = nn.SequentialCell(nn.Conv2d(in_channels=self.hidden,
                                              out_channels=int(self.hidden*1.5),
                                              kernel_size=3,
                                              padding=1,
                                              pad_mode='pad',
                                              has_bias=False),
                                    nn.BatchNorm2d(num_features=int(self.hidden*1.5)),
                                    nn.MaxPool2d(kernel_size=2,stride=2),
                                    nn.LeakyReLU(alpha=0.2))
        self.conv_3 = nn.SequentialCell(nn.Conv2d(in_channels=int(self.hidden*1.5),
                                              out_channels=self.hidden*2,
                                              kernel_size=3,
                                              padding=1,
                                              pad_mode='pad',
                                              has_bias=False),
                                    nn.BatchNorm2d(num_features=self.hidden * 2),
                                    nn.MaxPool2d(kernel_size=2,stride=2),
                                    nn.LeakyReLU(alpha=0.2),
                                    nn.Dropout(0.6))
        self.conv_4 = nn.SequentialCell(nn.Conv2d(in_channels=self.hidden*2,
                                              out_channels=self.hidden*4,
                                              kernel_size=3,
                                              padding=1,
                                              pad_mode='pad',
                                              has_bias=False),
                                    nn.BatchNorm2d(num_features=self.hidden * 4),    # 16 * 64 * (5 * 5)
                                    nn.MaxPool2d(kernel_size=2,stride=2),
                                    nn.LeakyReLU(alpha=0.2),
                                    nn.Dropout(0.5))
        # self.layer_last = nn.SequentialCell(nn.Dense(in_channels=self.last_hidden * 4,
        #                                       out_channels=self.emb_size, has_bias=True),
        #                                 nn.BatchNorm1d(self.emb_size))
        self.layer_last = nn.Dense(in_channels=self.last_hidden * 4,out_channels=self.emb_size, has_bias=True)
        #self.bn = nn.BatchNorm1d(self.emb_size)

    def construct(self, input_data):
        #print("img:",input_data[0])
        x = self.conv_1(input_data)
        x = self.conv_2(x)
        x = self.conv_3(x)
        x = self.conv_4(x)
        #x = ops.Reshape()(x,(x.shape[0],-1))
        print("feat:", input_data[0])
        #x = self.layer_last(x)
        x = self.layer_last(x.view(x.shape[0],-1))
        print("last--------------------------------:",x[0])
        return x
class NodeUpdateNetwork(nn.Cell):
    def __init__(self,
                 in_features,
                 num_features,
                 ratio=[2, 1],
                 dropout=0.0):
        super(NodeUpdateNetwork, self).__init__()
        # set size
        self.in_features = in_features
        self.num_features_list = [num_features * r for r in ratio]
        self.dropout = dropout

        self.eye = ops.Eye()
        self.bmm = ops.BatchMatMul()
        self.cat = ops.Concat(-1)
        self.split = ops.Split(1,2)
        self.repeat = ops.Tile()
        self.unsqueeze = ops.ExpandDims()
        self.squeeze = ops.Squeeze()
        self.transpose = ops.Transpose()


        # layers
        layer_list = OrderedDict()
        for l in range(len(self.num_features_list)):

            layer_list['conv{}'.format(l)] = nn.Conv2d(
                in_channels=self.num_features_list[l - 1] if l > 0 else self.in_features * 3,
                out_channels=self.num_features_list[l],
                kernel_size=1,
                has_bias=False)
            layer_list['norm{}'.format(l)] = nn.BatchNorm2d(num_features=self.num_features_list[l],)
            layer_list['relu{}'.format(l)] = nn.LeakyReLU(alpha=1e-2)

            if self.dropout > 0 and l == (len(self.num_features_list) - 1):
                layer_list['drop{}'.format(l)] = nn.Dropout(keep_prob=1-self.dropout)

        self.network = nn.SequentialCell(layer_list)

    def construct(self, node_feat, edge_feat):
        # get size
        num_tasks = node_feat.shape[0]
        num_data = node_feat.shape[1]

        # get eye matrix (batch_size x 2 x node_size x node_size)
        diag_mask = 1.0 - self.repeat(self.unsqueeze(self.unsqueeze(self.eye(num_data,num_data,ms.float32),0),0),(num_tasks,2,1,1))

        # set diagonal as zero and normalize 原论文是l1归一化
        # edge_feat = edge_feat * diag_mask
        # edge_feat = edge_feat / ops.clip_by_value(ops.ReduceSum(keep_dims=True)(ops.Abs()(edge_feat), -1),Tensor(0,ms.float32),Tensor(num_data,ms.float32))

        edge_feat = ops.L2Normalize(-1)(edge_feat * diag_mask)

        # compute attention and aggregate
        aggr_feat = self.bmm(self.squeeze(ops.Concat(2)(self.split(edge_feat))),node_feat)
        node_feat = self.cat([node_feat,self.cat(ops.Split(1, 2)(aggr_feat))]).swapaxes(1,2)
        #node_feat = self.transpose(self.cat([node_feat,self.cat(ops.Split(1, 2)(aggr_feat))]),(0,2,1))

        node_feat = self.network(self.unsqueeze(node_feat,(-1))).swapaxes(1,2).squeeze()
        #node_feat = self.squeeze(self.transpose(self.network(self.unsqueeze(node_feat,(-1))),(0,2,1,3)))

        return node_feat


class EdgeUpdateNetwork(nn.Cell):
    def __init__(self,
                 in_features,
                 num_features,
                 ratio=[2, 2, 1, 1],
                 separate_dissimilarity=False,
                 dropout=0.0):
        super(EdgeUpdateNetwork, self).__init__()
        # set size
        self.in_features = in_features
        self.num_features_list = [num_features * r for r in ratio]
        self.separate_dissimilarity = separate_dissimilarity
        self.dropout = dropout

        self.eye = ops.Eye()
        self.repeat = ops.Tile()
        self.unsqueeze = ops.ExpandDims()


        # layers
        layer_list = OrderedDict()
        for l in range(len(self.num_features_list)):
            # set layer
            layer_list['conv{}'.format(l)] = nn.Conv2d(in_channels=self.num_features_list[l-1] if l > 0 else self.in_features,
                                                       out_channels=self.num_features_list[l],
                                                       kernel_size=1,
                                                       has_bias=False)
            layer_list['norm{}'.format(l)] = nn.BatchNorm2d(num_features=self.num_features_list[l],
                                                            )
            layer_list['relu{}'.format(l)] = nn.LeakyReLU(alpha=1e-2)

            if self.dropout > 0:
                layer_list['drop{}'.format(l)] = nn.Dropout(keep_prob=1-self.dropout)

        layer_list['conv_out'] = nn.Conv2d(in_channels=self.num_features_list[-1],
                                           out_channels=1,
                                           kernel_size=1)
        self.sim_network = nn.SequentialCell(layer_list)


    def construct(self, node_feat, edge_feat):
        # compute abs(x_i, x_j)

        x_i = ops.ExpandDims()(node_feat,2)
        x_j = x_i.swapaxes(1,2)
        #x_j = ops.Transpose()(x_i,(0,2,1,3))
        #x_ij = (x_i-x_j)**2
        x_ij = ops.Abs()(x_i-x_j)
        #print("x_ij:",x_ij[0,0,:,:])
        x_ij = ops.Transpose()(x_ij,(0,3,2,1))
        sim_val = self.sim_network(x_ij)

        sim_val = ops.Sigmoid()(sim_val)
        #print("sim_val", sim_val[0, 0, :, :])

        dsim_val = 1.0 - sim_val

        diag_mask = 1.0 - self.repeat(self.unsqueeze(self.unsqueeze(self.eye(node_feat.shape[1],node_feat.shape[1],ms.float32),0),0),(node_feat.shape[0],2,1,1))
        edge_feat = edge_feat * diag_mask
        merge_sum = ops.ReduceSum(keep_dims=True)(edge_feat,-1)
        # set diagonal as zero and normalize
        # edge_feat = ops.Concat(1)([sim_val,dsim_val])*edge_feat
        # edge_feat = edge_feat / ops.clip_by_value((ops.ReduceSum(keep_dims=True)(ops.Abs()(edge_feat), -1)),Tensor(0,ms.float32),Tensor(num_data,ms.float32))
        # edge_feat = edge_feat*merge_sum

        edge_feat = ops.L2Normalize(-1)(ops.Concat(1)([sim_val,dsim_val])*edge_feat)*merge_sum

        force_edge_feat = self.repeat(self.unsqueeze(ops.Concat(0)([self.unsqueeze(self.eye(node_feat.shape[1],node_feat.shape[1],ms.float32),0),self.unsqueeze(ops.Zeros()((node_feat.shape[1],node_feat.shape[1]),ms.float32),0)]),0),(node_feat.shape[0],1,1,1))

        edge_feat = edge_feat + force_edge_feat
        edge_feat = edge_feat + 1e-6
        #print("sum_edge",self.repeat(self.unsqueeze(ops.ReduceSum()(edge_feat,1),1),(1,2,1,1))[0,0])
        edge_feat = edge_feat / self.repeat(self.unsqueeze(ops.ReduceSum()(edge_feat,1),1),(1,2,1,1))

        return edge_feat


class GraphNetwork(nn.Cell):
    def __init__(self,
                 in_features,
                 node_features,
                 edge_features,
                 num_layers,
                 dropout=0.0
                 ):
        super(GraphNetwork, self).__init__()
        # set size
        self.in_features = in_features
        self.node_features = node_features
        self.edge_features = edge_features
        self.num_layers = num_layers
        self.dropout = dropout
        self.layers = nn.CellList()
        # for each layer
        for l in range(self.num_layers):
            # set edge to node
            edge2node_net = NodeUpdateNetwork(in_features=self.in_features if l == 0 else self.node_features,
                                              num_features=self.node_features,
                                              dropout=self.dropout if l < self.num_layers-1 else 0.0)

            # set node to edge
            node2edge_net = EdgeUpdateNetwork(in_features=self.node_features,
                                              num_features=self.edge_features,
                                              separate_dissimilarity=False,
                                              dropout=self.dropout if l < self.num_layers-1 else 0.0)
            self.layers.append(nn.CellList([edge2node_net,node2edge_net]))
    # forward
    def construct(self, node_feat, edge_feat):
        # for each layer
        edge_feat_list = []
        #print("node_feat---------------------------------------------------------- -1", node_feat[0, 0, :])
        for l in range(self.num_layers):
            # (1) edge to node
            node_feat = self.layers[l][0](node_feat, edge_feat)
            # (2) node to edge
            edge_feat = self.layers[l][1](node_feat, edge_feat)
            # save edge feature
            edge_feat_list.append(edge_feat)

        return edge_feat_list

【操作步骤&问题现象】

我们代码主要功能是用4层卷积加一层全连接层提取图片特征,之后将图片的特征当成图网络每个节点,用GNN。(代码在附件上)

1、在训练了很多个batch之后,提取出来的特征(经过了4层卷积层和全连接层)出现了很大很大的值,之后几个batch后出现NAN,而在没有经过全连接层的时候,特征数字还是正常的

2、

【截图信息】

这是代码输出的特征

last--------------------------------: [ 1.918492   -0.8280923   2.0575197   0.3089749  -1.0514854   0.5368729
  0.14135109  1.5270222  -1.4794292  -1.4336827   1.0335447  -0.7093582
 -0.41919574 -0.5667086  -0.3535831   1.5567536   0.5002996  -1.4093596
  0.9674009  -0.18156137  0.14888959  0.6358457   1.406878   -0.03820777
 -0.24577822 -0.25783274  0.5756687  -1.4558431  -1.1002262   0.68062806
 -1.6467474   0.88712454  0.3551372  -1.3449378  -1.7011788  -0.8629771
 -0.92482185  0.9867192  -1.5548937   1.340383   -2.299356   -0.3421743
  1.3239275  -1.3792732  -0.31955895 -0.58364254 -3.7381008  -1.2121737
 -0.75104207 -0.7562581   0.04980466  0.45131734 -1.2448095  -0.33418307
  0.86268485 -1.3601649   1.2753168   2.469506   -1.7358601  -2.9104383
 -0.07392117 -0.73263663  0.11657254 -0.05724781  0.34374043 -0.31884825
  0.13456154  2.3561432  -0.18908082  0.5410311   1.7249999   0.9508886
 -0.30631644  1.6836481   1.1513023  -0.33672807 -0.889638   -0.76715356
 -0.7316199   1.597606   -1.6586273   0.4502733   0.5224928  -3.5851111
 -2.906651   -1.5284328   0.83426046  1.354644   -1.4453334   2.0504599
 -1.3200179  -0.50427496  0.97681373  0.30048305  0.17170379  0.8179815
 -0.92994857  1.333491   -1.2931286  -0.3569969   2.7953048  -3.352736
  1.878619    2.018083   -1.1191074  -1.1341975   1.4532931  -0.66957355
  2.3269157  -0.4198427   0.7148121   0.5458231  -1.3050007  -0.34666243
  2.519589    0.804219    0.91191477  1.3088121   0.6767241   2.1667008
  0.24471135  1.2600335  -1.8683847   2.5641935  -0.9636249  -1.0340385
 -0.32570755 -1.7694132 ]
------------------------------------------
------------------------------------------------------------------------------- 1 0.7806913
---------------------------------------------
feat: [[[-1.6726604  -1.6897851  -1.7069099  ...  0.43368444  0.46793392
    0.41655967]
  [-1.7069099  -1.7069099  -1.7069099  ...  0.5364329   0.5193082
    0.4850587 ]
  [-1.7240347  -1.7240347  -1.7069099  ...  0.60493195  0.5535577
    0.4850587 ]
  ...
  [-0.6622999  -0.8335474  -0.8677969  ... -0.02868402  0.00556549
   -0.02868402]
  [-0.6622999  -0.69654936 -0.69654936 ... -0.11430778 -0.11430778
   -0.14855729]
  [-0.95342064 -0.8335474  -0.78217316 ... -0.26843056 -0.30268008
   -0.31980482]]

 [[-1.7556022  -1.7731092  -1.7906162  ... -0.617647   -0.582633
   -0.635154  ]
  [-1.7906162  -1.7906162  -1.7906162  ... -0.512605   -0.512605
   -0.565126  ]
  [-1.8081232  -1.8081232  -1.7906162  ... -0.460084   -0.495098
   -0.565126  ]
  ...
  [-0.28501397 -0.37254897 -0.40756297 ... -1.0028011  -0.9677871
   -1.0203081 ]
  [-0.26750696 -0.33753496 -0.32002798 ... -1.12535    -1.1428571
   -1.160364  ]
  [-0.53011197 -0.53011197 -0.44257697 ... -1.2829131  -1.317927
   -1.317927  ]]

 [[-1.68244    -1.6998693  -1.7172985  ... -1.490719   -1.4558606
   -1.490719  ]
  [-1.7172985  -1.7172985  -1.7172985  ... -1.4732897  -1.4384314
   -1.4732897 ]
  [-1.7347276  -1.7347276  -1.7172985  ... -1.4558606  -1.4732897
   -1.5255773 ]
  ...
  [-1.3338562  -1.4210021  -1.4210021  ... -1.6127234  -1.5430065
   -1.5604358 ]
  [-1.2815686  -1.3512855  -1.3338562  ... -1.6127234  -1.5952941
   -1.6127234 ]
  [-1.5081482  -1.4732897  -1.4210021  ... -1.5778649  -1.6127234
   -1.6301525 ]]]
last--------------------------------: [-9.7715964e+37 -1.3229437e+37 -1.5262715e+38 -2.5811514e+38
  3.2964988e+38 -7.1266450e+37 -7.2963347e+37 -3.0699307e+38
 -1.6108344e+38  5.8011444e+37 -3.9925391e+37 -9.5891957e+37
 -1.7783365e+38  2.2280316e+38 -4.4186918e+37  3.4825655e+37
  5.8457292e+37  7.2160006e+37  1.4259578e+38  9.4037617e+37
  7.4650717e+37  1.8146209e+37 -2.5143476e+38  2.4387442e+38
 -7.5397363e+37  1.4157064e+38 -1.1084308e+38  1.9522180e+38
  2.5864164e+37 -8.5381704e+37  3.3140050e+36 -1.2379668e+38
 -3.3449897e+37  1.6203643e+38  1.4627435e+38  6.6909600e+37
  6.0661751e+37 -1.2335753e+38  1.3377397e+38 -3.7530971e+37
  3.5314601e+37 -1.4393099e+37           -inf -6.0411279e+37
 -7.0721061e+37  1.5951782e+38  9.0163464e+37  1.3680580e+37
 -1.2254094e+37  1.0919689e+38 -1.5229139e+37 -3.4862508e+36
 -8.9739065e+37  2.8713203e+38  9.4768839e+37  7.8658815e+37
 -2.6619306e+38 -7.8224467e+37  6.8780734e+37            inf
 -9.8889302e+37 -1.9009123e+38 -1.4562352e+38 -4.5324568e+37
 -2.6728082e+38  1.0300855e+38 -5.7767852e+37  1.3662499e+37
 -4.0048543e+37 -3.1911765e+37 -1.9702732e+38 -6.5395945e+37
  1.0223747e+38 -2.8775531e+38 -1.1156091e+38 -1.8772822e+38
  1.2472896e+38  1.2465860e+38 -6.7286062e+37 -8.9167649e+37
 -2.8327554e+37 -2.7379526e+37 -1.5994879e+37  1.1577176e+38
  1.1864721e+38  1.7089999e+38 -1.5323652e+37 -1.5374746e+38
  1.2187025e+38 -8.9546139e+37  1.7550813e+38 -5.7048014e+37
 -8.5996788e+37 -5.2310546e+36 -1.4450948e+37 -1.9950120e+37
  4.2429252e+37 -1.4849557e+38  1.0697206e+38 -7.6313524e+37
           -inf  1.7437526e+38 -1.0569269e+38 -1.5577321e+38
 -7.8117285e+37  6.4801082e+37 -3.3032475e+37 -6.4655517e+36
 -2.3770844e+38  1.0880277e+38  3.6430118e+37 -6.9370110e+37
  8.5146681e+37  1.1550550e+38 -2.5614073e+38 -2.1489826e+38
 -8.3233807e+37  2.7233982e+37 -1.3777926e+38 -9.6201629e+37
 -2.1125345e+38 -1.4252791e+36  3.6633845e+37  2.6106833e+37
  9.6643025e+37 -1.4538810e+37 -1.3660478e+38  1.9220696e+38]

1   采用warmup调整一下学习率,最大学习率设置为0.01;

2   采用梯度剪裁方法进行保护;

3   检查最后是否进行归一处理,估计可能取值范围不在0-1之间。

原网站

版权声明
本文为[小乐快乐]所创,转载请带上原文链接,感谢
https://blog.csdn.net/weixin_45666880/article/details/126270389