aws-neuron/transformers-neuronx

Improve Neuron model loading time

Opened this issue · 4 comments

This is not a bug, but rather a feature request: even when pre-compiled artifacts are available, loading a model on neuron cores can take a very long time.

This seems especially true when loading a model for the first time after an instance as been started, which happens when deploying models through Sagemaker.

For instance, it can take up to 10 minutes to upload a Llama 7b model when deploying through SageMaker (regardless of the instance type).

Hello,

We have recently made some improvements to weight load times by directly supporting safetensors checkpoints.

When loading llama 7b (with a pre-populated compilation cache on trn1.32xlarge) I measure a time of ~40 seconds using a safetensors checkpoint:

import time
from transformers_neuronx import NeuronAutoModelForCausalLM

begin = time.time()
model = NeuronAutoModelForCausalLM.from_pretrained('meta-llama/Llama-2-7b', tp_degree=32)
model.to_neuron()
end = time.time()

print('Duration:', end - begin)

Can you check if using a safetensors checkpoint improves your load duration? If you still observe slow load times, would you be able to provide a reproduction so we can determine exactly which portion of the model load is taking long? Is this maybe occurring only on a specific instance type?

I just tested this change on meta-llama/Llama-2-7b-chat-hf, loading the pre-compiled model from either the legacy split files or directly from safetensor weights.

Export parameters:

  • batch_size 4,
  • tp_degree 2,
  • sequence_length 4096,
  • auto_cast_type fp16.

On a ml.inf2.8xlarge:

split files: model loaded in 43.75 s
safetensors: model loaded in 43.75 s.

So I cannot say there is a benefit from loading safetensor files.

Same test immediately after a reboot, still on an ml.inf2.8xlarge:

split files: Neuron model loaded in 134.06 s.
safetensors: model loaded in 133.50 s.

I did the same test twice after a reboot, and I get consistent results: the model takes longer to load.
Note also that after several attempts, without rebooting, I also get from time to time the same long loading time.