Add warning if max_tokens is too short?
Opened this issue · 8 comments
import dspy
from dspy.evaluate import Evaluate
from dspy.teleprompt import BootstrapFewShot, BootstrapFewShotWithRandomSearch
llm = dspy.LM("gemini/gemini-1.5-pro", max_tokens=20)
dspy.configure(lm=llm)
class BasicQA(dspy.Signature):
"""Answer questions with short factoid answers."""
question = dspy.InputField()
answer = dspy.OutputField(desc="often between 1 and 5 words")
#Pass signature to ChainOfThought module
generate_answer = dspy.ChainOfThought(BasicQA)
# Call the predictor on a particular input.
question='What is the color of the sky?'
pred = generate_answer(question=question)
print(f"Question: {question}")
print(f"Predicted Answer: {pred.answer}")
gives me ValueError: Expected dict_keys(['reasoning', 'answer']) but got dict_keys(['reasoning'])
(I've configured Gemini's keys correctly)
Are you on latest DSPy? > 2.5.20?
Update: I see the issue. Why are you setting max_tokens=20
? Can you remove that?
Ah, I see; my bad. I simply shifted from Predict -> CoT without considering the tokens. It's fixed now, thanks @okhat ! Alternatively, can this error be handled better? My first intuition would be to add warnings to make sure some thumb rules are satisfied, but happy to contribute if there is an elegant solution to this
@vballoli Yes, indeed, a warning would be appreciated! If you could open a PR for that, I would be super grateful.
in clients/lm.py when the "termination reason" is "ran out of tokens", we can issue a warning?
@okhat should I add the warning here instead: 20 tokens is reasonable in a dspy.Predict
setting maybe but not in CoT, so we can simply add the warning here: https://github.com/stanfordnlp/dspy/blob/b88caa3228512df3d56ba5a9320cd4476389c7ae/dspy/predict/chain_of_thought.py#L37C60-L37C66
if config.get('max_tokens') < 100:
warnings.warn("The program might run into an error due to insufficient tokens") # alternatively, simply raise an error but unsure what number to choose
self._predict = ....
No this is not about cot — it may be about how many output fields exist in predict
Yeah that makes more sense, let me know which options fits better and I'll send the PR:
I think the warning should be triggered during generation. One field can be extremely long (asking for 300 lines of code) and many fields can be super short (asking for one letter x 10 times).
Warning should happen in lm.py if the generation doesn't finish.
This is the trace of the error under 20 tokens
/usr/local/lib/python3.10/dist-packages/dspy/predict/predict.py in v2_5_generate(lm, lm_kwargs, signature, demos, inputs, _parse_values)
262 adapter = dspy.settings.adapter or dspy.ChatAdapter()
263
--> 264 return adapter(
265 lm, lm_kwargs=lm_kwargs, signature=signature, demos=demos, inputs=inputs, _parse_values=_parse_values
266 )
/usr/local/lib/python3.10/dist-packages/dspy/adapters/base.py in __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values)
30 from .json_adapter import JsonAdapter
31 if _parse_values and not isinstance(self, JsonAdapter):
---> 32 return JsonAdapter()(lm, lm_kwargs, signature, demos, inputs, _parse_values=_parse_values)
33 raise e
34
/usr/local/lib/python3.10/dist-packages/dspy/adapters/json_adapter.py in __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values)
43
44 for output in outputs:
---> 45 value = self.parse(signature, output, _parse_values=_parse_values)
46 assert set(value.keys()) == set(signature.output_fields.keys()), f"Expected {signature.output_fields.keys()} but got {value.keys()}"
47 values.append(value)
/usr/local/lib/python3.10/dist-packages/dspy/utils/callback.py in wrapper(instance, *args, **kwargs)
200 # If no callbacks are provided, just call the function
201 if not callbacks:
--> 202 return fn(instance, *args, **kwargs)
203
204 # Generate call ID as the unique identifier for the call, this is useful for instrumentation.
/usr/local/lib/python3.10/dist-packages/dspy/adapters/json_adapter.py in parse(self, signature, completion, _parse_values)
84
85 if fields.keys() != signature.output_fields.keys():
---> 86 raise ValueError(f"Expected {signature.output_fields.keys()} but got {fields.keys()}")
87
88 return fields
ValueError: Expected dict_keys(['reasoning', 'answer']) but got dict_keys(['reasoning'])
Based on my understanding, LLM output is generated, but the adapter raises the error due to the keys mismatch - maybe an additional line to the ValueError message can be added where insufficient tokens can be raised