aws/sagemaker-huggingface-inference-toolkit

Custom Inference Code - model_fn() takes more positional argument

Opened this issue · 0 comments

Hello Everyone,

I have been testing this toolkit, trying to do some custom stuff when I got this error message.

Issue

1717449617113,"2024-06-03T21:20:13,924 [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - mms.service.PredictionException: model_fn() takes 1 positional argument but 2 were given : 400"

I was following the readme and overloaded the model_fn(model_dir). I had no idea, that this function could get multiple inputs, looking into the original handler implementation I figured that there might be a

My inference code

from transformers import AutoTokenizer, MistralForCausalLM, BitsAndBytesConfig
import torch

def model_fn(model_dir):

    model = MistralForCausalLM.from_pretrained(model_dir)
    tokenizer = AutoTokenizer.from_pretrained(model_dir)

    return model, tokenizer

def predict_fn(data, model_and_tokenizer):

    model, tokenizer = model_and_tokenizer
    sentences = data.pop("inputs", data)
    parameters = data.pop("parameters", None)

    inputs = tokenizer(sentences, return_tensors="pt")

    if parameters is not None:
        outputs = model.generate(**inputs, **parameters)
    else:
        outputs = model.generate(**inputs)

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

Logs

1717449608604,"2024-06-03T21:20:08,521   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   Prediction error"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Traceback   (most recent call last):"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File   ""/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py"",   line 243, in handle"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     self.initialize(context)"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File   ""/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py"",   line 83, in initialize"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     self.model =   self.load(*([self.model_dir] + self.load_extra_arg))"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   TypeError: model_fn() takes 1 positional argument but 2 were given"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - "
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - During   handling of the above exception, another exception occurred:"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-9000-model ACCESS_LOG - /169.254.178.2:55126 ""POST   /invocations HTTP/1.1"" 400 3"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - "
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Traceback   (most recent call last):"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File   ""/opt/conda/lib/python3.10/site-packages/mms/service.py"",   line 108, in predict"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     ret = self._entry_point(input_batch,   self.context)"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File   ""/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py"",   line 267, in handle"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     raise PredictionException(str(e),   400)"
1717449612107,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   mms.service.PredictionException: model_fn() takes 1 positional argument but 2   were given : 400"
1717449614111,"2024-06-03T21:20:12,046   [INFO ] pool-2-thread-6 ACCESS_LOG - /169.254.178.2:58518 ""GET   /ping HTTP/1.1"" 200 0"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-9000-model com.amazonaws.ml.mms.wlm.WorkerThread - Backend response   time: 1"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   Prediction error"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Traceback   (most recent call last):"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File   ""/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py"",   line 243, in handle"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     self.initialize(context)"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File   ""/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py"",   line 83, in initialize"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     self.model =   self.load(*([self.model_dir] + self.load_extra_arg))"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-9000-model ACCESS_LOG - /169.254.178.2:58518 ""POST   /invocations HTTP/1.1"" 400 2"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   TypeError: model_fn() takes 1 positional argument but 2 were given"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - "
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - During   handling of the above exception, another exception occurred:"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - "
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Traceback   (most recent call last):"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File   ""/opt/conda/lib/python3.10/site-packages/mms/service.py"",   line 108, in predict"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     ret = self._entry_point(input_batch,   self.context)"
1717449614111,"2024-06-03T21:20:13,924   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File   ""/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py"",   line 267, in handle"
1717449614111,"2024-06-03T21:20:13,924   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     raise PredictionException(str(e),   400)"
**1717449617113,"2024-06-03T21:20:13,924   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   mms.service.PredictionException: model_fn() takes 1 positional argument but 2   were given : 400"**

Solution

I was able to fix it by adding a new parameter to the function - def model_fn(model_dir, temp=None):.
It was a bit confusing, as the documentation reads as there was only 1 argument.
Is this general that the model_fn takes 2 arguments or it just happened in my particular case?

Thanks!