sunwng/SpyGR-segmentation

训练过程出现NaN/Inf

Opened this issue · 1 comments

你好,我将这个代码里面的GR_Module修改之后加入到我的模型之中,发现在训练过程中一次反向传播之后会出现Nan/Inf的情况。我使用的AdamW优化器,LR=1e-5。尝试改小LR并不能解决问题。在训练代码中加入了梯度裁剪也没有解决问题。

训练过程中打印梯度发现主要是phi_conv这个卷积层的梯度出现nan导致模型训练出错。

不知道你之前有没有遇到类似的问题,如果有,你是怎么解决的。希望得到你的回复,谢谢。

最后附上我修改之后的模块代码。

class GRModule(nn.Module):
    def __init__(self, channel, graph_feature=64):
        super(GRModule, self).__init__()

        self.channel = channel
        self.M = graph_feature

        self.phi_conv = nn.Sequential(
            nn.Conv2d(channel, self.M, kernel_size=3, stride=1, padding=1, bias=True),
            nn.ReLU(inplace=True)
        )
        self.glob_pool_conv = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Conv2d(channel, self.M, kernel_size=1, stride=1, padding=0, bias=False)
        )

        self.graph_weight = nn.Sequential(
            nn.Conv2d(channel, channel, kernel_size=1, stride=1, padding=0, bias=False),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        bs = x.shape[0]
        # 计算得到 $\phi x$ 和 $\phi^T x$
        x_phi_conv = self.phi_conv(x)
        x_phi = x_phi_conv.view([x_phi_conv.shape[0], -1, self.M])
        x_phi_T = x_phi_conv.view([x_phi_conv.shape[0], self.M, -1])

        # 计算得到 $\Lambda$
        x_glob_pool_conv = self.glob_pool_conv(x)
        x_glob_diag = torch.zeros(bs, self.M, self.M).to('cuda' if x_phi.is_cuda else 'cpu')
        for i in range(bs):
            x_glob_diag[i, :, :] = torch.diag(x_glob_pool_conv[i, :, :, :].reshape(1, self.M))
        
        # $\tilde A = \phi \Lambda \phi^T$
        A_tilde = torch.matmul(torch.matmul(x_phi, x_glob_diag), x_phi_T)
        
        # $\tilde D_{ii} = \sum_j A_{ij}$
        D_sqrt_inv = torch.zeros_like(A_tilde).to('cuda' if A_tilde.is_cuda else 'cpu')
        diag_sum = torch.sum(A_tilde, 2)

        for i in range(bs):
            diag_sqrt = 1.0 / torch.sqrt(diag_sum[i, :])
            diag_sqrt[torch.isnan(diag_sqrt)] = 0
            diag_sqrt[torch.isinf(diag_sqrt)] = 0
            D_sqrt_inv[i, :, :] = torch.diag(diag_sqrt)

        # $I$
        I = torch.eye(D_sqrt_inv.shape[1]).to('cuda' if A_tilde.is_cuda else 'cpu')
        I = I.repeat(bs, 1, 1)

        # $\tilde L = I - \tilde D_{-\frac{1}{2}} \tilde A \tilde D_{-\frac{1}{2}}$
        L_tilde = I - torch.matmul(torch.matmul(D_sqrt_inv, A_tilde), D_sqrt_inv)

        # $\sigma(\tilde L X W)$
        out = torch.matmul(L_tilde, x.reshape(bs, -1, self.channel))
        out = out.reshape(bs, self.channel, x.shape[2], x.shape[3])
        out = self.graph_weight(out)

        return out

用我自己的训练脚本和数据直接调用你这里写好的SpyGR也会出现inf/nan的问题。我在怀疑有没有可能是数据的问题。