Please provide support for LLama3 or provide example on how to serve it using Easydel
Closed this issue · 4 comments
jchauhan commented
with similar configurations as llama2, we are getting some garbage responses from llama3
s[/S]
The above code will output:
Hello
The `INST` tag is used to indicate that the text should be displayed in an instance of the `Hello` class. The `S` tag is used to indicate that the text should be displayed in a sentence.
[INST] [INST] [INST] [INST] will be the [INST] [INST] and [INST] and [INST] and [INST] and [INST] C
```[INST] and [INST] and the and [INST
``
[INST] and the and the and the and the and the `s
``
``
[S] [INST]s[/S] and the [S][[INST]s[INST[/S][
``s[INST[/S
``
INST[/S]s[INST[/S][S][S][S][s[/S][s][s[/s[/INST[/S][s[/S][[INST[INST[s`<``s[INST[/S][S][S][S[INST[INST[/S][S][S][S][S
S
jchauhan commented
@erfanzar Will you suggest some solution? A bit urgent. Thanks
The real issue is that in order to run llama3, eos_token_id need to be an array of integers.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
As per hugging face examples
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
outputs = model.generate(
input_ids,
max_new_tokens=256,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
erfanzar commented
@jchauhan hi, actually you can use JAXSrever since that's faster and safer option and support dynamic prompt template
erfanzar commented
here's example of using llama3 model on Kaggle GPU T4x2
server_config = JAXServerConfig(
max_sequence_length = 3072,
max_new_tokens = 2048,
max_compile_tokens = 512,
pre_compile=False,
eos_token_id=128009,
temperature=0.3,
top_p=0.95,
top_k=10
)
server = JAXServer.from_torch_pretrained(
pretrained_model_name_or_path="meta-llama/Meta-Llama-3-8B-Instruct",
server_config=server_config,
sharding_axis_dims=(1, 1, 1, -1),
model_config_kwargs=dict(
gradient_checkpointing="",
use_scan_mlp=False,
shard_attention_computation=False,
use_sharded_kv_caching=True,
attn_mechanism="local_ring"
),
dtype=jnp.float16,
param_dtype=jnp.float16,
auto_shard_params=True,
load_in_8bit=True,
input_shape=(1,2048),
torch_dtype=torch.float16,
device_map="cpu" # this one will be passed to transformers.AutoModelForCausalLM
)
prompt = server.format_chat(
prompt="write a poem about stars",
history=[],
system="",
)
pl = 0
for response, used_tokens in server.sample(prompt):
print(response[pl:], end="")
pl = len(response)
# Here's a poem about stars:
# The stars shine bright in the midnight sky,
# A celestial show, beyond the eye,
# Their twinkling light, a beacon in the night,
# Guiding us through the darkness, a guiding light.
# Their beauty is a wonder, a celestial display,
# A reminder of the magic, that's always in.swing,
# A reminder of the mystery, of the stars up high,
# A reminder of the wonder, of the stars that twinkle bright,
# A reminder of the magic, of the stars that shine with a gentle light.
# So let us cherish the stars, and the magic that they bring to our lives.
jchauhan commented
I wanted to run it on TPU v4 instance. Finally I could do it. I had to change the prompter.
Thanks. It was easy.