ndif-team/nnsight

NNsight should fail when an Unset proxy from a previous trace / future computation is used in a patching experiment

Opened this issue · 0 comments

This kind of silent failure can make nnsight very hard to debug:

import torch as th
from nnsight import LanguageModel
nn_model = LanguageModel("gpt2", device_map="cpu")

# The patching fails silently because hidden is not set
with nn_model.trace("a"):
    hidden = nn_model.transformer.h[0].output
with nn_model.trace("b"):
    nn_model.transformer.h[0].output = hidden
    corrupted_logits = nn_model.lm_head.output.save()

# The patching will work
with nn_model.trace("a"):
    hidden = nn_model.transformer.h[0].output.save()
with nn_model.trace("b"):
    nn_model.transformer.h[0].output = hidden
    corrupted_logits2 = nn_model.lm_head.output.save()

# The patching will fail silently because h[10].output has not been computed when h[0] is computed
with nn_model.trace("b"):
    nn_model.transformer.h[0].output = nn_model.transformer.h[10].output
    corrupted_logits3 = nn_model.lm_head.output.save()

with nn_model.trace("b"):
    clean_logits = nn_model.lm_head.output.save()
assert not th.allclose(clean_logits, corrupted_logits2), "this assert pass"
assert not th.allclose(clean_logits, corrupted_logits3), "this assert fails"
assert not th.allclose(clean_logits, corrupted_logits), "this assert fails"