AndersonJo/dqn-pytorch

Missing batch_first attribute for LSTM model.

Opened this issue · 5 comments

By default the BATCH_SIZE = 32.

Input to the LSTM from the CNN is of the shape (32, 64, 16).
The semantics of LSTM input are (seq_len, batch_size, input_size).

But the input format is (batch_size, seq_len, input_size).
To correct it batch_first needs to be passed True while creating the LSTM model.

self.lstm = nn.LSTM(16, LSTM_MEMORY, 1, batch_first=True)

Astonishing part is model is still learning with this error.

wow! As you said it is astonishing haha :)
I will fix this error soon!

Cool. I am fixing it as well. Will try for a PR in a couple of days.

good :) I will wait your PR. I think your contribution is more valuable than my fixing.
After reviewing your PR, I will close this issue. :) good job!

When checking for batch sizes of input to the forward method of LSTM using print(x.shape),
the following is obtained.

torch.Size([32, 4, 84, 84])
torch.Size([32, 4, 84, 84])
torch.Size([32, 4, 84, 84])
torch.Size([31, 4, 84, 84])
torch.Size([32, 4, 84, 84])
torch.Size([31, 4, 84, 84])
torch.Size([32, 4, 84, 84])
torch.Size([30, 4, 84, 84])
torch.Size([32, 4, 84, 84])

This shows that batch size is changing with different inputs. This would break the code in the forward method as hidden_state and cell_state are initialized using method init_states using BATCH_SIZE which would become 32 (currently it uses 64 as batch size).

Any way to make batch size consistent in the input.

Also could you let me know the sources of inspiration for this code. That might help in fixing the issue quicker.
Thanks.

Another issue that would need to be looked at would be batch_size when using init_states method to initialize hidden_state and cell_state.

Hidden/Cell state semantics: (n_layers, batch_size, hidden_size)

Since while training batch_size would be 1 (one sample added at a time to replay memory), train_hidden_state and train_cell_state would use batch_size=1 for the dimension semantics, while dqn_hidden_state and test_hidden_state would be using batch_size=32.

init_states method would be modified to accept batch_size as argument, and return relevant shaped hidden_state and cell_state.