FasterDecoding/Medusa

about Medusa mask details

dhcode-cpp opened this issue · 0 comments

Hi, Medusa is a very efficient LLM inference algorithm!

I have some question about Medusa implemention

Q1 : first step : multiple decoding heads get next-N-tokens, and step2 is like verification, and the input mask would like below, it‘s right?

# original mask
1, 0, 0, 0,
1, 1, 0, 0,
1, 1, 1, 0,
1, 1, 1, 1,

# if Medusa first step decode next-2-token
# and verification step the attention mask be like:
1, 0, 0, 0, | 0, 0
1, 1, 0, 0, | 0, 0
1, 1, 1, 0, | 0, 0
1, 1, 1, 1, | 0, 0
------------------
1, 1, 1, 1, | 1, 0
1, 1, 1, 1, | 1, 1

Q2: if the question 1 is true, the tree attention mask also expanded like above?

# we note above matrix be 
A | 0
-----
1 | M
# the M is medusa mask

if above is true, the M would be replace the tree attention mask?
image

Q3: In Medusa-2, the verification step wold be trained LLM parameters (no only heads )

thank's for your reply!