Labrouste
Available in
Francais Chinese
note · March 2, 2026 · 6 min

Medusa

ai inference

I remember when MedusaTogether2024 was published. That morning, a couple of colleagues’ computers at the office were open to the PDF (not that I spy on people’s screen). If this isn’t your case for every new Tri Dao paper, you have some soul-searching to do.

Medusa proposes to augment the language modeling head of any Transformer model with nn additional heads which predict nn future tokens, in parallel. This differs from traditional speculative decoding in that the speculator is not a separate model but rather an extension which leverages the original model’s hidden states. In addition, because predictions happen in parallel latency is greatly reduced.

A Medusa head is defined as layers of feed-forward and residual connections followed by a LM head. In practice, a single layer suffices.

from tinygrad import Tensor

# Hidden dimension, Vocab size, Heads.
H, V, N = 2048, 8192, 2

class Medusa:
    def __init__(self, model):
        self.model = model
        # We batch the linear layers and heads.
        self.w1 = Tensor.kaiming_uniform(N, H, H)
        self.w2 = Tensor.kaiming_uniform(N, H, V)

    def __call__(self, x, positions, cache, mask):
        h, logits, cache = self.model(x, positions, cache, mask)
        h = h.expand(N, -1, -1)
        h = h + (h @ self.w1).silu()
        medusa_logits = h @ self.w2
        return logits, medusa_logits, cache

The above module produces logits for {xt+1,,xt+n+1}\{x_{t+1}, \dots, x_{t+n+1}\} future tokens from which we can draw probability distributions. For each head ii, we select the sis_i highest probability tokens to evaluate.

model = Model()
medusa = Medusa(model)

# A dummy, single-token sequence
x = Tensor(tokenize("[BOS]"))

# Pass through the model, predict the next token.
logits, medusa_logits, cache = medusa(
    x,
    positions=Tensor([0]), # The [BOS] token has position zero.
    cache=Cache(), # Empty, imaginary KV cache.
    mask=Tensor([[0]]), # The [BOS] token can attend to itself.
)
choice = logits[-1].softmax().multinomial()

This gives us a tree of k=i=1nsik = \prod^{n}_{i=1} s_i candidate sequences of length nn.

# We select the top-2 tokens from the first head, top-3 from second.
S = Tensor([2, 3])
# This leads to 2×3 possible sequences.
K = S.prod().item()
Example tree of possible generations

The next step is to score each branch using the target model. We construct an input x~\tilde{x} of length k+1k+1 which contains the model’s prediction and the flattened tree.

s = 1
x = Tensor([choice])
positions = Tensor([1])
for i in range(N):
    _, ids = medusa_logits.topk(S[i])
    x = x.cat(ids.repeat(s))
    positions = positions.cat(1 + S[i] * s)
    s = S[i]

Finally, we construct an attention mask which prevents tokens from different branches to attend to one another. Funny enough, this scenario heavily resembles paged attention with prefix matching from the world of LLM-inference.

{todo}
Example tree of possible generations