ndif-team/nnsight

Cannot use proxy values from previous runs with `remote=True`

Closed this issue · 1 comments

Description

I noticed a discrepancy between remote=False and remote=True behavior when re-using a proxy from a previous run.

When remote=False, I can reuse the previous proxy in a previous tracer/generator context, but it's not possible with remote=True.

Expected Behavior

The behavior should be the same whether remote is True or False.

Reproduction Steps

from nnsight import CONFIG, LanguageModel

CONFIG.set_default_api_key("<your-api-key>")
model = LanguageModel('openai-community/gpt2-xl')

def run(remote: bool):    
    with model.generate("a dog is a dog, a cat is", max_new_tokens=4, remote=remote):
        embedding = model.transformer.wte.output.save()
        output1 = model.generator.output.save()
    
    print(embedding.shape)
    print("All token ids: ", output1)
    print("All prediction: ", model.tokenizer.batch_decode(output1))
    
    tokens_cnt = embedding.shape[1]
    stub_prompt = " ".join("_" * tokens_cnt)
    with model.generate(stub_prompt, max_new_tokens=4, remote=remote):
        model.transformer.wte.output = embedding
        output2 = model.generator.output.save()

    print("All token ids: ", output2)
    print("All prediction: ", model.tokenizer.batch_decode(output2))

run(False) output:

torch.Size([1, 9, 1600])
All token ids:  tensor([[  64, 3290,  318,  257, 3290,   11,  257, 3797,  318,  257, 3797,   11,
          290]])
All prediction:  ['a dog is a dog, a cat is a cat, and']

All token ids:  tensor([[  62, 4808, 4808, 4808, 4808, 4808, 4808, 4808, 4808,  257, 3797,   11,
          290]])
All prediction:  ['_ _ _ _ _ _ _ _ _ a cat, and']

run(True) output:

torch.Size([1, 9, 1600])
All token ids:  tensor([[  64, 3290,  318,  257, 3290,   11,  257, 3797,  318,  257, 3797,   11,
          290]])
All prediction:  ['a dog is a dog, a cat is a cat, and']

All token ids:  tensor([[  62, 4808, 4808, 4808, 4808, 4808, 4808, 4808, 4808, 4808, 4808, 4808,
         4808]])
All prediction:  ['_ _ _ _ _ _ _ _ _ _ _ _ _']

I tried to fix this with #105 . Please let me know if that seems reasonable.