lucidrains/block-recurrent-transformer-pytorch

Question

Closed this issue · 10 comments

YHL04 commented

Is it supposed to detach?

Inside block_recurrent_transformer_pytorch.py line 815

if exists(layer_next_states):
next_states.append(layer_next_states.detach())

How would the gradients flow through the states?

Screen Shot 2023-03-31 at 3 46 29 PM

yea, i had the same question, but i don't think they are propagating the gradients through the cache. would be interesting to port over some of the ideas from the memformer paper and try a differentiable cache though

YHL04 commented

@YHL04 if they were doing TBPTT, what is their cutoff number of steps? i don't see that in the paper anywhere

YHL04 commented

@YHL04 the way i interpreted the N and W is as the max_seq_len and block_width in the code here

@YHL04 maybe it will be faster to just email the first author 😄

YHL04 commented

@lucidrains Pls let me know what they say

@YHL04 hey! i sent the author an email, and verified that the cache is not differentiable, so the detach is supposed to be there.

however, during the exchange, i realized the source of the confusion is that i was updating the state only once, attending to the entire input segment, rather than one block at a time. can you let me know if the new code changes look more reasonable?

@YHL04 just confirmed with the author that the new changes are correct

thank you again for pressing on this!

YHL04 commented

@YHL04 Sounds good, I will look into it and make some changes myself