locuslab/trellisnet

Possible mistake in LSTM cell

JurijsNazarovs opened this issue · 4 comments

Hello,

I was reading your paper and noticed that code in trellisnet.py (https://github.com/locuslab/trellisnet/blob/master/TrellisNet/trellisnet.py ) lines 124-129, does not correspond to the formula in the paper, section 5.1 formula(12). Could you clarify if this is true or I am wrong and don’t understand something.

Thank you

Hi Jurijs,

Thanks for your interest in our work. What line 124-129 do is basically the following:

it, ot, gt, ft = out.chunk(4, dim=1)
it, ot, gt, ft = torch.sigmoid(it), torch.sigmoid(ot), torch.tanh(gt), torch.sigmoid(ft)
ct = ft * ct_1 + it * gt
ht = ot * torch.tanh(ct)

This corresponds exactly to formula (12), where z_{t+1, 1} corresponds to ct and z_{t+1, 2} corresponds to ht above. I did permute the order of \hat{z}_{*, 1/2/3/4}, which correspond to ft, it, gt and ot, respectively. But this is a trivial change and doesn't affect the correctness of the code.

Hope this helps!

Thanks for your prompt reply. I think the issue is still not clear to me. So, lets look in the following line:

ht = ot * torch.tanh(ct)

According to the paper ot is supposed to be \sigma{\hat{z}{*, 4}}, however, according to the code ot is \sigma{\hat{z}{*, 2}}. Do i understand correct that you claim that this is because of permutation, which is a trivial change?

I am wondering why it would not be an issue, because \hat{z}{*, 2} corresponds to a different part of truncated RNN than \hat{z}{*, 2} according to formula (9).

Oh, sorry, I think you may be confused. \hat{z}_t comes from formula (9) indeed, but I'm not permuting the layers. \hat{z}_t in formula (9) has dimension 4dL (it has L rows), which can be broken into L vectors, each with dimension 4d (i.e., a row in formula (9)). \hat{z}_t^{(i+1)} in formula (11) has dimension 4d. I'm simply permuting within this vector.

If you look at Figure 5 in the appendix of the paper, it, ot, gt and ft correspond to the violet-color blocks. The permutation doesn't matter as long as we make sure we consistently use, for instance, the first d channels for ft, channel d-2d for it, channel 2d-3d for gt and channel 3d-4d for ot.

Shaojie, thanks for your explanation, I understand that there is no mistake in the code, so, issue can be closed.

I think I got it. Do I understand correct that vector of 4 elements in formula(11) does not relate to formula(9), but you define output of trellis convolution as vector of 4 elements to integrate the LSTM cell?