前往小程序,Get更优阅读体验!
立即前往
发布
社区首页 >专栏 >逐步蒸馏论文复现

逐步蒸馏论文复现

作者头像
Srlua
发布2025-01-02 08:54:41
发布2025-01-02 08:54:41
16900
代码可运行
举报
文章被收录于专栏:CSDN社区搬运CSDN社区搬运
运行总次数:0
代码可运行

本文对这篇论文进行复现:Distilling Step-by-Step! Outperforming Larger Language Models with Less Training Data and Smaller Model Sizes 目前已发表在2023ACL上

1.论文概述

大规模语言模型(LLMs)由于其内存低效和计算密集,部署起来非常具有挑战性。为此,研究人员通常通过微调(finetuning)或蒸馏(distillation)训练更小的任务特定模型,但这两种方法都需要大量的训练数据。 本文提出了一种新的方法——逐步蒸馏(Distilling Step-by-Step),它通过提取LLM生成的推理过程作为监督信号,训练小模型并显著减少数据需求。该机制的核心是换一种角度,将 LLM 看作是可以推理的 agent,而不是噪声标签的来源。LLM 可以产生自然语言的理由(rationale),这些理由可以用来解释和支持模型所预测的标签。 例如,当被问及“一位先生携带着打高尔夫球的设备,他可能有什么?(a) 球杆,(b) 礼堂,© 冥想中心,(d) 会议,(e) 教堂”,LLM 可以通过思维链(CoT)推理回答出「(a)球杆」,并通过说明「答案一定是用来打高尔夫球的东西」来合理化这个标签。在上述选择中,只有球杆是用来打高尔夫的。研究者使用这些理由作为额外更丰富的信息在多任务训练设置中训练较小的模型,并进行标签预测和理由预测。

本篇工作基于T5-efficient-mini模型复现了该方法,不仅提高了训练速度,还在wandb平台上实现了训练过程的可视化。通过这种优化,展示了如何在实践中加速模型训练。以上内容均为原创。

2.论文方法

逐步蒸馏(Distilling Step-by-Step),其核心思想是利用大规模语言模型(LLMs)推理预测的能力,通过生成带有理由的标签数据来辅助训练更小的下游模型。该方法包含两个主要步骤:

  • 生成合理性解释(Rationales):通过提示(prompting)引导LLMs为无标签数据生成预测标签以及相应的自然语言理由(Rationales)。这些理由解释了为什么给定输入会被映射到某一特定输出。
  • 结合理由进行模型训练:利用生成的理由和预测标签,以多任务学习的方式训练小型模型,使其不仅能预测任务标签,还能学习生成对应的推理过程,从而提升模型的预测能力。
2.1 提取理由
  • 链式推理提示(Chain-of-Thought Prompting):设计包含输入、标签和理由的提示模板,通过少量示例指导LLMs生成新的标签和对应理由。
  • 生成过程:利用提示模板为无标签数据集生成预测标签和理由,形成带有解释的伪标注数据
2.2 结合理由训练小模型
  • 传统方法:直接微调预训练模型或利用LLMs生成的伪标签训练下游模型。
  • 逐步蒸馏方法:采用多任务学习方式,将标签预测和理由生成结合起来,训练小模型同时具备预测能力和推理能力。 通过在输入中添加任务前缀(如“[label]”和“[rationale]”),指导模型在不同场景下生成标签或理由。

3.实验部分

3.1数据集

论文中使用了4个流行的基准数据集,涵盖3种不同的自然语言处理(NLP)任务,具体数据集和任务如下:

3.1.1自然语言推理(Natural Language Inference, NLI)
  • e-SNLI (Explainable SNLI):基于SNLI(Stanford Natural Language Inference)的扩展版本,增加了每个推理对的解释(rationale)。任务是判断两个句子之间的逻辑关系(蕴含、矛盾、中立)。
  • ANLI (Adversarial Natural Language Inference):一个更具挑战性的自然语言推理数据集,包含三轮对抗样本生成的数据。任务同样是预测句子之间的逻辑关系。
3.1.2. 常识问答(Commonsense Question Answering, CQA)

CQA (Commonsense Question Answering):一个基于常识知识的多项选择问答数据集,要求模型结合外部常识知识来回答问题。

3.1.3 数学文字题(Arithmetic Math Word Problems, AMWP)

SVAMP (Single-Variable Arithmetic Math Problems)专注于单变量算术数学问题,设计更加多样化,意在测试模型在数学文字题上的推理能力。

3.2 实验步骤
  • step1:安装环境依赖 实验环境搭建 创建并激活 Conda 环境:
代码语言:javascript
代码运行次数:0
复制
conda create --name distill python=3.10.6 -y
conda activate distill
安装必要的依赖库:

conda install -y pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
pip install git+https://github.com/huggingface/transformers@v4.24.0 datasets sentencepiece protobuf==3.20.* tensorboardX
pip install sentencepiece
pip install protobuf==3.20 wandb
  • 标准微调(Standard Fine-tuning) 使用真实标签(GT)对模型进行标准微调:
代码语言:javascript
代码运行次数:0
复制
python run.py --from_pretrained ./t5-efficient-mini --dataset cqa --model_type standard --label_type gt --batch_size 64
  • 逐步蒸馏(Distilling Step-by-Step) 使用真实标签(GT label)和PaLM生成的推理(PaLM rationale):
代码语言:javascript
代码运行次数:0
复制
python run.py --from_pretrained ./t5-efficient-mini --dataset cqa --model_type task_prefix --label_type gt --llm palm --alpha 0.5 --batch_size 64
  • 标准蒸馏(Standard Distillation) 使用LLM生成的标签(PaLM label)对模型进行蒸馏:
代码语言:javascript
代码运行次数:0
复制
python run.py --from_pretrained ./t5-efficient-mini --dataset cqa --model_type standard --label_type llm --batch_size 64
  • 结合标签与推理的逐步蒸馏 使用PaLM生成的标签(PaLM label)和推理(PaLM rationale):
代码语言:javascript
代码运行次数:0
复制
python run.py --from_pretrained ./t5-efficient-mini --dataset cqa --model_type task_prefi

3.3实验结果 在wandb可以看到实验结果

4.核心代码

代码语言:javascript
代码运行次数:0
复制
class TaskPrefixTrainer(Seq2SeqTrainer):
    def __init__(self, alpha, output_rationale,**kwargs):
        super().__init__(**kwargs)
        self.alpha = alpha
        self.output_rationale = output_rationale



    def compute_loss(self, model, inputs, return_outputs=False):
        pred_outputs = model(**inputs['pred'])
        expl_outputs = model(**inputs['expl'])

        loss = self.alpha * pred_outputs.loss + (1. - self.alpha) * expl_outputs.loss

        # For Eval Loss/expl_loss, Eval Loss/pred_loss, Eval Loss/total_loss
        wandb.log({
            "Eval Loss/expl_loss": expl_outputs[0].item(),
            "Eval Loss/pred_loss": pred_outputs[0].item(),
            "Eval Loss/total_loss": loss.item()
        }, step=self.state.global_step)



        return (loss, {'pred': pred_outputs, 'expl': expl_outputs}) if return_outputs else loss

    def __del__(self):
        # 确保在训练结束后关闭SummaryWriter
        self.writer.close()

    def prediction_step(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        
        pred_outputs = super().prediction_step(model, inputs['pred'], prediction_loss_only=False, ignore_keys=ignore_keys)
        if self.output_rationale:
            expl_outputs = super().prediction_step(model, inputs['expl'], prediction_loss_only=False, ignore_keys=ignore_keys)
        else:
            expl_outputs = pred_outputs # placeholder only


        loss = self.alpha * pred_outputs[0]  + (1 - self.alpha) * expl_outputs[0]

        # 记录损失到 TensorBoard
        wandb.log({
            "Eval Loss/expl_loss": expl_outputs[0].item(),
            "Eval Loss/pred_loss": pred_outputs[0].item(),
            "Eval Loss/total_loss": loss.item()
        }, step=self.state.global_step)

        return (
            loss,
            [pred_outputs[1], expl_outputs[1]],
            [pred_outputs[2], expl_outputs[2]],
        )

​​

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2025-01-02,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1.论文概述
  • 2.论文方法
    • 2.1 提取理由
    • 2.2 结合理由训练小模型
  • 3.实验部分
    • 3.1数据集
    • 3.2 实验步骤
  • 4.核心代码
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档