huggingface/text-generation-inference

NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation.

flozi00 opened this issue ยท 33 comments

Feature request

Longer context up to 8k tokens, the given discussion and notebook generate promising results

Motivation

Discussion: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/

Colab Notebook: https://colab.research.google.com/drive/1VI2nhlyKvd5cw4-zHvAIk00cAVj2lCCC#scrollTo=d2ceb547

Your contribution

As it's only 3 lines of code to change it would be pretty easy to change

I will start training an model and give an example demo

Oh nice. And if you want to write a PR that would be awesome too.

Please be mindful that tgi code doesn't do batching the same as transformers meaning the change is most likely be slightly more complex.
Also lots of models actually defined this buffer directly in their weights instead of instantiating, with unfortunately some downstream differences in generation.

image

The purple one is trained with the 3 line fix given in the colab

@Narsil Just wanted to chime in here and say I'm working on an implementation and PR for this

@iantbutler01 let me know if you need support at any point
atm i am focused on training such models rather than integration into tgi

I've opened a draft PR, #529

I've tested the fixed NTK Aware scaling and it seems to work, I still need to test dynamic scaling and clean up the PR to comply with contributor guidelines, but I wanted to at least start the discussion.

@iantbutler01 does this method only supports LLaMa models? if yes, why did you add the support in flash_rw_modeling.py?

@GemsFord The method should work for any model using rotary embeddings, its agnostic. My main use is for Falcon 40bn which I've been running locally and testing these changes with.

@iantbutler01 Thanks for adding the support for Falcon, I use that too that's why I asked. I am waiting for your PR to get merged.

Yup, I plan to clean this up and make it ready for review this weekend. I was on vacation and now catching back up with my work, but I will have time this weekend. As far as I can tell the implementation works so it's just a matter of cleaning up and then going through review feedback.

Nice, I don't think that effects this work unless they implemented it in a flash attention enabled module. I'll definitely check it out to make sure my implementation here is correct though

Most interesting is the dynamic ntk aware rope being added
Maybe an option for tgi too adding the dynamic version ?

That's already in my PR :D

Great ๐Ÿ˜€๐Ÿ‘

Hi, any instructions on how to use this after PR is merged? Also, I was thinking why there would be a desync between transformers lib and this repo since it would be too expensive to run LLMs without an inference server and instantiating an instance using the transformers lib alone.

@Narsil I've updated the PR to remove draft status, I think I'm ready for review, just pinging you because you were the earliest responder from HF on this thread.

The associated PR #529 seems to add post-hoc RoPE scaling (for models trained without scaling). Now that linear & dynamic rope scaling got merged into transformers (huggingface/transformers#24653) more models will be fine-tuned with scaled RoPE. For example we (open-assistant) uploaded today a first experiment llama2-13b-orca-8k-3319 which was fine-tuned with 8k context with simple linear scaling, it has in the config.json and can be used out of the box with transformers 4.31.0:

  "rope_scaling": {
    "factor": 2.0,
    "type": "linear"
  },

Will support for these kinds of fine-tuned models also be added to TGI? Will a separate PR be required for this?

Currently those models can simply be loaded with TGI but since the rope-scaling is not active the output is gibberish. Until rope-scaled models are supported it might be good to generate an error or warning when rope_scaling is not None in the model's configuration.

Or will the rope-scaling of the HF transformers llama impl automatically be used one the TGI transformers dependency in requirements.txt is updated (currently it is still transformers==4.29.2)?

Narsil commented

Two separate things, but we'll align with that yes.

Narsil commented

@andreaskoepf Can you provide an example where the rope scaling fails ?

I'm trying few dummy examples, but I'm not sure if what I'm doing is correct or not as the model output doesn't seem particularly bad either way (I'm guessing I'm not entering large enough prompts)

Narsil commented

@andreaskoepf the PR linked should fix it.

@Narsil So we can now use models like the llama2 orca 8k mentioned by @andreaskoepf?

Narsil commented

You should be able to !

I was able to get coherent results on prompts of 6k on that model.
I'm still waiting on confirmation that knows expectation from that particular model (my references to test are on llama v1-7b non finetuned, that I'm sure works, for the finetuned the output looks OK but without any reference points to compare to it's kind of hard)

I tried to test using gptq weights, on v1.0 everything is fine, with the latest container

File "/opt/conda/bin/text-generation-server", line 8, in <module>
    sys.exit(app())

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/cli.py", line 78, in serve
    server.serve(

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/server.py", line 184, in serve
    asyncio.run(

  File "/opt/conda/lib/python3.9/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)

  File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 647, in run_until_complete
    return future.result()

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/server.py", line 136, in serve_inner
    model = get_model(

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/__init__.py", line 185, in get_model
    return FlashLlama(

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/flash_llama.py", line 67, in __init__
    model = FlashLlamaForCausalLM(config, weights)

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 456, in __init__
    self.model = FlashLlamaModel(config, weights)

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 394, in __init__
    [

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 395, in <listcomp>
    FlashLlamaLayer(

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 331, in __init__
    self.self_attn = FlashLlamaAttention(

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/flash_llama_modeling.py", line 206, in __init__
    self.query_key_value = TensorParallelColumnLinear.load_multi(

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/utils/layers.py", line 264, in load_multi
    weight = weights.get_multi_weights_col(

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/utils/weights.py", line 134, in get_multi_weights_col
    bits, groupsize = self._get_gptq_params()

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/utils/weights.py", line 220, in _get_gptq_params
    raise e

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/utils/weights.py", line 213, in _get_gptq_params
    bits = self.get_tensor("gptq_bits").item()

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/utils/weights.py", line 66, in get_tensor
    filename, tensor_name = self.get_filename(tensor_name)

  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/utils/weights.py", line 53, in get_filename
    raise RuntimeError(f"weight {tensor_name} does not exist")

RuntimeError: weight gptq_bits does not exist
 rank=0
Error: ShardCannotStart
Narsil commented

What model is that ?

flozi00/Llama-2-7b-german-assistant-v2-4bit-autogptq

The only commit touched that part of code is #738 after the 1.0 release

Narsil commented

Should be ok after this, could you confirm ?

Another issue found

def _create_inv_freq(dim, base, device):

defined here

https://github.com/huggingface/text-generation-inference/blob/15fc64668f8d3dd407768286e5a0536aeb78c2e1/server/text_generation_server/utils/layers.py#L486C24-L486C39
used here and not accessible from the other class

so dynamic scaling is not working and raise function not defined error, linear scaling with quantized model is working. I can see that it has problems with the stop tokens now, so the model makes whole conversations, but i think that can be solved by some configuration

solving that typo here

#745

Narsil commented

Shoot I just merge my PR which is the same :)

Edit: accepted yours so you'll end up in contributors !
Thanks.

Thanks a lot :)
I love that, at most huggingface projects the core team is so fast ๐Ÿš€

I can confirm, dynamic is working now too

what should be the rope scaling factor for 32k context, 0.125?