Benchmark effect of merging query and keys matrices in transformers
Ayushk4 opened this issue · 1 comments
For certain architectures (like GPTJ and LLaMa), it may be possible to replace Query
Take a standard benchmark, run the model before and after merging Query and Key matrices.
---------- Following are the details: (How to write latex in GitHub?)----------
.T()
denotes transpose
Consider the input representation
qi = MatMul(Q, xi)
kj = MatMul(K, xj)
score_i,j = MatMul(qi.T(), kj)
= MatMul( MatMul(Q, xi).T(), MatMul(K, xj) )
= MatMul( MatMul(xi.T(), Q.T()), MatMul(K, xj) )
= MatrixChainMul(xi.T(), Q.T(), K, xj)
let QKMerge = MatMul(Q.T(), K)
score_i,j = MatrixChainMul(xi.T(), QKMerge, xj)
Above formula will have to be modified for rotary embeddings.