lucidrains/DALLE-pytorch

padding in PreShiftToken

sklin93 opened this issue · 2 comments

padding = seq_len - n + 1

It seems here assumes n > text_len, and is using padding = img_seq_len - (n - text_len) = seq_len - n + 1, which will give rise to shape errors when using dalle model's generate_text.
It should be padding = img_seq_len - max(n - text_len, 0)?

oh shoot, someone else actually contributed the generate_text functionality, so i didn't account for that

do you want to try a pull request? :)

Sure! I created a pull request #406