padding in PreShiftToken
sklin93 opened this issue · 2 comments
sklin93 commented
DALLE-pytorch/dalle_pytorch/transformer.py
Line 104 in 459c46a
It seems here assumes n > text_len, and is using padding = img_seq_len - (n - text_len) = seq_len - n + 1, which will give rise to shape errors when using dalle model's generate_text.
It should be padding = img_seq_len - max(n - text_len, 0)
?
lucidrains commented
oh shoot, someone else actually contributed the generate_text
functionality, so i didn't account for that
do you want to try a pull request? :)