Langchain mistral notebook raising KeyError: 'choices'
mauricioarmani opened this issue · 1 comments
mauricioarmani commented
Operating System
Linux
Version Information
Python Version: 3.11.4
Steps to reproduce
I tried to run the Azure Mistral Large endpoint using the notebook provided in this repository.
Expected behavior
The response should contain the keys as described in the Mistral API documentation.
Actual behavior
{
"name": "KeyError",
"message": "'choices'",
"stack": "---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Cell In[6], line 1
----> 1 chat_llm_chain.predict(human_input=\"Hi there my friend\")
File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain/chains/llm.py:293, in LLMChain.predict(self, callbacks, **kwargs)
278 def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
279 \"\"\"Format prompt with kwargs and pass to LLM.
280
281 Args:
(...)
291 completion = llm.predict(adjective=\"funny\")
292 \"\"\"
--> 293 return self(kwargs, callbacks=callbacks)[self.output_key]
File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain_core/_api/deprecation.py:145, in deprecated.<locals>.deprecate.<locals>.warning_emitting_wrapper(*args, **kwargs)
143 warned = True
144 emit_warning()
--> 145 return wrapped(*args, **kwargs)
File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain/chains/base.py:378, in Chain.__call__(self, inputs, return_only_outputs, callbacks, tags, metadata, run_name, include_run_info)
346 \"\"\"Execute the chain.
347
348 Args:
(...)
369 `Chain.output_keys`.
370 \"\"\"
371 config = {
372 \"callbacks\": callbacks,
373 \"tags\": tags,
374 \"metadata\": metadata,
375 \"run_name\": run_name,
376 }
--> 378 return self.invoke(
379 inputs,
380 cast(RunnableConfig, {k: v for k, v in config.items() if v is not None}),
381 return_only_outputs=return_only_outputs,
382 include_run_info=include_run_info,
383 )
File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain/chains/base.py:163, in Chain.invoke(self, input, config, **kwargs)
161 except BaseException as e:
162 run_manager.on_chain_error(e)
--> 163 raise e
164 run_manager.on_chain_end(outputs)
166 if include_run_info:
File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain/chains/base.py:153, in Chain.invoke(self, input, config, **kwargs)
150 try:
151 self._validate_inputs(inputs)
152 outputs = (
--> 153 self._call(inputs, run_manager=run_manager)
154 if new_arg_supported
155 else self._call(inputs)
156 )
158 final_outputs: Dict[str, Any] = self.prep_outputs(
159 inputs, outputs, return_only_outputs
160 )
161 except BaseException as e:
File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain/chains/llm.py:103, in LLMChain._call(self, inputs, run_manager)
98 def _call(
99 self,
100 inputs: Dict[str, Any],
101 run_manager: Optional[CallbackManagerForChainRun] = None,
102 ) -> Dict[str, str]:
--> 103 response = self.generate([inputs], run_manager=run_manager)
104 return self.create_outputs(response)[0]
File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain/chains/llm.py:115, in LLMChain.generate(self, input_list, run_manager)
113 callbacks = run_manager.get_child() if run_manager else None
114 if isinstance(self.llm, BaseLanguageModel):
--> 115 return self.llm.generate_prompt(
116 prompts,
117 stop,
118 callbacks=callbacks,
119 **self.llm_kwargs,
120 )
121 else:
122 results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
123 cast(List, prompts), {\"callbacks\": callbacks}
124 )
File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain_core/language_models/chat_models.py:550, in BaseChatModel.generate_prompt(self, prompts, stop, callbacks, **kwargs)
542 def generate_prompt(
543 self,
544 prompts: List[PromptValue],
(...)
547 **kwargs: Any,
548 ) -> LLMResult:
549 prompt_messages = [p.to_messages() for p in prompts]
--> 550 return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain_core/language_models/chat_models.py:411, in BaseChatModel.generate(self, messages, stop, callbacks, tags, metadata, run_name, run_id, **kwargs)
409 if run_managers:
410 run_managers[i].on_llm_error(e, response=LLMResult(generations=[]))
--> 411 raise e
412 flattened_outputs = [
413 LLMResult(generations=[res.generations], llm_output=res.llm_output) # type: ignore[list-item]
414 for res in results
415 ]
416 llm_output = self._combine_llm_outputs([res.llm_output for res in results])
File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain_core/language_models/chat_models.py:401, in BaseChatModel.generate(self, messages, stop, callbacks, tags, metadata, run_name, run_id, **kwargs)
398 for i, m in enumerate(messages):
399 try:
400 results.append(
--> 401 self._generate_with_cache(
402 m,
403 stop=stop,
404 run_manager=run_managers[i] if run_managers else None,
405 **kwargs,
406 )
407 )
408 except BaseException as e:
409 if run_managers:
File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain_core/language_models/chat_models.py:618, in BaseChatModel._generate_with_cache(self, messages, stop, run_manager, **kwargs)
616 else:
617 if inspect.signature(self._generate).parameters.get(\"run_manager\"):
--> 618 result = self._generate(
619 messages, stop=stop, run_manager=run_manager, **kwargs
620 )
621 else:
622 result = self._generate(messages, stop=stop, **kwargs)
File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain_mistralai/chat_models.py:312, in ChatMistralAI._generate(self, messages, stop, run_manager, stream, **kwargs)
308 params = {**params, **kwargs}
309 response = self.completion_with_retry(
310 messages=message_dicts, run_manager=run_manager, **params
311 )
--> 312 return self._create_chat_result(response)
File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain_mistralai/chat_models.py:316, in ChatMistralAI._create_chat_result(self, response)
314 def _create_chat_result(self, response: Dict) -> ChatResult:
315 generations = []
--> 316 for res in response[\"choices\"]:
317 finish_reason = res.get(\"finish_reason\")
318 gen = ChatGeneration(
319 message=_convert_mistral_chat_message_to_message(res[\"message\"]),
320 generation_info={\"finish_reason\": finish_reason},
321 )
KeyError: 'choices'"
}
Addition information
This code works:
import requests
import json
url = "https://<endpoint>.<region>.inference.ai.azure.com" + "/v1/chat/completions"
headers = {
"Authorization": "Bearer " + "<key>",
"Content-type": "application/json",
}
system_content = "You are a helpful assistant."
user_content = """Hello!"""
temperature = 0.7
max_tokens = 500
data = {
"messages": [
{
"role": "system",
"content": system_content,
},
{"role": "user", "content": user_content},
],
"temperature": temperature,
"max_tokens": max_tokens,
}
response = requests.post(url, headers=headers, json=data)
print("Request data:", json.dumps(data, indent=4))
# Parse the response content
json_response = response.json()
print("Response: ", json_response)#["choices"][0]["message"]["content"])
Output:
Request data: {
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "Hello!"
}
],
"temperature": 0.7,
"max_tokens": 500
}
Response: {'choices': [{'finish_reason': 'stop', 'index': 0, 'message': {'content': " Hello! I'm here to help answer your ...
mauricioarmani commented
Solution:
chat_model = ChatMistralAI(
endpoint="https://<endpoint>.<region>.inference.ai.azure.com" + "/v1",
mistral_api_key="<key>",
)
I just appended a "/v1" because langchain only appends "/chat/completions".