RuntimeError: Expected hidden[0] size (2, 1, 512), got [2, 128, 512] - Seq2Seq Model with PreTrained BERT Model
Ninja16180 opened this issue · 0 comments
Hi Ben,
Your suggestion helped resolving the error mentioned in this issue:
#161 (comment)
However, while training the seq2seq model I have encountered this Run time error this time:
My file is here for your reference:
I am sharing my code for your review in the following github repo:
https://github.com/Ninja16180/BERT/blob/main/Training_Seq2Seq_Model_using_Pre-Trained_BERT_Model.ipynb
RuntimeError Traceback (most recent call last)
<ipython-input-63-472071541d41> in <module>()
8 start_time = time.time()
9
---> 10 train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
11 valid_loss = evaluate(model, valid_iterator, criterion)
12
8 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/rnn.py in check_hidden_size(self, hx, expected_hidden_size, msg)
221 msg: str = 'Expected hidden size {}, got {}') -> None:
222 if hx.size() != expected_hidden_size:
--> 223 raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
224
225 def check_forward_args(self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]):
RuntimeError: Expected hidden[0] size (2, 1, 512), got [2, 128, 512]
I am not quite sure why this issue is coming while the expected dimension of the hidden layer is :
[n layers * n directions, batch size, hid dim]
A batch size of 128 has been passed like below:
from torchtext.legacy.data import BucketIterator,TabularDataset
BATCH_SIZE = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iterator, valid_iterator = data.BucketIterator.splits(
(train_data, valid_data),
batch_size = BATCH_SIZE,
device = device)
Not sure why during the training it is expecting the batch size to be 1?
Could you please help me with this?
Also, I am curious to know the value of this 'hidden_size' which is used in both Encoder and Decoder classes?
emb_dim = bert.config.to_dict()['hidden_size']
I can understand that this is the embedding dimension size from the transformer via its config attribute, but what is its value?
And does this have anything to do with the run time error I am currently facing?
Thanks in advance!