Azure/azureml-examples

Langchain mistral notebook raising KeyError: 'choices'

mauricioarmani opened this issue · 1 comments

Operating System

Linux

Version Information

Python Version: 3.11.4

Steps to reproduce

I tried to run the Azure Mistral Large endpoint using the notebook provided in this repository.

Expected behavior

The response should contain the keys as described in the Mistral API documentation.

Actual behavior

{
	"name": "KeyError",
	"message": "'choices'",
	"stack": "---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[6], line 1
----> 1 chat_llm_chain.predict(human_input=\"Hi there my friend\")

File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain/chains/llm.py:293, in LLMChain.predict(self, callbacks, **kwargs)
    278 def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
    279     \"\"\"Format prompt with kwargs and pass to LLM.
    280 
    281     Args:
   (...)
    291             completion = llm.predict(adjective=\"funny\")
    292     \"\"\"
--> 293     return self(kwargs, callbacks=callbacks)[self.output_key]

File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain_core/_api/deprecation.py:145, in deprecated.<locals>.deprecate.<locals>.warning_emitting_wrapper(*args, **kwargs)
    143     warned = True
    144     emit_warning()
--> 145 return wrapped(*args, **kwargs)

File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain/chains/base.py:378, in Chain.__call__(self, inputs, return_only_outputs, callbacks, tags, metadata, run_name, include_run_info)
    346 \"\"\"Execute the chain.
    347 
    348 Args:
   (...)
    369         `Chain.output_keys`.
    370 \"\"\"
    371 config = {
    372     \"callbacks\": callbacks,
    373     \"tags\": tags,
    374     \"metadata\": metadata,
    375     \"run_name\": run_name,
    376 }
--> 378 return self.invoke(
    379     inputs,
    380     cast(RunnableConfig, {k: v for k, v in config.items() if v is not None}),
    381     return_only_outputs=return_only_outputs,
    382     include_run_info=include_run_info,
    383 )

File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain/chains/base.py:163, in Chain.invoke(self, input, config, **kwargs)
    161 except BaseException as e:
    162     run_manager.on_chain_error(e)
--> 163     raise e
    164 run_manager.on_chain_end(outputs)
    166 if include_run_info:

File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain/chains/base.py:153, in Chain.invoke(self, input, config, **kwargs)
    150 try:
    151     self._validate_inputs(inputs)
    152     outputs = (
--> 153         self._call(inputs, run_manager=run_manager)
    154         if new_arg_supported
    155         else self._call(inputs)
    156     )
    158     final_outputs: Dict[str, Any] = self.prep_outputs(
    159         inputs, outputs, return_only_outputs
    160     )
    161 except BaseException as e:

File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain/chains/llm.py:103, in LLMChain._call(self, inputs, run_manager)
     98 def _call(
     99     self,
    100     inputs: Dict[str, Any],
    101     run_manager: Optional[CallbackManagerForChainRun] = None,
    102 ) -> Dict[str, str]:
--> 103     response = self.generate([inputs], run_manager=run_manager)
    104     return self.create_outputs(response)[0]

File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain/chains/llm.py:115, in LLMChain.generate(self, input_list, run_manager)
    113 callbacks = run_manager.get_child() if run_manager else None
    114 if isinstance(self.llm, BaseLanguageModel):
--> 115     return self.llm.generate_prompt(
    116         prompts,
    117         stop,
    118         callbacks=callbacks,
    119         **self.llm_kwargs,
    120     )
    121 else:
    122     results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
    123         cast(List, prompts), {\"callbacks\": callbacks}
    124     )

File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain_core/language_models/chat_models.py:550, in BaseChatModel.generate_prompt(self, prompts, stop, callbacks, **kwargs)
    542 def generate_prompt(
    543     self,
    544     prompts: List[PromptValue],
   (...)
    547     **kwargs: Any,
    548 ) -> LLMResult:
    549     prompt_messages = [p.to_messages() for p in prompts]
--> 550     return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)

File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain_core/language_models/chat_models.py:411, in BaseChatModel.generate(self, messages, stop, callbacks, tags, metadata, run_name, run_id, **kwargs)
    409         if run_managers:
    410             run_managers[i].on_llm_error(e, response=LLMResult(generations=[]))
--> 411         raise e
    412 flattened_outputs = [
    413     LLMResult(generations=[res.generations], llm_output=res.llm_output)  # type: ignore[list-item]
    414     for res in results
    415 ]
    416 llm_output = self._combine_llm_outputs([res.llm_output for res in results])

File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain_core/language_models/chat_models.py:401, in BaseChatModel.generate(self, messages, stop, callbacks, tags, metadata, run_name, run_id, **kwargs)
    398 for i, m in enumerate(messages):
    399     try:
    400         results.append(
--> 401             self._generate_with_cache(
    402                 m,
    403                 stop=stop,
    404                 run_manager=run_managers[i] if run_managers else None,
    405                 **kwargs,
    406             )
    407         )
    408     except BaseException as e:
    409         if run_managers:

File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain_core/language_models/chat_models.py:618, in BaseChatModel._generate_with_cache(self, messages, stop, run_manager, **kwargs)
    616 else:
    617     if inspect.signature(self._generate).parameters.get(\"run_manager\"):
--> 618         result = self._generate(
    619             messages, stop=stop, run_manager=run_manager, **kwargs
    620         )
    621     else:
    622         result = self._generate(messages, stop=stop, **kwargs)

File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain_mistralai/chat_models.py:312, in ChatMistralAI._generate(self, messages, stop, run_manager, stream, **kwargs)
    308 params = {**params, **kwargs}
    309 response = self.completion_with_retry(
    310     messages=message_dicts, run_manager=run_manager, **params
    311 )
--> 312 return self._create_chat_result(response)

File ~/anaconda3/envs/pes/lib/python3.11/site-packages/langchain_mistralai/chat_models.py:316, in ChatMistralAI._create_chat_result(self, response)
    314 def _create_chat_result(self, response: Dict) -> ChatResult:
    315     generations = []
--> 316     for res in response[\"choices\"]:
    317         finish_reason = res.get(\"finish_reason\")
    318         gen = ChatGeneration(
    319             message=_convert_mistral_chat_message_to_message(res[\"message\"]),
    320             generation_info={\"finish_reason\": finish_reason},
    321         )

KeyError: 'choices'"
}

Addition information

This code works:

import requests
import json


url = "https://<endpoint>.<region>.inference.ai.azure.com" + "/v1/chat/completions"

headers = {
    "Authorization": "Bearer " + "<key>",
    "Content-type": "application/json",
}

system_content = "You are a helpful assistant."
user_content = """Hello!"""
temperature = 0.7
max_tokens = 500


data = {
    "messages": [
        {
            "role": "system",
            "content": system_content,
        },
        {"role": "user", "content": user_content},
    ],
    "temperature": temperature,
    "max_tokens": max_tokens,
}

response = requests.post(url, headers=headers, json=data)
print("Request data:", json.dumps(data, indent=4))

# Parse the response content
json_response = response.json()
print("Response: ", json_response)#["choices"][0]["message"]["content"])

Output:

Request data: {
    "messages": [
        {
            "role": "system",
            "content": "You are a helpful assistant."
        },
        {
            "role": "user",
            "content": "Hello!"
        }
    ],
    "temperature": 0.7,
    "max_tokens": 500
}
Response:  {'choices': [{'finish_reason': 'stop', 'index': 0, 'message': {'content': " Hello! I'm here to help answer your ...

Solution:

chat_model = ChatMistralAI(
    endpoint="https://<endpoint>.<region>.inference.ai.azure.com" + "/v1",
    mistral_api_key="<key>",
)

I just appended a "/v1" because langchain only appends "/chat/completions".