logit_bias parameter in chat completion api seems to be ignored.
Opened this issue · 0 comments
I deployed a mistral-7b-instruct-v0.1 model an endpoint on sagemaker following this tutorial.
In my particular usecase, I want the llm to output only one token: "0" or "1". Therefore, I am using the parameters logit_bias
and max_tokens=1
of the chat completion api.
For logit_bias, I checked the token ids of "0", "1" from the tokenizer.json of mistral-7b-instruct-v0.1. I am therefore using logit_bias={"28734": 100, "28740": 100}
Then I am trying to send requests to this endpoint:
import json
import sagemaker
sagemaker_session = sagemaker.Session(
boto_session=boto3.Session(
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
aws_session_token=session_token,
)
)
sagemaker_client = sagemaker_session.sagemaker_runtime_client
messages = [
{"role": "system", "content": "Reply only by 0 or 1."},
{"role": "user", "content": "what is 0+0?"},
{"role": "assistant", "content": "0"},
{"role": "user", "content": "what is 1/1?"},
]
content_type = "application/json"
request_body = {
"messages": messages,
"model": "mistralai/Mistral-7B-v0.1",
"logit_bias": {"28734": 100, "28740": 100}, # bias of +100 for "0" and "1"
"max_tokens": 1,
}
payload = json.dumps(request_body).encode("utf-8")
endpoint_name = "lmi-mistral-7b-instruct-v01-xxxx-xx-xx-xx-xx-xx-xxx"
response = sagemaker_client.invoke_endpoint(EndpointName=endpoint_name, ContentType=content_type, Body=payload)
result = response["Body"].read()
ans = json.loads(result)
pprint(ans)
The output is:
{'choices': [{'finish_reason': 'length',
'index': 0,
'logprobs': None,
'message': {'content': ' ', 'role': 'assistant'}}],
'created': 1724127226,
'id': 'chatcmpl-<built-in function id>',
'object': 'chat.completion',
'usage': {'completion_tokens': 1, 'prompt_tokens': 45, 'total_tokens': 46}}
So the output is " "
which indicates that the logit_bias
parameter is ignored.
Note that even if I penalize the white space token with logit_bias= {"28734": 100, "28740": 100, "29000": -100}
the output token is still " ".
Am I doing anything wrong? Or is the logit_bias
parameter not actually supported?
Thanks for any help.