训练过程出现NaN/Inf
Opened this issue · 1 comments
YangLeiSX commented
你好,我将这个代码里面的GR_Module修改之后加入到我的模型之中,发现在训练过程中一次反向传播之后会出现Nan/Inf的情况。我使用的AdamW优化器,LR=1e-5。尝试改小LR并不能解决问题。在训练代码中加入了梯度裁剪也没有解决问题。
训练过程中打印梯度发现主要是phi_conv这个卷积层的梯度出现nan导致模型训练出错。
不知道你之前有没有遇到类似的问题,如果有,你是怎么解决的。希望得到你的回复,谢谢。
最后附上我修改之后的模块代码。
class GRModule(nn.Module):
def __init__(self, channel, graph_feature=64):
super(GRModule, self).__init__()
self.channel = channel
self.M = graph_feature
self.phi_conv = nn.Sequential(
nn.Conv2d(channel, self.M, kernel_size=3, stride=1, padding=1, bias=True),
nn.ReLU(inplace=True)
)
self.glob_pool_conv = nn.Sequential(
nn.AdaptiveAvgPool2d((1,1)),
nn.Conv2d(channel, self.M, kernel_size=1, stride=1, padding=0, bias=False)
)
self.graph_weight = nn.Sequential(
nn.Conv2d(channel, channel, kernel_size=1, stride=1, padding=0, bias=False),
nn.ReLU(inplace=True)
)
def forward(self, x):
bs = x.shape[0]
# 计算得到 $\phi x$ 和 $\phi^T x$
x_phi_conv = self.phi_conv(x)
x_phi = x_phi_conv.view([x_phi_conv.shape[0], -1, self.M])
x_phi_T = x_phi_conv.view([x_phi_conv.shape[0], self.M, -1])
# 计算得到 $\Lambda$
x_glob_pool_conv = self.glob_pool_conv(x)
x_glob_diag = torch.zeros(bs, self.M, self.M).to('cuda' if x_phi.is_cuda else 'cpu')
for i in range(bs):
x_glob_diag[i, :, :] = torch.diag(x_glob_pool_conv[i, :, :, :].reshape(1, self.M))
# $\tilde A = \phi \Lambda \phi^T$
A_tilde = torch.matmul(torch.matmul(x_phi, x_glob_diag), x_phi_T)
# $\tilde D_{ii} = \sum_j A_{ij}$
D_sqrt_inv = torch.zeros_like(A_tilde).to('cuda' if A_tilde.is_cuda else 'cpu')
diag_sum = torch.sum(A_tilde, 2)
for i in range(bs):
diag_sqrt = 1.0 / torch.sqrt(diag_sum[i, :])
diag_sqrt[torch.isnan(diag_sqrt)] = 0
diag_sqrt[torch.isinf(diag_sqrt)] = 0
D_sqrt_inv[i, :, :] = torch.diag(diag_sqrt)
# $I$
I = torch.eye(D_sqrt_inv.shape[1]).to('cuda' if A_tilde.is_cuda else 'cpu')
I = I.repeat(bs, 1, 1)
# $\tilde L = I - \tilde D_{-\frac{1}{2}} \tilde A \tilde D_{-\frac{1}{2}}$
L_tilde = I - torch.matmul(torch.matmul(D_sqrt_inv, A_tilde), D_sqrt_inv)
# $\sigma(\tilde L X W)$
out = torch.matmul(L_tilde, x.reshape(bs, -1, self.channel))
out = out.reshape(bs, self.channel, x.shape[2], x.shape[3])
out = self.graph_weight(out)
return out
YangLeiSX commented
用我自己的训练脚本和数据直接调用你这里写好的SpyGR也会出现inf/nan的问题。我在怀疑有没有可能是数据的问题。