ndif-team/nnsight

Error when using stopping_criteria in .generate if remote=True

Opened this issue · 5 comments

Remote execution does not support stopping_criteria right now:

from nnsight import LanguageModel
from transformers import StoppingCriteria

class Stopping(StoppingCriteria):
    def __init__(self):
        pass

    def __call__(self, input_ids, _scores, **_kwargs):
        return False  # Continue generation

    def __len__(self):
        return 1

    def __iter__(self):
        yield self

nn_model = LanguageModel("meta-llama/Llama-2-70b-hf")

stopping_criteria = Stopping()
with nn_model.generate("hello", remote=True, stopping_criteria=stopping_criteria) as tracer:
    out = nn_model.generator.output.save()
print(nn_model.tokenizer.decode(out[0]))

Error trace:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[47], line 20
     17 nn_model = LanguageModel("meta-llama/Llama-2-70b-hf")
     19 stopping_criteria = StoppingCriteria()
---> 20 with nn_model.generate("hello", remote=True, stopping_criteria=stopping_criteria) as tracer:
     21     out = nn_model.generator.output.save()
     22 print(nn_model.tokenizer.decode(out[0]))

File /dlabscratch1/cdumas/.conda_envs/llmenglish/lib/python3.11/site-packages/nnsight/models/mixins/Generation.py:11, in GenerationMixin.generate(self, *args, **kwargs)
      9 def generate(self, *args, **kwargs) -> Runner:
---> 11     return self.trace(*args, generate=True, **kwargs)

File /dlabscratch1/cdumas/.conda_envs/llmenglish/lib/python3.11/site-packages/nnsight/models/NNsightModel.py:196, in NNsight.trace(self, trace, invoker_args, scan, *inputs, **kwargs)
    193         return output.value
    195     # Otherwise open an invoker context with the give args.
--> 196     runner.invoke(*inputs, **invoker_args).__enter__()
    198 # If trace is False, you had to have provided an input.
    199 if not trace:

File /dlabscratch1/cdumas/.conda_envs/llmenglish/lib/python3.11/site-packages/nnsight/contexts/Invoker.py:69, in Invoker.__enter__(self)
     64     with FakeTensorMode(
     65         allow_non_fake_inputs=True,
     66         shape_env=ShapeEnv(assume_static_by_default=True),
     67     ) as fake_mode:
     68         with FakeCopyMode(fake_mode):
---> 69             self.tracer._model._execute(
     70                 *copy.deepcopy(self.inputs),
     71                 **copy.deepcopy(self.tracer._kwargs),
     72             )
     74     self.scanning = False
     76 else:

File /dlabscratch1/cdumas/.conda_envs/llmenglish/lib/python3.11/site-packages/nnsight/models/mixins/Generation.py:19, in GenerationMixin._execute(self, prepared_inputs, generate, *args, **kwargs)
     13 def _execute(
     14     self, prepared_inputs: Any, *args, generate: bool = False, **kwargs
     15 ) -> Any:
     17     if generate:
---> 19         return self._execute_generate(prepared_inputs, *args, **kwargs)
     21     return self._execute_forward(prepared_inputs, *args, **kwargs)

File /dlabscratch1/cdumas/.conda_envs/llmenglish/lib/python3.11/site-packages/nnsight/models/LanguageModel.py:293, in LanguageModel._execute_generate(self, prepared_inputs, max_new_tokens, *args, **kwargs)
    287 def _execute_generate(
    288     self, prepared_inputs: Any, *args, max_new_tokens=1, **kwargs
    289 ):
    291     device = next(self._model.parameters()).device
--> 293     output = self._model.generate(
    294         *args,
    295         **prepared_inputs.to(device),
    296         max_new_tokens=max_new_tokens,
    297         **kwargs,
    298     )
    300     self._model.generator(output)
    302     return output

File /dlabscratch1/cdumas/.conda_envs/llmenglish/lib/python3.11/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File /dlabscratch1/cdumas/.conda_envs/llmenglish/lib/python3.11/site-packages/transformers/generation/utils.py:1533, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   1521 prepared_logits_processor = self._get_logits_processor(
   1522     generation_config=generation_config,
   1523     input_ids_seq_length=input_ids_length,
   (...)
   1529     negative_prompt_attention_mask=negative_prompt_attention_mask,
   1530 )
   1532 # 9. prepare stopping criteria
-> 1533 prepared_stopping_criteria = self._get_stopping_criteria(
   1534     generation_config=generation_config, stopping_criteria=stopping_criteria
   1535 )
   1536 # 10. go into different generation modes
   1537 if generation_mode == GenerationMode.ASSISTED_GENERATION:

File /dlabscratch1/cdumas/.conda_envs/llmenglish/lib/python3.11/site-packages/transformers/generation/utils.py:903, in GenerationMixin._get_stopping_criteria(self, generation_config, stopping_criteria)
    901 if generation_config.eos_token_id is not None:
    902     criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id))
--> 903 criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
    904 return criteria

File /dlabscratch1/cdumas/.conda_envs/llmenglish/lib/python3.11/site-packages/transformers/generation/utils.py:911, in GenerationMixin._merge_criteria_processor_list(self, default_list, custom_list)
    906 def _merge_criteria_processor_list(
    907     self,
    908     default_list: Union[LogitsProcessorList, StoppingCriteriaList],
    909     custom_list: Union[LogitsProcessorList, StoppingCriteriaList],
    910 ) -> Union[LogitsProcessorList, StoppingCriteriaList]:
--> 911     if len(custom_list) == 0:
    912         return default_list
    913     for default in default_list:

TypeError: object of type 'StoppingCriteria' has no len()

@Butanium

I'm not getting the error you posted, but instead a maximum recursion error.

I need to have a better error for this, but how things work now you can't send arbitrary objects to the server.

It has to be one of:
list
dict
tuple
int, string, None, float, bool
Tensor
slice
whitelisted function
nnsight Node

You can see this here: https://github.com/ndif-team/nnsight/blob/main/src/nnsight/pydantics/format/types.py

Hi, sorry I posted the wrong trace, I got a recursion error too 😅
I'll look into it

So if I understand correctly, as every stopping criteria is a different class inheriting StoppingCriteria it might not be possible for nnsight to support this argument on remote execution ?

So if I understand correctly, as every stopping criteria is a different class inheriting StoppingCriteria it might not be possible for nnsight to support this argument on remote execution ?

Yeah ndif/nnsight works with a custom serialized format. So it only supports types we explicitly define. Otherwise anyone could execute arbitrary code with arbitrary classes. If StoppingCriteria seems quite useful maybe I could add it.

Stopping criteria is an abstract class meant to be inherited. If I understand correctly you'd need to manually add a set of class inheriting StoppingCriteria right ?