ndrplz/ConvLSTM_pytorch

torch split into the four gates

arvindmohan opened this issue · 2 comments

Hey I was not sure how the logic in line 49 works

cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)

Since each of the 4 gates has operations of weights with inputs, how is the order of split determined? Why not something like cc_g, cc_f, cc_i, cc_go = torch.split(combined_conv, self.hidden_dim, dim=1)? I am a bit confused how the LSTM equations in https://pytorch.org/docs/stable/nn.html#torch.nn.LSTMCell are implemented here.

Thanks in advance.

Indeed, it's up to you.

In equations you usually have different convolution operations delivering such variables.
Since output channels in a convolution are independent, and such variables are computed from the same input, you can just collapse them in a single operation and split them later.

The order is arbitrary, the change you suggest should work as well.

Interesting...I am still not able to intuitively grasp how changing the order can still give the same results. But I'll find some papers which explain this deeper. Thanks.