aws/sagemaker-huggingface-inference-toolkit

How to pass device_id to overriden functions?

ririya opened this issue · 9 comments

I want to select the gpu that my custom model is being loaded into. Can't find a way to do this, as when we do self.load_fn = load_fn we lose the reference to self inside load_fn

Hello @ririya,

Could you explain a bit more what you mean with select the GPU? ,e.g. you are having multiple gpus and you want to select which one?

Hello @ririya,

Could you explain a bit more what you mean with select the GPU? ,e.g. you are having multiple gpus and you want to select which one?

Exactly. I’m trying to deploy a model with a g5.12xlarge instance with 4 gpus. However all workers are being allocated on the first gpu. I know the handlers self.device is using TorchServe round robin to select a different gpu for every worker. However I cant find a way to pass that information forward to my load_fn

Got it! But I have to sadly share that's currently not possible with the Inference Toolkit sadly. Since the Inference Toolkit is built on top of MMS and the sagemaker-inference-toolkit there was no way to respect their API design (support for custom methods model_fn...) and also expose the "workers" or "GPU" devices.
I tried to look that up in the past as well, e.g. here.

To unblock you quickly is that you could go with a g5.xlarge and then rather scale the number of instances to 4.

@philschmid thanks for the reply. It’s so weird that SM built this weird interface instead of using class inheritance. At the very least they should have made the device id an argument to the model_fn… I’ll keep investigating this. But scaling up the number of instances looks like the only solution right now

Hi @philschmid @ririya, we recently made changes to pass context from handler service to custom handler functions for the base sagemaker-inference-toolkit. The multi-gpu usage was tested with sagemaker-pytorch-inference-toolkit, since it extends the transformer class from the base toolkit. As of huggingface toolkit, I did not see it extends from the base toolkit. @philschmid could you take a look how we may achieve this?

@waytrue17 this is awesome! I’ll give it a try

@philschmid: The toolkit indeed supports access to the context.
The code below was used as part of inference.py and tested via inference recommender for multi-gpu instances.

def model_fn(model_dir, context):
    logger.debug("model_fn: Creating model")
    pipe = XXPipeline.from_pretrained(pretrained_model_name_or_path=model_dir,
                                                                  torch_dtype=torch.float16)
    # distribute instantiated models among different gpus
    gpu_id = str(context.system_properties.get("gpu_id"))
    logger.debug("gpu_id:" + gpu_id)
    pipe.to("cuda:" + gpu_id)
    
    pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
    logger.debug("model_fn: Model created and served via GPU: " + gpu_id)
    return pipe

It works well with:

pytorch_model = PyTorchModel(
    model_data=pretrained_model_data,
    role=role,
    framework_version="1.12",
    py_version="py38",
    source_dir="....../code",
    entry_point="inference.py",
)

but does not work for HuggingFaceModel. Can you update or get in touch?

Will look into it!

Using PyTorchModel class I get the following libraries inside the SageMaker container as of 2023-04-27:

accelerate==0.18.0
aniso8601==9.0.1
ansi2html==1.8.0
arrow==1.2.3
asttokens==2.2.1
awscli==1.27.110
backcall==0.2.0
boto3==1.26.110
botocore==1.29.110
captum==0.6.0
certifi==2022.12.7
click==8.1.3
colorama==0.4.4
conda==4.12.0
contourpy==1.0.7
cryptography==40.0.1
cycler==0.11.0
Cython==0.29.34
decorator==5.1.1
diffusers==0.16.0
docutils==0.16
enum-compat==0.0.3
executing==1.2.0
filelock==3.12.0
Flask==2.2.3
Flask-RESTful==0.3.9
fonttools==4.39.3
fsspec==2023.4.0
future==0.18.3
huggingface-hub==0.14.1
importlib-metadata==6.3.0
importlib-resources==5.12.0
ipython==8.12.0
itsdangerous==2.1.2
jedi==0.18.2
Jinja2==3.1.2
jmespath==1.0.1
kiwisolver==1.4.4
MarkupSafe==2.1.2
matplotlib==3.7.1
matplotlib-inline==0.1.6
numpy==1.22.2
nvgpu==0.10.0
opencv-python==4.7.0.72
pandas==2.0.0
pexpect==4.8.0
pickleshare==0.7.5
Pillow==9.5.0
pip==23.0.1
portalocker==2.7.0
prompt-toolkit==3.0.38
psutil==5.9.4
ptyprocess==0.7.0
pure-eval==0.2.2
pyasn1==0.4.8
Pygments==2.15.0
pynvml==11.5.0
pyparsing==3.0.9
PyYAML==5.4.1
regex==2023.3.23
retrying==1.3.4
rsa==4.7.2
s3transfer==0.6.0
sagemaker-inference==1.9.2
sagemaker-pytorch-inference==2.0.11
scipy==1.10.1
setuptools==67.6.1
stack-data==0.6.2
tabulate==0.9.0
termcolor==2.2.0
tokenizers==0.13.3
torch-model-archiver==0.6.1
torchserve==0.6.1
tqdm==4.65.0
traitlets==5.9.0
transformers==4.28.1
wcwidth==0.2.6
Werkzeug==2.2.3
wheel==0.40.0
zipp==3.15.0
zstandard==0.19.0