Custom Inference Code - model_fn() takes more positional argument
Opened this issue · 0 comments
Hello Everyone,
I have been testing this toolkit, trying to do some custom stuff when I got this error message.
Issue
1717449617113,"2024-06-03T21:20:13,924 [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - mms.service.PredictionException: model_fn() takes 1 positional argument but 2 were given : 400"
I was following the readme and overloaded the model_fn(model_dir)
. I had no idea, that this function could get multiple inputs, looking into the original handler implementation I figured that there might be a
My inference code
from transformers import AutoTokenizer, MistralForCausalLM, BitsAndBytesConfig
import torch
def model_fn(model_dir):
model = MistralForCausalLM.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
return model, tokenizer
def predict_fn(data, model_and_tokenizer):
model, tokenizer = model_and_tokenizer
sentences = data.pop("inputs", data)
parameters = data.pop("parameters", None)
inputs = tokenizer(sentences, return_tensors="pt")
if parameters is not None:
outputs = model.generate(**inputs, **parameters)
else:
outputs = model.generate(**inputs)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
Logs
1717449608604,"2024-06-03T21:20:08,521 [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Prediction error"
1717449608604,"2024-06-03T21:20:08,522 [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Traceback (most recent call last):"
1717449608604,"2024-06-03T21:20:08,522 [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - File ""/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py"", line 243, in handle"
1717449608604,"2024-06-03T21:20:08,522 [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - self.initialize(context)"
1717449608604,"2024-06-03T21:20:08,522 [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - File ""/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py"", line 83, in initialize"
1717449608604,"2024-06-03T21:20:08,522 [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - self.model = self.load(*([self.model_dir] + self.load_extra_arg))"
1717449608604,"2024-06-03T21:20:08,522 [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - TypeError: model_fn() takes 1 positional argument but 2 were given"
1717449608604,"2024-06-03T21:20:08,522 [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - "
1717449608604,"2024-06-03T21:20:08,522 [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - During handling of the above exception, another exception occurred:"
1717449608604,"2024-06-03T21:20:08,522 [INFO ] W-9000-model ACCESS_LOG - /169.254.178.2:55126 ""POST /invocations HTTP/1.1"" 400 3"
1717449608604,"2024-06-03T21:20:08,522 [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - "
1717449608604,"2024-06-03T21:20:08,522 [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Traceback (most recent call last):"
1717449608604,"2024-06-03T21:20:08,522 [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - File ""/opt/conda/lib/python3.10/site-packages/mms/service.py"", line 108, in predict"
1717449608604,"2024-06-03T21:20:08,522 [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - ret = self._entry_point(input_batch, self.context)"
1717449608604,"2024-06-03T21:20:08,522 [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - File ""/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py"", line 267, in handle"
1717449608604,"2024-06-03T21:20:08,522 [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - raise PredictionException(str(e), 400)"
1717449612107,"2024-06-03T21:20:08,522 [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - mms.service.PredictionException: model_fn() takes 1 positional argument but 2 were given : 400"
1717449614111,"2024-06-03T21:20:12,046 [INFO ] pool-2-thread-6 ACCESS_LOG - /169.254.178.2:58518 ""GET /ping HTTP/1.1"" 200 0"
1717449614111,"2024-06-03T21:20:13,923 [INFO ] W-9000-model com.amazonaws.ml.mms.wlm.WorkerThread - Backend response time: 1"
1717449614111,"2024-06-03T21:20:13,923 [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Prediction error"
1717449614111,"2024-06-03T21:20:13,923 [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Traceback (most recent call last):"
1717449614111,"2024-06-03T21:20:13,923 [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - File ""/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py"", line 243, in handle"
1717449614111,"2024-06-03T21:20:13,923 [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - self.initialize(context)"
1717449614111,"2024-06-03T21:20:13,923 [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - File ""/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py"", line 83, in initialize"
1717449614111,"2024-06-03T21:20:13,923 [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - self.model = self.load(*([self.model_dir] + self.load_extra_arg))"
1717449614111,"2024-06-03T21:20:13,923 [INFO ] W-9000-model ACCESS_LOG - /169.254.178.2:58518 ""POST /invocations HTTP/1.1"" 400 2"
1717449614111,"2024-06-03T21:20:13,923 [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - TypeError: model_fn() takes 1 positional argument but 2 were given"
1717449614111,"2024-06-03T21:20:13,923 [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - "
1717449614111,"2024-06-03T21:20:13,923 [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - During handling of the above exception, another exception occurred:"
1717449614111,"2024-06-03T21:20:13,923 [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - "
1717449614111,"2024-06-03T21:20:13,923 [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Traceback (most recent call last):"
1717449614111,"2024-06-03T21:20:13,923 [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - File ""/opt/conda/lib/python3.10/site-packages/mms/service.py"", line 108, in predict"
1717449614111,"2024-06-03T21:20:13,923 [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - ret = self._entry_point(input_batch, self.context)"
1717449614111,"2024-06-03T21:20:13,924 [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - File ""/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py"", line 267, in handle"
1717449614111,"2024-06-03T21:20:13,924 [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - raise PredictionException(str(e), 400)"
**1717449617113,"2024-06-03T21:20:13,924 [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - mms.service.PredictionException: model_fn() takes 1 positional argument but 2 were given : 400"**
Solution
I was able to fix it by adding a new parameter to the function - def model_fn(model_dir, temp=None):
.
It was a bit confusing, as the documentation reads as there was only 1 argument.
Is this general that the model_fn
takes 2 arguments or it just happened in my particular case?
Thanks!