本文对这篇论文进行复现:Distilling Step-by-Step! Outperforming Larger Language Models with Less Training Data and Smaller Model Sizes 目前已发表在2023ACL上
大规模语言模型(LLMs)由于其内存低效和计算密集,部署起来非常具有挑战性。为此,研究人员通常通过微调(finetuning)或蒸馏(distillation)训练更小的任务特定模型,但这两种方法都需要大量的训练数据。 本文提出了一种新的方法——逐步蒸馏(Distilling Step-by-Step),它通过提取LLM生成的推理过程作为监督信号,训练小模型并显著减少数据需求。该机制的核心是换一种角度,将 LLM 看作是可以推理的 agent,而不是噪声标签的来源。LLM 可以产生自然语言的理由(rationale),这些理由可以用来解释和支持模型所预测的标签。 例如,当被问及“一位先生携带着打高尔夫球的设备,他可能有什么?(a) 球杆,(b) 礼堂,© 冥想中心,(d) 会议,(e) 教堂”,LLM 可以通过思维链(CoT)推理回答出「(a)球杆」,并通过说明「答案一定是用来打高尔夫球的东西」来合理化这个标签。在上述选择中,只有球杆是用来打高尔夫的。研究者使用这些理由作为额外更丰富的信息在多任务训练设置中训练较小的模型,并进行标签预测和理由预测。
本篇工作基于T5-efficient-mini模型复现了该方法,不仅提高了训练速度,还在wandb平台上实现了训练过程的可视化。通过这种优化,展示了如何在实践中加速模型训练。以上内容均为原创。
逐步蒸馏(Distilling Step-by-Step),其核心思想是利用大规模语言模型(LLMs)推理预测的能力,通过生成带有理由的标签数据来辅助训练更小的下游模型。该方法包含两个主要步骤:
论文中使用了4个流行的基准数据集,涵盖3种不同的自然语言处理(NLP)任务,具体数据集和任务如下:
CQA (Commonsense Question Answering):一个基于常识知识的多项选择问答数据集,要求模型结合外部常识知识来回答问题。
SVAMP (Single-Variable Arithmetic Math Problems)专注于单变量算术数学问题,设计更加多样化,意在测试模型在数学文字题上的推理能力。
conda create --name distill python=3.10.6 -y
conda activate distill
安装必要的依赖库:
conda install -y pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
pip install git+https://github.com/huggingface/transformers@v4.24.0 datasets sentencepiece protobuf==3.20.* tensorboardX
pip install sentencepiece
pip install protobuf==3.20 wandb
python run.py --from_pretrained ./t5-efficient-mini --dataset cqa --model_type standard --label_type gt --batch_size 64
python run.py --from_pretrained ./t5-efficient-mini --dataset cqa --model_type task_prefix --label_type gt --llm palm --alpha 0.5 --batch_size 64
python run.py --from_pretrained ./t5-efficient-mini --dataset cqa --model_type standard --label_type llm --batch_size 64
python run.py --from_pretrained ./t5-efficient-mini --dataset cqa --model_type task_prefi
3.3实验结果 在wandb可以看到实验结果
class TaskPrefixTrainer(Seq2SeqTrainer):
def __init__(self, alpha, output_rationale,**kwargs):
super().__init__(**kwargs)
self.alpha = alpha
self.output_rationale = output_rationale
def compute_loss(self, model, inputs, return_outputs=False):
pred_outputs = model(**inputs['pred'])
expl_outputs = model(**inputs['expl'])
loss = self.alpha * pred_outputs.loss + (1. - self.alpha) * expl_outputs.loss
# For Eval Loss/expl_loss, Eval Loss/pred_loss, Eval Loss/total_loss
wandb.log({
"Eval Loss/expl_loss": expl_outputs[0].item(),
"Eval Loss/pred_loss": pred_outputs[0].item(),
"Eval Loss/total_loss": loss.item()
}, step=self.state.global_step)
return (loss, {'pred': pred_outputs, 'expl': expl_outputs}) if return_outputs else loss
def __del__(self):
# 确保在训练结束后关闭SummaryWriter
self.writer.close()
def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
pred_outputs = super().prediction_step(model, inputs['pred'], prediction_loss_only=False, ignore_keys=ignore_keys)
if self.output_rationale:
expl_outputs = super().prediction_step(model, inputs['expl'], prediction_loss_only=False, ignore_keys=ignore_keys)
else:
expl_outputs = pred_outputs # placeholder only
loss = self.alpha * pred_outputs[0] + (1 - self.alpha) * expl_outputs[0]
# 记录损失到 TensorBoard
wandb.log({
"Eval Loss/expl_loss": expl_outputs[0].item(),
"Eval Loss/pred_loss": pred_outputs[0].item(),
"Eval Loss/total_loss": loss.item()
}, step=self.state.global_step)
return (
loss,
[pred_outputs[1], expl_outputs[1]],
[pred_outputs[2], expl_outputs[2]],
)