shabie/docformer

Again Device issue

kmr2017 opened this issue · 3 comments

I am trying to the code. But I face problem

when I execute below line:
output = docformer(v_bar, t_bar, v_bar_s, t_bar_s) # shape (1, 512, 768)

I get this error.

/usr/local/lib/python3.7/dist-packages/torch/functional.py in einsum(*args)
328 return einsum(equation, *_operands)
329
--> 330 return _VF.einsum(equation, operands) # type: ignore[attr-defined]
331
332 # Wrapper around _histogramdd and _histogramdd_bin_edges needed due to (Tensor, Tensor[]) return type
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_bmm)

This time, I face issue in wrapper_

Maybe, I think, this line is responsible for the issue:

self.scale = torch.sqrt(torch.FloatTensor([embed_dim]))

And, which the below lines are having problem of .to(device), especially when the device is cuda. Would shortly modify it and let you know, if the problem still persists. If possible, can you do try to change the above line from:
self.scale = torch.sqrt(torch.FloatTensor([embed_dim])) to self.scale = embed_dim**0.5, and remove all the .to(device) parts in the below set of lines and let us know?

I would try from my end as well and would let you know. Thanks for pointing this issue out.

Maybe, this couldn't be the case, and there could be something else, but I would let you know soon. And, can you do let us know, the whole part of the code, since the above mentioned line of code, won't help me recreate the bug

Can you do let me know, if the issue has been resolved or not? If not resolved, can you help me with reproducing the error on Google Colab, since then I can definitely try to solve th bug, and update the same in this repo

Regards,
Akarsh

Hi @kmr2017, sorry for the late reply, but I faced this issue just now, and I think I have managed to solve it. You can just clone: https://github.com/uakarsh/docformer, and this would do the thing (as far as I know). And, do let me know if that solves the issue or not. I would shortly include the update in the main branch as well.

Regards,