NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation.
flozi00 opened this issue ยท 33 comments
Feature request
Longer context up to 8k tokens, the given discussion and notebook generate promising results
Motivation
Discussion: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
Colab Notebook: https://colab.research.google.com/drive/1VI2nhlyKvd5cw4-zHvAIk00cAVj2lCCC#scrollTo=d2ceb547
Your contribution
As it's only 3 lines of code to change it would be pretty easy to change
I will start training an model and give an example demo
Oh nice. And if you want to write a PR that would be awesome too.
Please be mindful that tgi code doesn't do batching the same as transformers
meaning the change is most likely be slightly more complex.
Also lots of models actually defined this buffer directly in their weights instead of instantiating, with unfortunately some downstream differences in generation.
@Narsil Just wanted to chime in here and say I'm working on an implementation and PR for this
@iantbutler01 let me know if you need support at any point
atm i am focused on training such models rather than integration into tgi
I've opened a draft PR, #529
I've tested the fixed NTK Aware scaling and it seems to work, I still need to test dynamic scaling and clean up the PR to comply with contributor guidelines, but I wanted to at least start the discussion.
@iantbutler01 does this method only supports LLaMa models? if yes, why did you add the support in flash_rw_modeling.py?
@GemsFord The method should work for any model using rotary embeddings, its agnostic. My main use is for Falcon 40bn which I've been running locally and testing these changes with.
@iantbutler01 Thanks for adding the support for Falcon, I use that too that's why I asked. I am waiting for your PR to get merged.
Yup, I plan to clean this up and make it ready for review this weekend. I was on vacation and now catching back up with my work, but I will have time this weekend. As far as I can tell the implementation works so it's just a matter of cleaning up and then going through review feedback.
Rope Scaling got merged to the transformers repo
Nice, I don't think that effects this work unless they implemented it in a flash attention enabled module. I'll definitely check it out to make sure my implementation here is correct though
Most interesting is the dynamic ntk aware rope being added
Maybe an option for tgi too adding the dynamic version ?
That's already in my PR :D
Great ๐๐
Hi, any instructions on how to use this after PR is merged? Also, I was thinking why there would be a desync between transformers lib and this repo since it would be too expensive to run LLMs without an inference server and instantiating an instance using the transformers lib alone.
@Narsil I've updated the PR to remove draft status, I think I'm ready for review, just pinging you because you were the earliest responder from HF on this thread.
The associated PR #529 seems to add post-hoc RoPE scaling (for models trained without scaling). Now that linear & dynamic rope scaling got merged into transformers (huggingface/transformers#24653) more models will be fine-tuned with scaled RoPE. For example we (open-assistant) uploaded today a first experiment llama2-13b-orca-8k-3319 which was fine-tuned with 8k context with simple linear scaling, it has in the config.json and can be used out of the box with transformers 4.31.0:
"rope_scaling": {
"factor": 2.0,
"type": "linear"
},
Will support for these kinds of fine-tuned models also be added to TGI? Will a separate PR be required for this?
Currently those models can simply be loaded with TGI but since the rope-scaling is not active the output is gibberish. Until rope-scaled models are supported it might be good to generate an error or warning when rope_scaling
is not None in the model's configuration.
Or will the rope-scaling of the HF transformers llama impl automatically be used one the TGI transformers dependency in requirements.txt is updated (currently it is still transformers==4.29.2
)?
Two separate things, but we'll align with that yes.
@andreaskoepf Can you provide an example where the rope scaling fails ?
I'm trying few dummy examples, but I'm not sure if what I'm doing is correct or not as the model output doesn't seem particularly bad either way (I'm guessing I'm not entering large enough prompts)
@andreaskoepf the PR linked should fix it.
@Narsil So we can now use models like the llama2 orca 8k mentioned by @andreaskoepf?
You should be able to !
I was able to get coherent results on prompts of 6k on that model.
I'm still waiting on confirmation that knows expectation from that particular model (my references to test are on llama v1-7b non finetuned, that I'm sure works, for the finetuned the output looks OK but without any reference points to compare to it's kind of hard)
I tried to test using gptq weights, on v1.0 everything is fine, with the latest container
File "/opt/conda/bin/text-generation-server", line 8, in <module>
sys.exit(app())
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/cli.py", line 78, in serve
server.serve(
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/server.py", line 184, in serve
asyncio.run(
File "/opt/conda/lib/python3.9/asyncio/runners.py", line 44, in run
return loop.run_until_complete(main)
File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 647, in run_until_complete
return future.result()
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/server.py", line 136, in serve_inner
model = get_model(
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/__init__.py", line 185, in get_model
return FlashLlama(
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/flash_llama.py", line 67, in __init__
model = FlashLlamaForCausalLM(config, weights)
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 456, in __init__
self.model = FlashLlamaModel(config, weights)
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 394, in __init__
[
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 395, in <listcomp>
FlashLlamaLayer(
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 331, in __init__
self.self_attn = FlashLlamaAttention(
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 206, in __init__
self.query_key_value = TensorParallelColumnLinear.load_multi(
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/utils/layers.py", line 264, in load_multi
weight = weights.get_multi_weights_col(
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/utils/weights.py", line 134, in get_multi_weights_col
bits, groupsize = self._get_gptq_params()
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/utils/weights.py", line 220, in _get_gptq_params
raise e
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/utils/weights.py", line 213, in _get_gptq_params
bits = self.get_tensor("gptq_bits").item()
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/utils/weights.py", line 66, in get_tensor
filename, tensor_name = self.get_filename(tensor_name)
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/utils/weights.py", line 53, in get_filename
raise RuntimeError(f"weight {tensor_name} does not exist")
RuntimeError: weight gptq_bits does not exist
rank=0
Error: ShardCannotStart
What model is that ?
flozi00/Llama-2-7b-german-assistant-v2-4bit-autogptq
The only commit touched that part of code is #738 after the 1.0 release
Should be ok after this, could you confirm ?
Another issue found
defined here
https://github.com/huggingface/text-generation-inference/blob/15fc64668f8d3dd407768286e5a0536aeb78c2e1/server/text_generation_server/utils/layers.py#L486C24-L486C39
used here and not accessible from the other class
so dynamic scaling is not working and raise function not defined error, linear scaling with quantized model is working. I can see that it has problems with the stop tokens now, so the model makes whole conversations, but i think that can be solved by some configuration
Shoot I just merge my PR which is the same :)
Edit: accepted yours so you'll end up in contributors !
Thanks.
Thanks a lot :)
I love that, at most huggingface projects the core team is so fast ๐
I can confirm, dynamic is working now too
what should be the rope scaling factor for 32k context, 0.125?