GenerationMixin._get_logits_warper() missing 1 required positional argument: 'device'
Closed this issue · 3 comments
(Looks extremely interesting - really looking forward to trying it out :)
On an M1 Mac, no CUDA GPU, setting 'device'='cpu':
Python 3.12.4
datasets 2.20.0 pypi_0 pypi
fire 0.6.0 pypi_0 pypi
interegular 0.3.3 pypi_0 pypi
jsonschema 4.23.0 pyhd8ed1ab_0 conda-forge
jsonschema-specifications 2023.12.1 pyhd8ed1ab_0 conda-forge
jsonschema-with-format-nongpl 4.23.0 hd8ed1ab_0 conda-forge
python-fastjsonschema 2.20.0 pyhd8ed1ab_0 conda-forge
torch 2.3.1 pypi_0 pypi
tqdm 4.66.4 pypi_0 pypi
transformers 4.42.3 pypi_0 pypi
TypeError Traceback (most recent call last)
Cell In[2], line 2
1 prompt = "Give me the SQL query to select the name of the employee with the highest salary from the employee table. Given that the employee table has the following columns: name, salary.\n"
----> 2 output = llm.infer(prompt)[0]
3 print(f"LLM output:\n{output}\n")
File ~/mambaforge/envs/syncode/lib/python3.12/site-packages/syncode/infer.py:144, in Syncode.infer(self, prompt, task_id, stop_words)
142 output = FOLEval.run_eval(self, debug_task_id=task_id)
143 elif self.dataset.type == "input":
--> 144 output = self.user_input(prompt, stop_words=stop_words)
145 elif self.dataset.type == "json":
146 output = JSONEval.run_json_eval(self, debug_task_id=task_id, eval_type = self.json_eval_type)
File ~/mambaforge/envs/syncode/lib/python3.12/site-packages/syncode/infer.py:180, in Syncode.user_input(self, prompt, stop_words)
178 return self.model.generate_chat_completion_grammar(prompt)
179 else:
--> 180 return self.model.generate_batch_completion_grammar(prompt, self.num_samples, stop_words=stop_words)
182 else:
183 while True:
File ~/mambaforge/envs/syncode/lib/python3.12/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File ~/mambaforge/envs/syncode/lib/python3.12/site-packages/syncode/language_model.py:97, in HuggingFaceModel.generate_batch_completion_grammar(self, prompt, batch_size, stop_words)
95 # Generate completions
96 if (gen_mode == GenerationMode.SAMPLE or gen_mode == GenerationMode.GREEDY_SEARCH) and batch_size == 1: # Use our own implementation for greedy search and sampling
---> 97 generated_ids = self._generate(
98 inputs,
99 gen_config,
100 gen_mode,
101 grammar_decoder=self.grammar_decoder,
102 stop_criteria=stop_criteria
103 )
104 else:
105 # Use generate from transformers library for other modes
106 if stop_criteria is not None:
File ~/mambaforge/envs/syncode/lib/python3.12/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File ~/mambaforge/envs/syncode/lib/python3.12/site-packages/syncode/language_model.py:154, in HuggingFaceModel._generate(self, inputs, gen_config, gen_mode, grammar_decoder, stop_criteria)
150 """
151 We support greedy search and sampling for batch size 1 otherwise we use the generate function from transformers library.
152 """
153 token_ids, attention_mask, past_key_values = inputs['input_ids'], inputs['attention_mask'], None
--> 154 logit_warper = self.model._get_logits_warper(gen_config)
155 max_tokens = self.gen_args['max_new_tokens']+token_ids.size(1)
157 while True:
TypeError: GenerationMixin._get_logits_warper() missing 1 required positional argument: 'device'
Hi,
I haven't tried running on CPU since this new part in "language_model.py" was added. I think it should be easy to fix, but I'm little busy for next couple of days. Can you actually give it a try? We have to provide the argument device
there explicitly I think.
(If not you can also use SynCode as logit processor in example here (https://github.com/uiuc-focal-lab/syncode/blob/main/notebooks/example_logits_processor.ipynb), this would avoid all this and rely on HuggingFace generation method)
Sure.
So this is a method in the HF transformers utils. It wanted device as string.
Which in that context is available in self.device
.
Strange (and looks pretty hacky) that it is not pulled from the config. Perhaps there is a good reason?
Changing line 154 in language_model.py
logit_warper = self.model._get_logits_warper(gen_config, self.device)
Yes, it seems they changes the function argument device
to be required positional argument here in this commit.
Can you create a short PR with this change?
If we pass the argument explicitly then it will probably crash on old versions of transformers library. We will also need to update requirements.txt
to have transformers >=
certain version.