about Medusa mask details
dhcode-cpp opened this issue · 0 comments
dhcode-cpp commented
Hi, Medusa is a very efficient LLM inference algorithm!
I have some question about Medusa implemention
Q1 : first step : multiple decoding heads get next-N-tokens, and step2 is like verification, and the input mask would like below, it‘s right?
# original mask
1, 0, 0, 0,
1, 1, 0, 0,
1, 1, 1, 0,
1, 1, 1, 1,
# if Medusa first step decode next-2-token
# and verification step the attention mask be like:
1, 0, 0, 0, | 0, 0
1, 1, 0, 0, | 0, 0
1, 1, 1, 0, | 0, 0
1, 1, 1, 1, | 0, 0
------------------
1, 1, 1, 1, | 1, 0
1, 1, 1, 1, | 1, 1
Q2: if the question 1 is true, the tree attention mask also expanded like above?
# we note above matrix be
A | 0
-----
1 | M
# the M is medusa mask
if above is true, the M
would be replace the tree attention mask?
Q3: In Medusa-2, the verification step wold be trained LLM parameters (no only heads )
thank's for your reply!