传统语言模型采用简单的训练目标:给定前t个令牌,预测第t+1个令牌。这种自回归分解方式虽然优雅且有效,但存在根本性局限:模型仅接收即时下一个令牌的预测训练信号,从未明确学习提前规划多个步骤。
考虑生成句子:"猫坐在垫子上,因为它很舒服。" 当预测"因为"时,模型应该已经在考虑句子将如何完成——包括从句、代词指代和结论。但仅依靠单令牌预测,没有明确的梯度信号鼓励这种前向规划。
这种局限在需要长期连贯性的任务中尤为明显,如故事生成、多段落推理或代码生成,模型容易生成局部流畅但全局自相矛盾的内容。
多令牌预测通过添加辅助预测头来解决此问题,这些预测头可以同时预测未来的多个令牌。除了标准的位置t+1预测外,还同时预测位置t+2、t+3等后续令牌。
完整的训练目标函数为:
L = Lmain + Σ{k=1}^{n} λ_k * L_k
其中n为预测的未来令牌数,λ_k为加权系数(通常随距离增大而递减)。
实现多令牌预测需要架构上的补充。不能直接复用主语言建模头进行未来预测,需要依赖中间令牌的信息。
预测头结构: 对于预测k个令牌后的位置,需要组合两个信息源:
组合方式为:combined = Wcombine([norm(h_i), norm(e{i+k-1})]) + b
然后通过轻量级Transformer(注意力层和前馈层)处理后再投影到词汇表,生成预测logits。
从优化角度看,MTP提供更丰富的梯度信号。标准训练中只有隐藏表示hi接收来自预测x{i+1}的梯度。而使用MTP后,hi还会接收来自预测x{i+k}的梯度。这些额外梯度鼓励h_i编码不仅与下一个令牌相关,而且与多个未来令牌相关的信息。
这相当于添加了一个隐式正则化器,约束学习到的表示更加结构化、更具前瞻性和全局连贯性。
训练阶段: 所有预测并行计算,使用真实令牌信息,无误差累积。
推理阶段: MTP头通常不用于自回归生成,其作用是在训练阶段改善学习到的表示,推理时仍使用标准单令牌预测方式,保证部署时的计算效率。
对于预测深度k,权重通常采用指数衰减:λ_k = γ^(k-1),其中γ∈(0,1)。例如γ=0.5时,深度1权重1.0,深度2权重0.5,深度3权重0.25。
class MultiTokenPredictionHead(nn.Module):
"""
多令牌预测头
每个头预测特定未来位置的令牌
组合前一个隐藏状态与未来令牌嵌入
"""
def __init__(self, config: DeepSeekConfig, depth: int):
super().__init__()
self.depth = depth
self.n_embd = config.n_embd
# 组合前一个隐藏状态与未来令牌嵌入
self.combine_proj = nn.Linear(2 * config.n_embd, config.n_embd, bias=config.bias)
# 归一化层
self.norm1 = RMSNorm(config.n_embd)
self.norm2 = RMSNorm(config.n_embd)
# Transformer组件(每个头的轻量级Transformer)
self.attn = MultiheadLatentAttention(config)
self.mlp = MixtureOfExperts(config)
self.attn_norm = RMSNorm(config.n_embd)
self.mlp_norm = RMSNorm(config.n_embd)
def forward(self, prev_hidden, future_token_embed):
"""
参数:
prev_hidden: [B, T, D] - 前一层的隐藏状态
future_token_embed: [B, T, D] - 未来令牌的嵌入
返回:
hidden: [B, T, D] - 处理后的隐藏状态
"""
# 归一化输入
prev_norm = self.norm1(prev_hidden)
future_norm = self.norm2(future_token_embed)
# 组合表示
combined = torch.cat([prev_norm, future_norm], dim=-1)
hidden = self.combine_proj(combined)
# 通过轻量级Transformer处理
hidden = hidden + self.attn(self.attn_norm(hidden))
moe_out, _ = self.mlp(self.mlp_norm(hidden))
hidden = hidden + moe_out
return hidden训练时将MTP头集成到主模型中,操作流程如下:
关键洞察是头之间存在链式依赖关系,形成层次化结构。
研究表明MTP带来多个实证收益:
传统自回归模型依赖单令牌预测,这种策略虽然有效但可能短视。MTP通过使模型能够同时预测多个令牌来解决这一局限,加速训练和推理,同时丰富上下文理解能力。这项创新不仅提高了效率,还强化了模型的理论和经验基础。FINISHED
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。