RuntimeError: expected scalar type Long but found Float
Closed this issue · 2 comments
bakszero commented
After training the model successfully, I ran the translate command after which I get the following runtime error:
File "translate.py", line 36, in <module>
main(opt)
File "translate.py", line 21, in main
attn_debug=opt.attn_debug
File "/mnt/disks/disk-huge/bakhtiyar/ITDD/onmt/translate/translator.py", line 205, in translate
batch, data, attn_debug, fast=self.fast
File "/mnt/disks/disk-huge/bakhtiyar/ITDD/onmt/translate/translator.py", line 309, in translate_batch
return self._translate_batch(batch, data)
File "/mnt/disks/disk-huge/bakhtiyar/ITDD/onmt/translate/translator.py", line 706, in _translate_batch
beam_attn.data[j, :, :memory_lengths[j]])
File "/mnt/disks/disk-huge/bakhtiyar/ITDD/onmt/translate/beam.py", line 139, in advance
self.attn.append(attn_out.index_select(0, prev_k))
RuntimeError: expected scalar type Long but found Float
I'm using Pytorch version 1.7.1 and torchtext 0.4.0, on Ubuntu 18.04.
I think the prev_k
is a list of float values in this case which creates the error. If I change it to just self.attn.append(attn_out)
, I then get an error at ITDD/onmt/translate/translator.py", line 712, in <lambda> lambda state, dim: state.index_select(dim, select_indices)) RuntimeError: expected scalar type Long but found Float
Please could you help resolve the issue? Thanks!
lizekang commented
You can try torch 1.0.0. We use the version two years ago.
bakszero commented
Looks like it, retraining with 1.0.0 now. Thanks 👍