ot.emd2 outputs in CPU
Closed this issue · 2 comments
hagasam commented
Hi there,
I am using POT==0.9.4, and for some reason in the following:
M = ot.dist(a, b) # where a and b are torch tensors in GPU
loss = ot.emd2(ab, ab, M)
loss.backward()
M is on gpu but loss is not, which triggers the following error:
RuntimeError Traceback (most recent call last)
File <timed exec>:16
File ~/research/ai-env/lib/python3.10/site-packages/torch/_tensor.py:487, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
477 if has_torch_function_unary(self):
478 return handle_torch_function(
479 Tensor.backward,
480 (self,),
(...)
485 inputs=inputs,
486 )
--> 487 torch.autograd.backward(
488 self, gradient, retain_graph, create_graph, inputs=inputs
489 )
File ~/research/ai-env/lib/python3.10/site-packages/torch/autograd/__init__.py:200, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
195 retain_graph = create_graph
197 # The reason we repeat same the comment below is that
198 # some Python versions print out the first line of a multi-line function
199 # calls in the traceback and some print out the last line
--> 200 Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
201 tensors, grad_tensors_, retain_graph, create_graph, inputs,
202 allow_unreachable=True, accumulate_grad=True)
RuntimeError: Function ValFunctionBackward returned an invalid gradient at index 3 - expected device cuda:0 but got cpu
Can anyone please advise?
Thanks
rflamary commented
Thanks for the Issue. This should not be the case indeed,
could you please give us a short full example where it fails (with the to(device) and everything) ? We can run it on our end and try to debug?