Question
Closed this issue · 10 comments
Is it supposed to detach?
Inside block_recurrent_transformer_pytorch.py line 815
if exists(layer_next_states):
next_states.append(layer_next_states.detach())
How would the gradients flow through the states?
@YHL04 if they were doing TBPTT, what is their cutoff number of steps? i don't see that in the paper anywhere
@YHL04 maybe it will be faster to just email the first author 😄
@lucidrains Pls let me know what they say
@YHL04 hey! i sent the author an email, and verified that the cache is not differentiable, so the detach is supposed to be there.
however, during the exchange, i realized the source of the confusion is that i was updating the state only once, attending to the entire input segment, rather than one block at a time. can you let me know if the new code changes look more reasonable?
@YHL04 just confirmed with the author that the new changes are correct
thank you again for pressing on this!