jadore801120/attention-is-all-you-need-pytorch

d_k not equal to d_k gives issues

luffycodes opened this issue · 0 comments

hey if I set d_k=8 and d_v=64, it create issues. following is the bug trace.

Traceback (most recent call last):
File "/train.py", line 338, in
main()
File "/train.py", line 273, in main
train(transformer, training_data, validation_data, optimizer, device, opt)
File "/train.py", line 160, in train
model, training_data, optimizer, opt, device, smoothing=opt.label_smoothing)
File "/train.py", line 84, in train_epoch
pred = model(src_seq, trg_seq)
File "/home/zoro/miniconda3/envs/vaswani/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in call
result = self.forward(*input, **kwargs)
File "/transformer/Models.py", line 169, in forward
enc_output, *_ = self.encoder(src_seq, src_mask)
File "/home/zoro/miniconda3/envs/vaswani/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in call
result = self.forward(*input, **kwargs)
File "/transformer/Models.py", line 75, in forward
enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=src_mask)
File "/home/zoro/miniconda3/envs/vaswani/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in call
result = self.forward(*input, **kwargs)
File "/transformer/Layers.py", line 20, in forward
enc_input, enc_input, enc_input, mask=slf_attn_mask)
File "/home/zoro/miniconda3/envs/vaswani/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in call
result = self.forward(*input, **kwargs)
File "/transformer/SubLayers.py", line 54, in forward
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
File "/miniconda3/envs/vaswani/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in call
result = self.forward(*input, **kwargs)
File "/miniconda3/envs/vaswani/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 87, in forward
return F.linear(input, self.weight, self.bias)
File "/miniconda3/envs/vaswani/lib/python3.6/site-packages/torch/nn/functional.py", line 1372, in linear
output = input.matmul(weight.t())
RuntimeError: size mismatch, m1: [4608 x 512], m2: [64 x 512] at /tmp/pip-req-build-ocx5vxk7/aten/src/THC/generic/THCTensorMathBlas.cu:290