lanwuwei/SPM_toolkit

some question with DecAtt

Closed this issue · 3 comments

Thanks for sharing your great work!
I am trying to use DecAtt with quora dataset , but the model can not train like you?
the prediction always be 0 , so I print the attention matrix.
raw_attentions = torch.matmul(repr1, repr2) print(raw_attentions)
At beginning, it is filled with big number like this
tensor([[[ 141.6384, 3.2162, 3.8512, ..., 3.0097, 146.9306, 152.0328], [ 3.8926, 0.2856, 0.3139, ..., 0.2734, 5.9284, 6.0491], [ 3.8910, 0.2930, 0.3183, ..., 0.2677, 7.0301, 6.9951], ...,
after many batches it amazingly to be all zero !!
`tensor([[[ 0., 0., 0., ..., 0., 0., 0.],
[ 0., 0., 0., ..., 0., 0., 0.],
[ 0., 0., 0., ..., 0., 0., 0.],
...,
[ 0., 0., 0., ..., 0., 0., 0.],
[ 0., 0., 0., ..., 0., 0., 0.],
[ 0., 0., 0., ..., 0., 0., 0.]],

    [[ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
     [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
     [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
     ...,`

Hi caoxu, the value in raw_attentions should be like this:
Variable containing:
(0 ,.,.) =
5.3814e+02 2.0416e+03 1.1151e+03 ... 1.4076e+03 6.3590e+02 6.7285e+02
1.5274e+03 5.8082e+03 3.1728e+03 ... 4.0065e+03 1.8074e+03 1.9189e+03
1.1055e+03 4.2067e+03 2.2978e+03 ... 2.9015e+03 1.3091e+03 1.3885e+03
... ⋱ ...
1.8057e+03 6.8718e+03 3.7522e+03 ... 4.7419e+03 2.1378e+03 2.2699e+03
1.6039e+02 6.1006e+02 3.3317e+02 ... 4.2068e+02 1.8991e+02 2.0130e+02
5.8381e+02 2.2293e+03 1.2178e+03 ... 1.5392e+03 6.9328e+02 7.3975e+02

(1 ,.,.) =
4.5535e+02 1.0488e+03 6.9031e+02 ... 8.3667e+02 1.0885e+02 6.1863e+02
7.5872e+02 1.7398e+03 1.1399e+03 ... 1.3920e+03 1.7650e+02 1.0247e+03
7.3954e+02 1.7294e+03 1.1695e+03 ... 1.3535e+03 1.9768e+02 1.0320e+03
... ⋱ ...
1.9571e+02 4.5160e+02 2.9675e+02 ... 3.6145e+02 4.7209e+01 2.6582e+02
7.6155e+02 1.7811e+03 1.2092e+03 ... 1.3889e+03 2.0445e+02 1.0656e+03
8.2073e+02 1.8940e+03 1.2585e+03 ... 1.4979e+03 2.0167e+02 1.1235e+03

(2 ,.,.) =
5.4431e+02 1.8314e+03 1.3088e+03 ... 1.0579e+03 5.1971e+02 6.9764e+02
2.3324e+03 7.8572e+03 5.6125e+03 ... 4.5347e+03 2.2297e+03 2.9916e+03
2.1125e+02 7.2039e+02 5.0812e+02 ... 4.1160e+02 2.0584e+02 2.8302e+02
... ⋱ ...
1.3373e+03 4.5043e+03 3.2176e+03 ... 2.6003e+03 1.2789e+03 1.7189e+03
7.2755e+02 2.4492e+03 1.7498e+03 ... 1.4139e+03 6.9587e+02 9.3610e+02
6.3130e+02 2.1277e+03 1.5178e+03 ... 1.2268e+03 6.0562e+02 8.1883e+02
...

(29,.,.) =
5.0561e+02 2.7038e+03 1.8770e+02 ... 9.5561e+02 3.8252e+02 6.9260e+02
1.1009e+03 5.8888e+03 4.1470e+02 ... 2.0726e+03 8.3484e+02 1.5186e+03
1.0668e+03 5.6989e+03 4.0530e+02 ... 2.0070e+03 8.2106e+02 1.4735e+03
... ⋱ ...
4.0813e+02 2.1711e+03 1.6243e+02 ... 7.6040e+02 3.2915e+02 5.7295e+02
6.7366e+02 3.6003e+03 2.5612e+02 ... 1.2683e+03 5.1753e+02 9.3004e+02
5.5651e+02 2.9685e+03 2.1236e+02 ... 1.0455e+03 4.3329e+02 7.6979e+02

(30,.,.) =
6.0563e+02 9.5121e+02 1.5719e+03 ... 1.0785e+02 6.7842e+02 7.1051e+02
1.0173e+03 1.6053e+03 2.6459e+03 ... 1.8447e+02 1.1478e+03 1.1903e+03
5.6602e+02 8.9856e+02 1.4739e+03 ... 1.0601e+02 6.3954e+02 6.7012e+02
... ⋱ ...
2.2659e+01 3.7745e+01 5.9815e+01 ... 5.2804e+00 2.6416e+01 2.8450e+01
8.0991e+02 1.2881e+03 2.1093e+03 ... 1.5265e+02 9.1358e+02 9.6303e+02
7.4484e+02 1.1842e+03 1.9399e+03 ... 1.3997e+02 8.4062e+02 8.8444e+02

(31,.,.) =
5.8181e+02 2.4455e+03 3.3938e+03 ... 4.8929e+02 5.4869e+02 6.5787e+02
1.8570e+03 7.8038e+03 1.0841e+04 ... 1.5629e+03 1.7500e+03 2.1087e+03
2.8775e+03 1.2082e+04 1.6771e+04 ... 2.4198e+03 2.7110e+03 3.2494e+03
... ⋱ ...
1.0387e+03 4.3725e+03 6.0737e+03 ... 8.7433e+02 9.8003e+02 1.1845e+03
5.4160e+02 2.2821e+03 3.1693e+03 ... 4.5642e+02 5.1145e+02 6.1883e+02
5.5070e+02 2.3263e+03 3.2337e+03 ... 4.6426e+02 5.2059e+02 6.3572e+02
[torch.FloatTensor of size (32,19,9)]
0 matrix must be wrong. Can you check the input (like repr1/repr2 or sent1/sent2)? I got the above results based on my original code(you may need to change data path and glove path.) For the first batch, I got train accuracy 0.683.

@lanwuwei
Hi lanwuwei, thanks for your quick answer.
I am doing some work on chinese corpus, and I find that if I use my data in DecAtt , it will appear the problems above. But no problem in PWIM.
Can I make friends with you on wechat , I am doing some research work in beijing institute of technoloty.
thanks a lot ! my wechat : caoenjun_

Sure!