mlverse/torch

torch_device errors in latest builds

Closed this issue · 4 comments

I am encountering

runtime_error("y is not a torch_device")

which is your ==.torch_device method's error message using the latest development build of torch fe44f6b in a script which creates a torch device and works with torch <= 0.11.0.

Moreover I get the following:

> device <- torch_device("cuda:0")
Error in `value[[3L]]()`:
! "length" is not support for objects with class <torch_device/R7>
Run `rlang::last_trace()` to see where the error occurred.

which seems to be a regression caused by #1111

Breaking this down after further investigation:

The "y is not a torch_device" error is actually to do with the use of safetensors. If I set torch.serialization_version to 2 then things work as before. It points to this error occurring during loading of serialised tensors (model and optimiser) saved previously. I haven't had the time to figure out exactly where it happens, but it seems that this is a breaking change for existing code.

The second "length" error only happens within Rstudio, I'm guessing as it automatically tries to determine it's length for the environment pane or something. More of a cosmetic issue.

Thanks I see you've addressed the second point already.

I've pinpointed the first issue. It seems that the 'device' argument of safetensors::safe_load() only takes a character string and not a torch device already created with torch_device().

Specifically where this was causing an issue was with loading of an optimizer state dict to the correct device, in the context of something like the code below:

device <- torch_device("cuda:0")

net <- torch_load(modelfile)
net$to(device = device)$train()
optimiser <- optim_adam(net$parameters)
optimiser$load_state_dict(torch_load(optimfile, device = device))

This worked before, but now the second torch_load() fails, but works if torch_load(optimfile, device = "cuda:0") is specified instead.

Thanks for the investigation @shikokuchuo ! This was really helpful! I think #1122 will fix it

Thanks, can confirm this fixes the issue.