Jamie-Stirling/RetNet

Q, k and D device difference

Closed this issue · 1 comments

leffff commented

ret = (Q @ K.permute(0, 2, 1)) * D.unsqueeze(0)

Q and K are put onto any device because they are model parameters, while D is created in SimpleRetention._get_D and is not put to any device. Therefore if you train on CUDA, Q and K are on cuda and D is on CPU. Error arises

Thanks for raising and fixing this.