前往小程序,Get更优阅读体验!
立即前往
发布
社区首页 >专栏 >机器学习|从0开发大模型之SFT训练

机器学习|从0开发大模型之SFT训练

作者头像
用户1904552
发布2025-02-27 10:44:55
发布2025-02-27 10:44:55
12400
代码可运行
举报
文章被收录于专栏:周末程序猿周末程序猿
运行总次数:0
代码可运行

继续写《从0开发大模型》系列文章,上一章主要数据数据预训练,让模型能学到句子接龙和部分语言理解能力,获取基座版本,但是用基座版本的模型的对话能力太弱了,需要用大量的数据微调,本文主要介绍如何用SFT训练模型。

1、什么是SFT

SFT是有监督微调(Supervised Fine-Tuning),指采用预先训练好的网络模型,并针对你自己的专门任务在少量的监督数据上对其进行重新训练的技术。

SFT在大语言模型中的应用有以下重要原因:

任务特定性能提升:预训练语言模型通过大规模的无监督训练学习了语言的统计模式和语义表示,然而它在特定任务下的效果可能并不令人满意,通过在任务特定的有标签数据上进行微调,模型可以进一步学习任务相关的特征和模式,从而提高性能。 领域适应性:预训练语言模型可能在不同领域的数据上表现不一致,通过在特定领域的有标签数据上进行微调,可以使模型更好地适应该领域的特殊术语、结构和语义,提高在该领域任务上的效果。 数据稀缺性:某些任务可能受制于数据的稀缺性,很难获得大规模的标签数据,监督微调可以通过使用有限的标签数据来训练模型,从而在数据有限的情况下取得较好的性能。 防止过拟合:在监督微调过程中,通过使用有标签数据进行有监督训练,可以减少模型在特定任务上的过拟合风险,这是因为监督微调过程中的有标签数据可以提供更具体的任务信号,有助于约束模型的学习,避免过多地拟合预训练过程中的无监督信号。

2、整理SFT数据

整理SFT数据需要遵循以下原则:

  • 按照QA的格式整理数据
  • 如果训练多语言模型,需要准备其他语言的数据,本文训练的模型是中文的,所以只准备中文的数据
  • SFT的数据需要确保QA的数据回答是正确的,否则模型无法学习到正确的答案

(1)数据格式如下(CSV):

代码语言:javascript
代码运行次数:0
复制
history,q,a
[],好的。现在请你将这个文本中的所有的逗号都替换成空格。,"好的,请稍等一下,现在我会将文本中的所有逗号替换为空格。处理后文本为:""这是一个句子 目的是看看是否可以正确地从这个句子中删除关键词。""。处理结果如何?"

其中history是历史的输入,q是问题,a是答案,但是以上数据无法直接用于微调,需要会拼接,比如翻译类型的会这样处理:

代码语言:javascript
代码运行次数:0
复制
instruction:
[USR]:将下列内容翻译成英语:{待翻译文本}
answer
[BOT]:{翻译结果}

拼接后的文本:
<bos_token>[USER]:将下列内容翻译成英语:{待翻译文本}<special token>[BOT]:{翻译结果} <eos_token>

(2)SFT的数据集可以参考以下数据集:

  • BelleGroup/train_3.5M_CN
  • LinkSoul/instruction_merge_set
  • stingning/ultrachat
  • BAAI/COIG-PC-core
  • shibing624/sharegpt_gpt4
  • shareAI/ShareGPT-Chinese-English-90k
  • Tiger Research
  • BelleGroup/school_math_0.25M
  • YeungNLP/moss-003-sft-data

(3)整理数据,将数据压缩:

这里为了数据处理方便,这里数据直接使用:https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data,每次处理2000条数据,然后保存为bin文件,其中需要通过 sft_process_and_write_data 将数据转换为token。

代码语言:javascript
代码运行次数:0
复制
def sft_process_and_write_data(data, max_length = 1024, padding = 0):
    doc_ids = []
    for per in data:
        history, q, a = per['history'], per['q'], per['a']
        if len(q) < 10 or len(a) < 5:
            continue
        if len(q) > 512 or len(a) > 512:
            continue

        messages = []
        for history_message in history:
            if len(history_message) <= 1:
                continue
            messages.append(
                {"role": 'user', "content": history_message[0][:max_length // 2]}
            )
            messages.append(
                {"role": 'assistant', "content": history_message[1][:max_length // 2]}
            )

        messages += [
            {"role": "user", "content": q},
            {"role": "assistant", "content": a},
        ]
        new_prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        input_id = tokenizer(new_prompt).data['input_ids'][:max_length]
        padding_len = max_length - len(input_id)
        input_id = input_id + [padding] * padding_len
        if len(input_id) >= 5:
            doc_ids += input_id
    
    return doc_ids

def sft_process():
    file_name = 'sft_data.bin'
    chunk_size = 2000  # 每次处理的记录数

    input_doc_ids = []
    datalist = []
    sft_datasets = [f'{basepath}/sft_data_zh.jsonl']
    chunk_num = 0
    for path in sft_datasets:
        with jsonlines.open(path) as reader:
            for idx, obj in enumerate(reader):
                try:
                    datalist.append({
                        'history': obj.get('history', ''),
                        'q': obj.get('input', '') + obj.get('q', ''),
                        'a': obj.get('output', '') + obj.get('a', '')
                    })

                    if len(datalist) >= chunk_size:
                        chunk_num += 1
                        input_doc_ids += sft_process_and_write_data(datalist)
                        arr = np.array(input_doc_ids, dtype=np.uint16)
                        with open(f'{basepath}/{file_name}', 'wb') as f:
                            f.write(arr.tobytes())
                        datalist = []
                        if chunk_num % 100 == 0:
                            print(f'chunk:{chunk_num} process end, and input_doc_ids length:{len(input_doc_ids)}')
                except jsonlines.InvalidLineError as e:
                    print(f"Skipping invalid JSON line {idx + 1}: {e}")
                    continue
                    
            if len(datalist) > 0:
                input_doc_ids += sft_process_and_write_data(datalist)
                arr = np.array(input_doc_ids, dtype=np.uint16)
                with open(f'{basepath}/{file_name}', 'wb') as f:
                    f.write(arr.tobytes())
                datalist = []

3、SFT训练

SFT训练的代码和上一篇预训练的代码差别不大,区别是加载SFT数据集,代码如下(替换上一篇预训练的 PretrainDataset 函数):

代码语言:javascript
代码运行次数:0
复制
class SFTDataset(Dataset):
    def __init__(self, data_path_lst, tokenizer, max_length=1024, prompt_max_len=512, answer_max_len=256):
        super().__init__()
        self.max_length = max_length
        self.prompt_max_len = prompt_max_len
        self.answer_max_len = answer_max_len

        data_lst = []
        for data_path in data_path_lst:
            with open(data_path, 'rb') as f:
                data = np.fromfile(f, dtype=np.uint16)
                data_lst.append(data)
        data = np.concatenate(data_lst)
        data = data[:max_length * int(len(data) / max_length)]
        self.data = data.reshape(-1, max_length)
        print("train data.shape:{}".format(self.data.shape))
        print("SFTDataset finished.....")

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index: int):
        sample = self.data[index]
        X = np.array(sample[:-1]).astype(np.int64)
        Y = np.array(sample[1:]).astype(np.int64)
        return torch.from_numpy(X), torch.from_numpy(Y)

SFT训练的数据集很大,训练时间较长,大概需要2-3天的时间,其中部分输出(从这里看loss值已经开始下降了):

代码语言:javascript
代码运行次数:0
复制
...
Epoch:[7/20](307000/341718) loss:0.644 lr:0.0000070 epoch_Time:13.0min:
Epoch:[7/20](308000/341718) loss:0.856 lr:0.0000070 epoch_Time:12.0min:
Epoch:[7/20](309000/341718) loss:0.424 lr:0.0000070 epoch_Time:12.0min:
Epoch:[7/20](310000/341718) loss:0.524 lr:0.0000070 epoch_Time:12.0min:
Epoch:[7/20](311000/341718) loss:0.272 lr:0.0000070 epoch_Time:11.0min:
Epoch:[7/20](312000/341718) loss:0.373 lr:0.0000069 epoch_Time:11.0min:
Epoch:[7/20](313000/341718) loss:0.387 lr:0.0000069 epoch_Time:11.0min:
Epoch:[7/20](314000/341718) loss:0.560 lr:0.0000069 epoch_Time:10.0min:
Epoch:[7/20](315000/341718) loss:0.365 lr:0.0000069 epoch_Time:10.0min:
Epoch:[7/20](316000/341718) loss:0.226 lr:0.0000069 epoch_Time:10.0min:
Epoch:[7/20](317000/341718) loss:0.666 lr:0.0000069 epoch_Time:9.0min:
Epoch:[7/20](318000/341718) loss:0.504 lr:0.0000069 epoch_Time:9.0min:
Epoch:[7/20](319000/341718) loss:0.534 lr:0.0000069 epoch_Time:9.0min:
Epoch:[7/20](320000/341718) loss:0.403 lr:0.0000069 epoch_Time:8.0min:
Epoch:[7/20](321000/341718) loss:0.445 lr:0.0000069 epoch_Time:8.0min:
Epoch:[7/20](322000/341718) loss:0.581 lr:0.0000069 epoch_Time:7.0min:
Epoch:[7/20](323000/341718) loss:0.655 lr:0.0000069 epoch_Time:7.0min:
Epoch:[7/20](324000/341718) loss:0.606 lr:0.0000069 epoch_Time:7.0min:
Epoch:[7/20](325000/341718) loss:0.480 lr:0.0000069 epoch_Time:6.0min:
Epoch:[7/20](326000/341718) loss:0.696 lr:0.0000069 epoch_Time:6.0min:
Epoch:[7/20](327000/341718) loss:0.634 lr:0.0000069 epoch_Time:6.0min:
Epoch:[7/20](328000/341718) loss:0.852 lr:0.0000069 epoch_Time:5.0min:
Epoch:[7/20](329000/341718) loss:0.717 lr:0.0000069 epoch_Time:5.0min:
Epoch:[7/20](330000/341718) loss:0.680 lr:0.0000069 epoch_Time:5.0min:
Epoch:[7/20](331000/341718) loss:0.415 lr:0.0000069 epoch_Time:4.0min:
Epoch:[7/20](332000/341718) loss:0.617 lr:0.0000069 epoch_Time:4.0min:
Epoch:[7/20](333000/341718) loss:0.647 lr:0.0000069 epoch_Time:4.0min:
Epoch:[7/20](334000/341718) loss:0.554 lr:0.0000069 epoch_Time:3.0min:
Epoch:[7/20](335000/341718) loss:0.746 lr:0.0000069 epoch_Time:3.0min:
Epoch:[7/20](336000/341718) loss:0.499 lr:0.0000069 epoch_Time:2.0min:
Epoch:[7/20](337000/341718) loss:0.318 lr:0.0000069 epoch_Time:2.0min:
Epoch:[7/20](338000/341718) loss:0.651 lr:0.0000069 epoch_Time:2.0min:
Epoch:[7/20](339000/341718) loss:0.424 lr:0.0000069 epoch_Time:1.0min:
Epoch:[7/20](340000/341718) loss:0.567 lr:0.0000069 epoch_Time:1.0min:
Epoch:[7/20](341000/341718) loss:0.568 lr:0.0000069 epoch_Time:1.0min:
...

参考

(1)https://github.com/karpathy/llama2.c/blob/master/train.py

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2024-11-10,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 周末程序猿 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1、什么是SFT
  • 2、整理SFT数据
  • 3、SFT训练
  • 参考
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档