Possible Bug in QuaRot Implementation with remove_mean_from_embed()
A-suozhang opened this issue · 4 comments
First of all, thank the authors for the awesome work, It really benefits the research of quantization.
When reviewing the code for LLMC and the original repository of QuaRot, I noticed that the LLMC implementation subtracts the mean of the embedding layer weights in the line mentioned above. However, I did not find any corresponding operation to compensate for this difference. Could this discrepancy lead to variations compared to the original inference without QuaRot?
In most newer large language models (LLMs), such as the LLaMA series, the embedding layer is followed by an RMSNorm layer, which does not perform a zero-mean operation like LayerNorm. I assume the authors might be attempting to split the LayerNorm into a "subtract mean" operation followed by RMSNorm, particularly for models that use LayerNorm after the embedding layer (e.g., OPT).
I am curious about the rationale behind introducing the remove_mean_from_embed function
, as I do not recall seeing it elsewhere. This function is specifically called for the QuaRot method. Clarification from the authors would be greatly appreciated. Thanks a lot!
Thanks for appreciating our work.
If the mean of
Because the prerequisite
remove_mean_from_embed
derives from SliceGPT, and inherits into QuaRot.
Thank you for your response! However, I'm still unclear on when the subtraction of the mean is compensated to ensure the computation matches the original implementation.
From my understanding, the "SliceGPT" paper suggests that subtracting the mean is used to convert LayerNorm to RMSNorm. However, in the LLaMA model series, RMSNorm is applied after the embedding layers.
I agree with you that the remove_mean_from_embed
should be called before the layernorms, however, for models that already has RMSNorm followed by embedding layers, I'm not sure whether this operation is necessary. Also, I didnot find the corresponding implementation for the llama2 quantization in the orginal QuaRot codebase.
For LayerNorm, it doesn’t need any compensation. You can find the relative code here.
https://github.com/spcl/QuaRot/blob/main/fake_quant/rotation_utils.py#L45
However, for models, e.g., Llama, I am also curious about why this function need to be used. I have tried to remove it and got the same performance for Llama. If you have the answer, please let me know.
remove_mean_from_embed
Llama does not need this operation, but performing this operation has little effect on the result because llama's embdding mean is small,spcl/QuaRot#7 (comment)