Q, k and D device difference
Closed this issue · 1 comments
leffff commented
Line 45 in 2acf026
Q and K are put onto any device because they are model parameters, while D is created in SimpleRetention._get_D and is not put to any device. Therefore if you train on CUDA, Q and K are on cuda and D is on CPU. Error arises
Jamie-Stirling commented
Thanks for raising and fixing this.