baichuan-7B 是百川智能开发的一个强大的中文基座模型,然后它本身并不具备对话能力。为了让它能够像 ChatGPT 一样进行对话,我们需要进行对话风格的微调。本项目就是一个简单的尝试,通过一万多条对话数据来使 baichuan-7B 拥有基础的对话能力。
这个小项目是基于Github repo —— LLM-Tuning 实现的,本文涉及到的相关的代码、数据集、教程都在仓库里,建议点个Star⭐️后配合本文食用😃🫡:
项目地址:https://github.com/beyondguo/LLM-Tuning
我们采用 HC3 (Human-ChatGPT Comparison Corpus) 数据集,来作为对话微调样本,见:
HC3 数据集包含了多个不同领域的 QA 问答对,且每个问题都至少包含一个人类和一个 ChatGPT 的回答,因此十分适合用于 ChatBot 的微调,MosaicML 的mpt-7b-chat
模型和 UC Berkeley 的 Koala-13B
模型都使用了 HC3 数据集进行开发。
下面是 HC3-Chinese 数据集的一个截面:
具体数据处理的代码,见GitHub仓库中的 hc3_data_prepare.py
.
最终得到 hc3_chatgpt_zh_specific_qa.json
文件。
执行 sh tokenize.sh
,进行分词:
CUDA_VISIBLE_DEVICES=0 python tokenize_dataset_rows.py \
--model_checkpoint baichuan-inc/baichuan-7B \
--input_file hc3_chatgpt_zh_specific_qa.json \
--prompt_key q \
--target_key a \
--save_name hc3_chatgpt_zh_specific_qa_baichuan-7B \
--max_seq_length 2000 \
--skip_overlength False
执行 sh train.sh
,进行训练:
CUDA_VISIBLE_DEVICES=0,1,2,3 python baichuan_lora_tuning.py \
--tokenized_dataset hc3_chatgpt_zh_specific_qa_baichuan-7B \
--lora_rank 4 \
--per_device_train_batch_size 16 \
--gradient_accumulation_steps 1 \
--num_train_epochs 2 \
--save_steps 200 \
--save_total_limit 2 \
--learning_rate 1e-4 \
--fp16 \
--remove_unused_columns false \
--logging_steps 50 \
--output_dir weights/hc3_chatgpt_zh_specific_qa_baichuan-7B
# 看看训练之后baichuan是否具备了Chat能力:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from transformers import TextStreamer
tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/baichuan-7B", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("baichuan-inc/baichuan-7B", device_map="auto", trust_remote_code=True)
# load LoRA:
model = PeftModel.from_pretrained(model, "weights/hc3_chatgpt_zh_specific_qa_baichuan-7B")
def chat(text):
streamer = TextStreamer(tokenizer,skip_prompt=True,skip_special_tokens=True)
inputs = tokenizer("问:"+text+"答:", return_tensors='pt') # 这里添加 "问:","答:",是为了跟我构造的训练数据对应,从而更好地引导模型进行回答
inputs = inputs.to('cuda:0')
output = model.generate(**inputs, max_new_tokens=1024,repetition_penalty=1.1, streamer=streamer)
输入:哎,最近晚上睡不着!
输入:你是谁开发啊
大家有什么想法欢迎来评论区讨论,或者来我的Github项目的discussion区讨论哦!