syncdoth/RetNet

somewhere that needs to be modified

liujuncn opened this issue · 1 comments

def get_parallel_decay_mask(self, length):

Tensor created in above function is default created on cpu. If train on cuda, it will throw error like below:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

You are correct. The recent commit (48a3984) should solve this issue. Thanks!