note
·
March 2, 2026
·
3 min
Speculative Sampling
ai
inference
Speculative SamplingDeepmind2023 和 Speculative DecodingGoogle2022 是最早提出speculative decoding的论文。两篇论文都依赖一个并行模型(或者一个更快的自回归模型采样 次)来生成一系列草稿token 及其对应的概率分布 。值得注意的是,草稿模型通过蒸馏等方式训练,以匹配目标模型的分布。
from tinygrad import Tensor
# 词表大小、推测深度、序列。
V, N, X = 32, 4, Tensor([26, 3, 10, 15])
# 用略有不同的分布模拟目标模型和草稿模型。
dist = Tensor.rand(V).sub(0.5).mul(3).unsqueeze(0)
model = lambda x: dist.expand(x.size(0), V).softmax(-1)
drafter = lambda x: dist.expand(x.size(0), V).div(0.5).softmax(-1)
sample = lambda d: d.multinomial(1).item()
# 从草稿模型生成 `N` 个token。
# 记录草稿概率 (dps) 和采样的草稿token (dts)。
dps, dts = [], []
for _ in range(N):
probs = drafter(X.cat(Tensor(dts)))[-1]
dps.append(probs)
dts.append(sample(probs))
目标模型的一次验证前向传播得到概率分布 ,用于对草稿模型生成的候选token进行打分。
# 对草稿token进行一次目标模型前向传播。
# 返回目标概率 (tps):
# * 每个草稿位置一个
# * 额外一个bonus
tps = model(X.cat(Tensor(dts)))[-(N+1):]
具体来说,比值 用于计算分数 。如果 ,则接受该token;否则拒绝。
# 接受循环。
accepted = 0
for t, p, q in zip(dts, dps, tps):
if Tensor.rand().item() < (q[t] / p[t]).item():
# 目标模型足够认可,接受草稿token。
X = X.cat(Tensor([t]))
accepted += 1
else:
# 目标模型不认可,从残差分布中重新采样。
residual = (q - p).clamp(0)
X = X.cat(Tensor([sample(residual / residual.sum())]))
break
当草稿模型与目标模型产生分歧时,从残差分布 中采样,可以保证预测结果符合目标模型的真实分布。DeepMind论文中的定理1给出了证明。
注意,在验证传播中我们已经采样了 。如果草稿模型的所有候选token都被接受,我们可以直接从 中采样并接受,因为它来自目标模型。
# Bonus token:如果所有 N 个草稿都被接受,直接从 q 采样第 (N+1) 个。
if accepted == N:
X = X.cat(Tensor([sample(tps[N])]))
就这样!我们通过目标模型的一次前向传播就生成了三个token。只要草稿模型足够快、接受率足够高,解码速度就能大幅提升。