PythonOT/POT

ot.emd2 outputs in CPU

Closed this issue · 2 comments

Hi there,

I am using POT==0.9.4, and for some reason in the following:
M = ot.dist(a, b) # where a and b are torch tensors in GPU
loss = ot.emd2(ab, ab, M)
loss.backward()

M is on gpu but loss is not, which triggers the following error:

RuntimeError                              Traceback (most recent call last)
File <timed exec>:16

File ~/research/ai-env/lib/python3.10/site-packages/torch/_tensor.py:487, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    477 if has_torch_function_unary(self):
    478     return handle_torch_function(
    479         Tensor.backward,
    480         (self,),
   (...)
    485         inputs=inputs,
    486     )
--> 487 torch.autograd.backward(
    488     self, gradient, retain_graph, create_graph, inputs=inputs
    489 )

File ~/research/ai-env/lib/python3.10/site-packages/torch/autograd/__init__.py:200, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    195     retain_graph = create_graph
    197 # The reason we repeat same the comment below is that
    198 # some Python versions print out the first line of a multi-line function
    199 # calls in the traceback and some print out the last line
--> 200 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    201     tensors, grad_tensors_, retain_graph, create_graph, inputs,
    202     allow_unreachable=True, accumulate_grad=True)

RuntimeError: Function ValFunctionBackward returned an invalid gradient at index 3 - expected device cuda:0 but got cpu

Can anyone please advise?
Thanks

Thanks for the Issue. This should not be the case indeed,

could you please give us a short full example where it fails (with the to(device) and everything) ? We can run it on our end and try to debug?

@rflamary thanks for the reply. My bad, I did not check whether the tensor ab is on gpu, which caused the code to crash. Things are working properly now. I want to thank again for this nice library, it is quite powerful!! loving it!