ndif-team/nnsight

AttributeError: '_thread._local' object has no attribute 'device'

atlaie opened this issue · 2 comments

Hi,

first of all thanks for the hard work of putting this library out!

I wanted to try it out and I'm struggling with the minimal example that I share below:

from collections import OrderedDict
import torch
from nnsight import NNsight

input_size = 5
hidden_dims = 10
output_size = 2

net = torch.nn.Sequential(
    OrderedDict(
        [
            ("layer1", torch.nn.Linear(input_size, hidden_dims)),
            ("layer2", torch.nn.Linear(hidden_dims, output_size)),
        ]
    )
).requires_grad_(False)

device = 'mps'
torch.set_default_device(device)
input = torch.rand((1, input_size))#.to(device)
print(input.device)

model = NNsight(net)


with model.trace(input):
    # Save the output before the edit to compare. Notice we apply .clone() before saving as the setting operation is in-place.
    l1_output_before = model.layer1.output.clone().save()
    print("Before:", l1_output_before)
    print(l1_output_before.device)
    
    # Access the 0th index of the hidden state dimension and set it to 0.
    model.layer1.output[:, hidden_dims - 1] = 0
    print('This line was read!')
    
    l1_output_after = model.layer1.output.save()

print("Before:", l1_output_before)
print("After:", l1_output_after)

It outputs an error that I'm guessing has to do with my system being macOS:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[1], [line 26](vscode-notebook-cell:?execution_count=1&line=26)
     [21](vscode-notebook-cell:?execution_count=1&line=21) print(input.device)
     [23](vscode-notebook-cell:?execution_count=1&line=23) model = NNsight(net)
---> [26](vscode-notebook-cell:?execution_count=1&line=26) with model.trace(input):
     [27](vscode-notebook-cell:?execution_count=1&line=27)     # Save the output before the edit to compare. Notice we apply .clone() before saving as the setting operation is in-place.
     [28](vscode-notebook-cell:?execution_count=1&line=28)     l1_output_before = model.layer1.output.clone().save()
     [29](vscode-notebook-cell:?execution_count=1&line=29)     print("Before:", l1_output_before)

File [~/opt/anaconda3/envs/GLMHMM/lib/python3.10/site-packages/nnsight/contexts/Runner.py:41](https://file+.vscode-resource.vscode-cdn.net/Users/atlaie/Desktop/TFM/~/opt/anaconda3/envs/GLMHMM/lib/python3.10/site-packages/nnsight/contexts/Runner.py:41), in Runner.__exit__(self, exc_type, exc_val, exc_tb)
     [39](https://file+.vscode-resource.vscode-cdn.net/Users/atlaie/Desktop/TFM/~/opt/anaconda3/envs/GLMHMM/lib/python3.10/site-packages/nnsight/contexts/Runner.py:39) """On exit, run and generate using the model whether locally or on the server."""
     [40](https://file+.vscode-resource.vscode-cdn.net/Users/atlaie/Desktop/TFM/~/opt/anaconda3/envs/GLMHMM/lib/python3.10/site-packages/nnsight/contexts/Runner.py:40) if isinstance(exc_val, BaseException):
---> [41](https://file+.vscode-resource.vscode-cdn.net/Users/atlaie/Desktop/TFM/~/opt/anaconda3/envs/GLMHMM/lib/python3.10/site-packages/nnsight/contexts/Runner.py:41)     raise exc_val
     [43](https://file+.vscode-resource.vscode-cdn.net/Users/atlaie/Desktop/TFM/~/opt/anaconda3/envs/GLMHMM/lib/python3.10/site-packages/nnsight/contexts/Runner.py:43) if self.remote:
     [44](https://file+.vscode-resource.vscode-cdn.net/Users/atlaie/Desktop/TFM/~/opt/anaconda3/envs/GLMHMM/lib/python3.10/site-packages/nnsight/contexts/Runner.py:44)     self.run_server()

Cell In[1], [line 28](vscode-notebook-cell:?execution_count=1&line=28)
     [23](vscode-notebook-cell:?execution_count=1&line=23) model = NNsight(net)
     [26](vscode-notebook-cell:?execution_count=1&line=26) with model.trace(input):
     [27](vscode-notebook-cell:?execution_count=1&line=27)     # Save the output before the edit to compare. Notice we apply .clone() before saving as the setting operation is in-place.
---> [28](vscode-notebook-cell:?execution_count=1&line=28)     l1_output_before = model.layer1.output.clone().save()
     [29](vscode-notebook-cell:?execution_count=1&line=29)     print("Before:", l1_output_before)
     [30](vscode-notebook-cell:?execution_count=1&line=30)     print(l1_output_before.device)
...
     [68](https://file+.vscode-resource.vscode-cdn.net/Users/atlaie/Desktop/TFM/~/opt/anaconda3/envs/GLMHMM/lib/python3.10/site-packages/nnsight/tracing/Node.py:68)     # That means the tensors as args and the model are different devices but we dont want to have to have the users move tensors to 'meta'
     [69](https://file+.vscode-resource.vscode-cdn.net/Users/atlaie/Desktop/TFM/~/opt/anaconda3/envs/GLMHMM/lib/python3.10/site-packages/nnsight/tracing/Node.py:69)     # So only when theres a FakeTensor with device meta, we move other tensors also to meta.
     [71](https://file+.vscode-resource.vscode-cdn.net/Users/atlaie/Desktop/TFM/~/opt/anaconda3/envs/GLMHMM/lib/python3.10/site-packages/nnsight/tracing/Node.py:71)     def get_device(tensor: torch.Tensor):

AttributeError: '_thread._local' object has no attribute 'device'

I've tried forcing everything to 'mps' but changes only work outside of the trace context (and, thus, are overwritten when I call this context again):

with model.trace(input):
    # Notice we apply .clone() before saving as the setting operation is in-place.
    l1_output_before = model.layer1.output.save()
print("Before:", l1_output_before)
# Access the 0th index of the hidden state dimension and set it to 0.
model.layer1.output[:, hidden_dims - 1] = 0
print("After 1:", l1_output_before)
with model.trace(input):
    # Save the output after to see our edit.
    l1_output_after = model.layer1.output.save()
print("After 2:", l1_output_after)

Before: tensor([[ 0.5571,  0.2771,  0.4906,  0.1573,  0.3036, -0.2014, -0.5157, -0.3143,
         -0.5238, -0.0052]])
After 1: tensor([[ 0.5571,  0.2771,  0.4906,  0.1573,  0.3036, -0.2014, -0.5157, -0.3143,
         -0.5238,  0.0000]])
After 2: tensor([[ 0.5571,  0.2771,  0.4906,  0.1573,  0.3036, -0.2014, -0.5157, -0.3143,
         -0.5238, -0.0052]])

Any suggestions?

Thanks!

Hey @atlaie ! If I had to guess your torch version is 2.3.x? If you go back to torch 2.2 this should work. I was using the torch global device context for a silly reason which breaks in 2.3. I'll have it working in torch 2.3 in the next update.

Here's where I was checking the global device (which breaks in 2.3):

And here's where I was setting it.

torch.set_default_device(device)

Hey @atlaie ! If I had to guess your torch version is 2.3.x? If you go back to torch 2.2 this should work. I was using the torch global device context for a silly reason which breaks in 2.3. I'll have it working in torch 2.3 in the next update.

It worked perfectly, thanks so much for the quick reply too!