aws/sagemaker-huggingface-inference-toolkit

Zero Shot Multi-label text classification

OneManArmy93 opened this issue · 0 comments

Greetings,

I have developed a script on my computer to do some zero shot multi-label text classification using xlm-roberta.
I want to reporduce my work on sagemaker using huggingface inference toolkit and I having some trouble doing so.

On local when i do the classification i do the following:

classifier = pipeline(model="joeddav/xlm-roberta-large-xnli", task="zero-shot-classification")

predictions = classifier(sequence_to_classify, candidate_labels, multi_label=True)

On sagemaker, I configure the model from the hub and launch a batch transform job for inference but i can't seem to find the multi_label parameter in the following:

huggingface_model = HuggingFaceModel(
        transformers_version="4.17.0",
        pytorch_version="1.10.2",
        py_version="py38",
        env=hub,
        role=event['role'])

    bt_output_key = f"s3://{event['bucket']}/{event['output_prefix']}/{event['execution_id']}"

    hf_transformer = huggingface_model.transformer(
        instance_count=event["instance_count"],
        instance_type=event["instance_type"],
        output_path=bt_output_key,
        strategy="SingleRecord",
        max_concurrent_transforms=event["concurrent_transforms"],
    )

    hf_transformer.transform(
        data=event['input_s3_path'],
        content_type="application/json",
        split_type="Line",
        wait=False
    )

I looked in the environment variables list but I think Im missing some thing.
Thank you for your help.