NolanoOrg/cformers

Benchmark effect of merging query and keys matrices in transformers

Ayushk4 opened this issue · 1 comments

For certain architectures (like GPTJ and LLaMa), it may be possible to replace Query $Q$ and Key $K$ matrices by a single matrix - saving on 1 out of seven/eight matrix multiplications in the transformer. I don't see an obvious way of having this for GPT-NeoX and OPT.

Take a standard benchmark, run the model before and after merging Query and Key matrices.

---------- Following are the details: (How to write latex in GitHub?)----------
.T() denotes transpose

Consider the input representation $X = {x1, ... xi, ... xj, ... xn}$.
qi = MatMul(Q, xi)
kj = MatMul(K, xj)

score_i,j = MatMul(qi.T(), kj)
= MatMul( MatMul(Q, xi).T(), MatMul(K, xj) )
= MatMul( MatMul(xi.T(), Q.T()), MatMul(K, xj) )
= MatrixChainMul(xi.T(), Q.T(), K, xj)

let QKMerge = MatMul(Q.T(), K)

score_i,j = MatrixChainMul(xi.T(), QKMerge, xj)

Above formula will have to be modified for rotary embeddings.