deepjavalibrary/djl-demo

logit_bias parameter in chat completion api seems to be ignored.

Opened this issue · 0 comments

I deployed a mistral-7b-instruct-v0.1 model an endpoint on sagemaker following this tutorial.

In my particular usecase, I want the llm to output only one token: "0" or "1". Therefore, I am using the parameters logit_bias and max_tokens=1 of the chat completion api.

For logit_bias, I checked the token ids of "0", "1" from the tokenizer.json of mistral-7b-instruct-v0.1. I am therefore using logit_bias={"28734": 100, "28740": 100}

Then I am trying to send requests to this endpoint:

import json
import sagemaker

sagemaker_session = sagemaker.Session(
    boto_session=boto3.Session(
        aws_access_key_id=access_key,
        aws_secret_access_key=secret_key,
        aws_session_token=session_token,
    )
)

sagemaker_client = sagemaker_session.sagemaker_runtime_client

messages = [
    {"role": "system", "content": "Reply only by 0 or 1."},
    {"role": "user", "content": "what is 0+0?"},
    {"role": "assistant", "content": "0"},
    {"role": "user", "content": "what is 1/1?"},
]

content_type = "application/json"
request_body = {
    "messages": messages,
    "model": "mistralai/Mistral-7B-v0.1",
    "logit_bias": {"28734": 100, "28740": 100},  # bias of +100 for "0" and "1" 
    "max_tokens": 1,
}
payload = json.dumps(request_body).encode("utf-8")

endpoint_name = "lmi-mistral-7b-instruct-v01-xxxx-xx-xx-xx-xx-xx-xxx"
response = sagemaker_client.invoke_endpoint(EndpointName=endpoint_name, ContentType=content_type, Body=payload)
result = response["Body"].read()
ans = json.loads(result)
pprint(ans)

The output is:

{'choices': [{'finish_reason': 'length',
              'index': 0,
              'logprobs': None,
              'message': {'content': ' ', 'role': 'assistant'}}],
 'created': 1724127226,
 'id': 'chatcmpl-<built-in function id>',
 'object': 'chat.completion',
 'usage': {'completion_tokens': 1, 'prompt_tokens': 45, 'total_tokens': 46}}

So the output is " " which indicates that the logit_bias parameter is ignored.
Note that even if I penalize the white space token with logit_bias= {"28734": 100, "28740": 100, "29000": -100} the output token is still " ".

Am I doing anything wrong? Or is the logit_bias parameter not actually supported?
Thanks for any help.