ndif-team/nnsight

Cannot use cache if I don't use .value

Butanium opened this issue · 0 comments

context = "a b c"
end = " d"
# tokens = nn_model.to_tokens(context)
with nn_model.trace(context):
    cache = nn_model.output.past_key_values.save()
with nn_model.trace(end, past_key_values=cache.value):  # Fail here if I don't put the `.value`
    output_cache = nn_model.output.logits.save()
with nn_model.trace(context+end):
    output_no_cache = nn_model.output.logits.save()
assert th.allclose(output_cache[:, -1], output_no_cache[:, -1])
output_no_cache.shape, output_cache.shape

Stack trace

---------------------------------------------------------------------------
ReferenceError                            Traceback (most recent call last)
[<ipython-input-49-262069780bf6>](https://localhost:8080/#) in <cell line: 6>()
      4 with nn_model.trace(context):
      5     cache = nn_model.output.past_key_values.save()
----> 6 with nn_model.trace(end, past_key_values=cache):
      7     output_cache = nn_model.output.logits.save()
      8 with nn_model.trace(context+end):

12 frames
[/usr/local/lib/python3.10/dist-packages/nnsight/models/NNsightModel.py](https://localhost:8080/#) in trace(self, trace, invoker_args, scan, *inputs, **kwargs)
    198 
    199             # Otherwise open an invoker context with the give args.
--> 200             runner.invoke(*inputs, **invoker_args).__enter__()
    201 
    202         # If trace is False, you had to have provided an input.

[/usr/local/lib/python3.10/dist-packages/nnsight/contexts/Invoker.py](https://localhost:8080/#) in __enter__(self)
     69                     self.tracer._model._execute(
     70                         *copy.deepcopy(self.inputs),
---> 71                         **copy.deepcopy(self.tracer._kwargs),
     72                     )
     73 

[/usr/lib/python3.10/copy.py](https://localhost:8080/#) in deepcopy(x, memo, _nil)
    144     copier = _deepcopy_dispatch.get(cls)
    145     if copier is not None:
--> 146         y = copier(x, memo)
    147     else:
    148         if issubclass(cls, type):

[/usr/lib/python3.10/copy.py](https://localhost:8080/#) in _deepcopy_dict(x, memo, deepcopy)
    229     memo[id(x)] = y
    230     for key, value in x.items():
--> 231         y[deepcopy(key, memo)] = deepcopy(value, memo)
    232     return y
    233 d[dict] = _deepcopy_dict

[/usr/lib/python3.10/copy.py](https://localhost:8080/#) in deepcopy(x, memo, _nil)
    170                     y = x
    171                 else:
--> 172                     y = _reconstruct(x, memo, *rv)
    173 
    174     # If is its own copy, don't memoize.

[/usr/lib/python3.10/copy.py](https://localhost:8080/#) in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
    269     if state is not None:
    270         if deep:
--> 271             state = deepcopy(state, memo)
    272         if hasattr(y, '__setstate__'):
    273             y.__setstate__(state)

[/usr/lib/python3.10/copy.py](https://localhost:8080/#) in deepcopy(x, memo, _nil)
    144     copier = _deepcopy_dispatch.get(cls)
    145     if copier is not None:
--> 146         y = copier(x, memo)
    147     else:
    148         if issubclass(cls, type):

[/usr/lib/python3.10/copy.py](https://localhost:8080/#) in _deepcopy_dict(x, memo, deepcopy)
    229     memo[id(x)] = y
    230     for key, value in x.items():
--> 231         y[deepcopy(key, memo)] = deepcopy(value, memo)
    232     return y
    233 d[dict] = _deepcopy_dict

[/usr/lib/python3.10/copy.py](https://localhost:8080/#) in deepcopy(x, memo, _nil)
    170                     y = x
    171                 else:
--> 172                     y = _reconstruct(x, memo, *rv)
    173 
    174     # If is its own copy, don't memoize.

[/usr/lib/python3.10/copy.py](https://localhost:8080/#) in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
    269     if state is not None:
    270         if deep:
--> 271             state = deepcopy(state, memo)
    272         if hasattr(y, '__setstate__'):
    273             y.__setstate__(state)

[/usr/lib/python3.10/copy.py](https://localhost:8080/#) in deepcopy(x, memo, _nil)
    144     copier = _deepcopy_dispatch.get(cls)
    145     if copier is not None:
--> 146         y = copier(x, memo)
    147     else:
    148         if issubclass(cls, type):

[/usr/lib/python3.10/copy.py](https://localhost:8080/#) in _deepcopy_dict(x, memo, deepcopy)
    229     memo[id(x)] = y
    230     for key, value in x.items():
--> 231         y[deepcopy(key, memo)] = deepcopy(value, memo)
    232     return y
    233 d[dict] = _deepcopy_dict

[/usr/lib/python3.10/copy.py](https://localhost:8080/#) in deepcopy(x, memo, _nil)
    149             y = _deepcopy_atomic(x, memo)
    150         else:
--> 151             copier = getattr(x, "__deepcopy__", None)
    152             if copier is not None:
    153                 y = copier(memo)

ReferenceError: weakly-referenced object no longer exists