Speculative Sampling
Speculative SamplingDeepmind2023 and Speculative DecodingGoogle2022 are among the first papers to introduce speculative decoding. Both papers rely on a parallel model (or a faster auto-regressive model sampled times) to produce a sequence of draft tokens and corresponding probability distributions . Notably, the draft model is trained to match the distribution of the target model, usually through distillation.
from tinygrad import Tensor
# Vocabulary size, speculation depth, sequence.
V, N, X = 32, 4, Tensor([26, 3, 10, 15])
# Simulate target and draft model with slightly different distributions.
dist = Tensor.rand(V).sub(0.5).mul(3).unsqueeze(0)
model = lambda x: dist.expand(x.size(0), V).softmax(-1)
drafter = lambda x: dist.expand(x.size(0), V).div(0.5).softmax(-1)
sample = lambda d: d.multinomial(1).item()
# Generate `N` tokens from the draft model.
# Record draft probabilities (dps) and sampled draft tokens (dts).
dps, dts = [], []
for _ in range(N):
probs = drafter(X.cat(Tensor(dts)))[-1]
dps.append(probs)
dts.append(sample(probs))
A verification forward pass through the target model yields probability distributions which are used to score the candidates generated by the draft model.
# Single forward pass through the target model over the draft tokens.
# Returns target probabilities (tps):
# * 1x per draft position
# * 1x bonus
tps = model(X.cat(Tensor(dts)))[-(N+1):]
Specifically, the ratio is used to compute the score . If , the token is accepted; otherwise, it is rejected.
# Acceptance loop.
accepted = 0
for t, p, q in zip(dts, dps, tps):
if Tensor.rand().item() < (q[t] / p[t]).item():
# Target model agrees strongly enough, accept the draft token.
X = X.cat(Tensor([t]))
accepted += 1
else:
# Target model disagrees, resample from the residual distribution.
residual = (q - p).clamp(0)
X = X.cat(Tensor([sample(residual / residual.sum())]))
break
Upon disagreement between the draft and target model, sampling from the residual distribution gives us a prediction which is guaranteed to be grounded in the target model’s true distribution. Theorem 1 in DeepMind’s paper provides the proof.
Now, notice that during the verification pass we sampled . In the case where all candidates from the draft model are accepted, we can sample from and accept it directly since it comes from the target model.
# Bonus token: if all N drafts accepted, sample the (N+1)th directly from q.
if accepted == N:
X = X.cat(Tensor([sample(tps[N])]))
And voilà! We generated three tokens from a single forward pass from the target model. Given a sufficiently fast draft model and high-enough acceptance rate your decoding speed has magically increased.