ndif-team/nnsight

Cannot load transformer lens model on CPU

Closed this issue · 2 comments

The device argument is just ignored

self, model: str, device: str, *args, processing: bool = True, **kwargs

The easiest way to fix this would be to remove the device argument so that it can be passed to TL via kwargs

EDIT: The below issue was fixed upstream in transformer_lens
Also you can't run the model once it's moved on cpu but I think we can fix it upstream

from nnsight.models.UnifiedTransformer import UnifiedTransformer
nn_model = UnifiedTransformer("gpt2-small").to("cpu")
with nn_model.trace("a"):
    pass
RuntimeError: Unhandled FakeTensor Device Propagation for aten.bmm.default, found two different devices cpu:0, cpu