Labrouste
Available in
note · March 2, 2026 · 6 min

Speculative Sampling

ai inference

Speculative SamplingDeepmind2023 and Speculative DecodingGoogle2022 are among the first papers to introduce speculative decoding. Both papers rely on a parallel model (or a faster auto-regressive model sampled nn times) to produce a sequence of draft tokens {x~t+1,,x~t+n}\{\tilde{x}_{t+1}, \ldots, \tilde{x}_{t+n}\} and corresponding probability distributions {pt+1,,pt+n}\{p_{t+1}, \ldots, p_{t+n}\}. Notably, the draft model is trained to match the distribution of the target model, usually through distillation.

from tinygrad import Tensor

# Vocabulary size, speculation depth, sequence.
V, N, X = 32, 4, Tensor([26, 3, 10, 15])

# Simulate target and draft model with slightly different distributions.
dist = Tensor.rand(V).sub(0.5).mul(3).unsqueeze(0)
model = lambda x: dist.expand(x.size(0), V).softmax(-1)
drafter = lambda x: dist.expand(x.size(0), V).div(0.5).softmax(-1)
sample = lambda d: d.multinomial(1).item()

# Generate `N` tokens from the draft model.
# Record draft probabilities (dps) and sampled draft tokens (dts).
dps, dts = [], []
for _ in range(N):
    probs = drafter(X.cat(Tensor(dts)))[-1]
    dps.append(probs)
    dts.append(sample(probs))

A verification forward pass through the target model yields probability distributions {qt+1,,qt+n+1}\{q_{t+1}, \ldots, q_{t+n+1}\} which are used to score the candidates generated by the draft model.

# Single forward pass through the target model over the draft tokens.
# Returns target probabilities (tps):
# * 1x per draft position
# * 1x bonus
tps = model(X.cat(Tensor(dts)))[-(N+1):]

Specifically, the ratio qi/piq_i/p_i is used to compute the score sis_i. If si>rU(0,1)s_i > r \sim U(0, 1), the token is accepted; otherwise, it is rejected.

# Acceptance loop.
accepted = 0
for t, p, q in zip(dts, dps, tps):
    if Tensor.rand().item() < (q[t] / p[t]).item():
        # Target model agrees strongly enough, accept the draft token.
        X = X.cat(Tensor([t]))
        accepted += 1
    else:
        # Target model disagrees, resample from the residual distribution.
        residual = (q - p).clamp(0)
        X = X.cat(Tensor([sample(residual / residual.sum())]))
        break

Upon disagreement between the draft and target model, sampling from the residual distribution (qipi)+(q_i - p_i)_+ gives us a prediction which is guaranteed to be grounded in the target model’s true distribution. Theorem 1 in DeepMind’s paper provides the proof.

Now, notice that during the verification pass we sampled qt+5q_{t+5}. In the case where all candidates from the draft model are accepted, we can sample from qt+5q_{t+5} and accept it directly since it comes from the target model.

# Bonus token: if all N drafts accepted, sample the (N+1)th directly from q.
if accepted == N:
    X = X.cat(Tensor([sample(tps[N])]))

And voilà! We generated three tokens from a single forward pass from the target model. Given a sufficiently fast draft model and high-enough acceptance rate your decoding speed has magically increased.