Parallel forward
neverix opened this issue · 4 comments
The model's decoder right now only supports sequential decoding. This is because of the way attn_state
is implemented. Parallel generation forward pass can be implemented by setting attn_state
to None
and handling all cases inside generation code
This would help solve #58
I'm not sure what you mean. Are you saying parallel forward over the 256 image tokens? That wouldn't work because each token depends on the previous token. And if you meant parallel over the layers that wouldn't work either since each layer depends on the previous layer's output. Maybe you meant parallel backward?
Right now the code can't just do forward over all tokens because of the caching implementation. It needs to run through every token instead of just masking the attention
Oh I see, it would be for if you wanted to do a forward pass over all tokens at once, instead of sampling one after the other.