Possible mistake in LSTM cell
JurijsNazarovs opened this issue · 4 comments
Hello,
I was reading your paper and noticed that code in trellisnet.py (https://github.com/locuslab/trellisnet/blob/master/TrellisNet/trellisnet.py ) lines 124-129, does not correspond to the formula in the paper, section 5.1 formula(12). Could you clarify if this is true or I am wrong and don’t understand something.
Thank you
Hi Jurijs,
Thanks for your interest in our work. What line 124-129 do is basically the following:
it, ot, gt, ft = out.chunk(4, dim=1)
it, ot, gt, ft = torch.sigmoid(it), torch.sigmoid(ot), torch.tanh(gt), torch.sigmoid(ft)
ct = ft * ct_1 + it * gt
ht = ot * torch.tanh(ct)
This corresponds exactly to formula (12), where z_{t+1, 1} corresponds to ct
and z_{t+1, 2} corresponds to ht
above. I did permute the order of \hat{z}_{*, 1/2/3/4}, which correspond to ft
, it
, gt
and ot
, respectively. But this is a trivial change and doesn't affect the correctness of the code.
Hope this helps!
Thanks for your prompt reply. I think the issue is still not clear to me. So, lets look in the following line:
ht = ot * torch.tanh(ct)
According to the paper ot
is supposed to be \sigma{\hat{z}{*, 4}}, however, according to the code ot
is \sigma{\hat{z}{*, 2}}. Do i understand correct that you claim that this is because of permutation, which is a trivial change?
I am wondering why it would not be an issue, because \hat{z}{*, 2} corresponds to a different part of truncated RNN than \hat{z}{*, 2} according to formula (9).
Oh, sorry, I think you may be confused. \hat{z}_t comes from formula (9) indeed, but I'm not permuting the layers. \hat{z}_t in formula (9) has dimension 4dL (it has L rows), which can be broken into L vectors, each with dimension 4d (i.e., a row in formula (9)). \hat{z}_t^{(i+1)} in formula (11) has dimension 4d. I'm simply permuting within this vector.
If you look at Figure 5 in the appendix of the paper, it
, ot
, gt
and ft
correspond to the violet-color blocks. The permutation doesn't matter as long as we make sure we consistently use, for instance, the first d channels for ft
, channel d-2d for it
, channel 2d-3d for gt
and channel 3d-4d for ot
.
Shaojie, thanks for your explanation, I understand that there is no mistake in the code, so, issue can be closed.
I think I got it. Do I understand correct that vector of 4 elements in formula(11) does not relate to formula(9), but you define output of trellis convolution as vector of 4 elements to integrate the LSTM cell?