pytorch/tensordict

[BUG] Wrong values in TensorDict with device="cpu" specified

JCBrouwer opened this issue · 9 comments

Describe the bug

When creating a TensorDict with a device specified, sometimes the wrong values are present in the resulting dict

To Reproduce

The following code will produce varying values for the first two prints while the third print is always correct. The incorrect results are usually random values, but sometimes I've also received only zeros.

import torch
from tensordict import TensorDict

B = 4

data = {"a": torch.randn((B, 2), device="cuda"), "b": torch.randn((B, 3), device="cuda")}

td = TensorDict(data, batch_size=B, device="cpu")
print(td["a"])

td = TensorDict(data, batch_size=B, device="cpu")
print(td["a"])

td = TensorDict(data, batch_size=B).to("cpu")
print(td["a"])
print(data)

If we move the print of the data to the top, all three prints are always correct.

import torch
from tensordict import TensorDict

B = 4

data = {"a": torch.randn((B, 2), device="cuda"), "b": torch.randn((B, 3), device="cuda")}

print(data)

td = TensorDict(data, batch_size=B, device="cpu")
print(td["a"])

td = TensorDict(data, batch_size=B, device="cpu")
print(td["a"])

td = TensorDict(data, batch_size=B).to("cpu")
print(td["a"])

Expected behavior

The resulting TensorDict should contain the correct values from the original data dict regardless of if device is specified immediately or not. It should also not matter if the data is printed first!?

System info

Describe the characteristic of your environment:

  • tensordict==0.3.2 (from pip)
  • torch==2.2.2 (from pip)
  • Python 3.10.13

Reason and Possible fixes

I guess this has something to do with GPU operations being asynchronous, but CPU operations being synchronous?

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

I believe it's because we use non_blocking=True but don't sync after the tensordict creation, I can fix that!
It won't be the case if you do it the other way (cpu -> cuda) as cuda syncs itself.
Thanks for reporting

In the meantime you can call

torch.cuda.synchronize()

after instantiating the td!

Looking at it, this has already been solved as now non_blocking is an arg of tensordict constructor.
If you set it to True, it will be your responsibility to call synchronize if it's needed.
Hope that makes sense!

Is the use of non_blocking=True also your own responsibility with regular tensors? I was under the impression that torch would synchronize for you if it was necessary, but maybe I'm wrong.

I'd say it's pretty dangerous to have hidden responsibilities enabled by default. If indeed this is a user's own responsibility then the default setting should be non_blocking=False in my opinion.

Is the use of non_blocking=True also your own responsibility with regular tensors? I was under the impression that torch would synchronize for you if it was necessary, but maybe I'm wrong.

For cpu -> cuda yes, for anything else no

I'd say it's pretty dangerous to have hidden responsibilities enabled by default. If indeed this is a user's own responsibility then the default setting should be non_blocking=False in my opinion.

I agree. I can put a synchronize in the tensordict - the only thing blocking me right now is that we should do it only at the root tensordict, not at every node, and i'm not sure how I should achieve that without adding yet another keyword argument to __init__

After further discussion with @albanD we agreed on the following plan:

  • tensordict will always use non_blocking=True during creation
  • if non_blocking=True is passed by the user, we don't call synchronize at the end
  • if non_blocking=False (default) we do call synchronize if device is cpu / mps or any non-cuda

Your calls will be fixed, but you will have the same error if non_blocking is True.

Sounds like a well thought-out resolution. Thanks a lot!

After further discussion with @albanD we agreed on the following plan:

  • tensordict will always use non_blocking=True during creation
  • if non_blocking=True is passed by the user, we don't call synchronize at the end
  • if non_blocking=False (default) we do call synchronize if device is cpu / mps or any non-cuda

Your calls will be fixed, but you will have the same error if non_blocking is True.

Double checking, did you mean to say tensordict will always use non_blocking=False during creation

I meant that under the hood we'll always use non_blocking=True
From a UX perspective, the only different when calling TensorDict(..., non_blocking=True) will be that a call to syncrhonize will not be made after sending the tensor asynchronously to the required device. If False, we call synchronize on cuda (thereby synchronizing all cuda devices) whenever cuda is available, and same for MPS.

Using non_blocking=True will be useful if you want to keep control over which stream is synchronized. For instance, you may have some tensors on cuda:0 that you want to bring to cpu while cuda:1 is busy doing somethings else. In that case, non_blocking=False will dispatch things cuda:0 -> cpu async but then call a synchronize over both cuda:0 and cuda:1, which can be annoying. With non_blocking=True, you can avoid that and call the synchronize over the first stream only.