前面一篇文章介绍了《从0开发大模型之DeepSeek的GRPO》,并且实现了一个简单版本的 GRPO
代码,不过从工程领域来看,并没有复现DeepSeek-R1,于是最近申请了48G的显存,结合一些开源的方案复现aha monent
,并给出完整的代码和工具链。
aha monent
DeepSeek-R1
论文中提到,模型让作者「见证了强化学习的力量和美感」,在DeepSeek-R1-Zero的中间版本,「顿悟时刻」来了:模型学会了以人类的语气进行反思。
aha monent
NuminaMath-CoT
数据集提炼出来的从上一篇《从0开发大模型之DeepSeek的GRPO》中已经了解GRPO
的原理,其中一部分是包括奖励函数的设计,其中如何设计这里就省略,本文暂时参考其他复现R1的项目设使用了5个函数:
^<think>.*?</think><answer>.*?</answer>$
的返回则返回1,否则就返回0(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)
,最大值返回3,否则返回0N-gram
重复奖励为了提升性能和节省显存,这里使用了vLLM
,vLLM
是一个开源的大模型推理加速框架,通过PagedAttention
高效地管理attention
中缓存的张量,实现比HuggingFace Transformers
高14-24倍的吞吐量,从本文实验过程中发现,之前需要60G显存的,基本40G就能跑起来。
由于vLLM
的加载模型和Huggingface
的可以直接兼容,所以可以参考如下代码跑起来:
from vllm import LLM, SamplingParams
if __name__ == '__main__':
model_path = "{模型名称}"
model = LLM(model=model_path,
tensor_parallel_size=1,
trust_remote_code=True,
max_model_len=10000,
enforce_eager=True,
gpu_memory_utilization=0.5,
block_size=32)
sampling_params = SamplingParams(temperature=0, max_tokens=1, prompt_logprobs=20)
prompt = "vLLM是如何实现的?"
response = model.generate(prompt, sampling_params, use_tqdm=False)[0]
print(response, '\n\n', response.outputs)
Accelerate
和deepspeed
加速训练
Accelerate
是PyTorch
官方提供的分布式训练工具,而deepspeed
是由Microsoft
提供的分布式训练工具,最主要的区别在于支持的模型规模不同,deepspeed
支持更大规模的模型,deepspeed
还提供了更多的优化策略和工具,例如ZeRO
和Offload
等,Accelerate更加稳定和易于使用,适合中小规模的训练任务,不过huggingface
已经集成了deepspeed
,如果对于训练改几行代码即可,如下:
#!pip install accelerate
#!pip install deepspeed
import torch
import torch.nn.functional as F
from datasets import load_dataset
# 引入基础库accelerate
from accelerate import Accelerator
# 创建accelerator
accelerator = Accelerator()
# 修改设备信息
device = accelerator.device
model = torch.nn.Transformer().to(device)
optimizer = torch.optim.Adam(model.parameters())
dataset = load_dataset({需要加载的数据})
data = torch.utils.data.DataLoader(dataset, shuffle=True)
# 使用accelerator训练
model, optimizer, data = accelerator.prepare(model, optimizer, data)
model.train()
for epoch in range(10):
for source, targets in data:
source = source.to(device)
targets = targets.to(device)
optimizer.zero_grad()
output = model(source)
loss = F.cross_entropy(output, targets)
# 使用accelerator做backward
accelerator.backward(loss)
optimizer.step()
相关的配置可以参考zero3.yaml
文件或者运行accelerate config
。
需要安装 python>=3.10
和必要的库如下:
pip install transformers
pip install trl
pip install --upgrade trl
pip install latex2sympy2_extended math_verify
pip install flash_attn
pip install vllm
pip install deepspeed
pip install accelerate
运行的命令:
accelerate launch --config_file zero3.yaml 0-grpotrainer_r1.py
其中zero3.yaml
配置:
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: cpu
offload_param_device: cpu
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
完整的训练代码较大,请到本文的最后查看。
aha moment
从上图可以看出,模型从直接思考没有解出问题,但是后面反复添加一些思考步骤就正确了。
(1)安装过程中错误:ImportError: FlashAttention2 has been toggled on, but it cannot be used due to the following error: the package flash_attn seems to be not installed. Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2.
解决方案:
pip install -U flash-attn
(2)安装过程中错误:ImportError: vLLM is not available and use_vllm
is set to True. Please install vLLM with pip install vllm
to use it.
解决方案:
pip install -U vllm
(3)训练完的模型如何转换为运行的模型? 解决方案:
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
convert_zero_checkpoint_to_fp32_state_dict(
checkpoint_dir="./output/GRPO-R1-1.5B",
output_dir="./output/GRPO-R1-1.5B",
tag="global_step9055", # 模型保存的step文件
)
(4)如果进行模型测试? 解决方案:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftConfig
# 加载Qwen模型
# model_name = "Qwen/Qwen2.5-1.5B"
# 加载本地模型
model_name = "./output/GRPO-R1-1.5B"
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir="./model")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval()
device = torch.device("cuda"if torch.cuda.is_available() else"cpu")
print("device: ", device)
model.to(device)
chat_history_ids = None
whileTrue:
user_input = input("用户: ")
if user_input.lower() == "exit":
break
new_user_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt').to(device)
if chat_history_ids isnotNone:
input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
else:
input_ids = new_user_input_ids
chat_history_ids = model.generate(input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
bot_response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
print("机器人: ", bot_response)
from typing import Optional, Dict
import re, logging, os, sys, torch, math
import transformers
from transformers import (
AutoModelForCausalLM,
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
import datasets
from datasets import load_dataset
from trl import ModelConfig, ScriptArguments, GRPOConfig, GRPOTrainer, get_peft_config
from dataclasses import dataclass, field
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
logger = logging.getLogger(__name__)
def verify_answer(contents, solution):
rewards = []
for content, sol in zip(contents, solution):
gold_parsed = parse(
sol,
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
print('-'*100)
print(f'\ncontent:{content}\nsol:{sol}')
if len(gold_parsed) != 0:
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed="all",
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
# Reward 1 if the content is the same as the ground truth, 0 otherwise
reward = float(verify(answer_parsed, gold_parsed))
print('-'*100)
print(f'\nanswer_parsed:{answer_parsed}\ngold_parsed:{gold_parsed}\nreward:{reward}')
else:
reward = 1.0
print(f'Failed to parse gold solution: {sol}')
rewards.append(reward)
return rewards
def accuracy_reward(completions, solution, **kwargs):
"""Reward function that checks if the completion is the same as the ground truth."""
contents = [completion[0]["content"] for completion in completions]
rewards = verify_answer(contents, solution)
print(f'\naccuracy rewards:{rewards}')
return rewards
def format_reward(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content) for content in completion_contents]
rewards = [1.0if match else0.0for match in matches]
print('-'*100)
print('\nformat rewards:', rewards)
return rewards
def reasoning_steps_reward(completions, **kwargs):
"""Reward function that checks for clear step-by-step reasoning.
Regex pattern:
Step \d+: - matches "Step 1:", "Step 2:", etc.
^\d+\. - matches numbered lists like "1.", "2.", etc. at start of line
\n- - matches bullet points with hyphens
\n\* - matches bullet points with asterisks
First,|Second,|Next,|Finally, - matches transition words
"""
pattern = r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [len(re.findall(pattern, content)) for content in completion_contents]
# Magic nubmer 3 to encourage 3 steps and more, otherwise partial reward
return [min(1.0, count / 3) for count in matches]
def len_reward(completions: list[Dict[str, str]], solution: list[str], **kwargs) -> float:
"""Compute length-based rewards to discourage overthinking and promote token efficiency.
Taken from from the Kimi 1.5 tech report: https://arxiv.org/abs/2501.12599
Args:
completions: List of model completions
solutions: List of ground truth solutions
Returns:
List of rewards where:
- For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len)
- For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len))
"""
contents = [completion[0]["content"] for completion in completions]
# First check correctness of answers
correctness = verify_answer(contents, solution)
# Calculate lengths
lengths = [len(content) for content in contents]
min_len = min(lengths)
max_len = max(lengths)
# If all responses have the same length, return zero rewards
if max_len == min_len:
return [0.0] * len(completions)
rewards = []
for length, is_correct in zip(lengths, correctness):
lambda_val = 0.5 - (length - min_len) / (max_len - min_len)
reward = lambda_val if is_correct > 0.0else min(0, lambda_val)
rewards.append(float(reward))
return rewards
def get_cosine_scaled_reward(
min_value_wrong: float = -1.0,
max_value_wrong: float = -0.5,
min_value_correct: float = 0.5,
max_value_correct: float = 1.0,
max_len: int = 1000,
):
def cosine_scaled_reward(completions, solution, **kwargs):
"""Reward function that scales based on completion length using a cosine schedule.
Shorter correct solutions are rewarded more than longer ones.
Longer incorrect solutions are penalized less than shorter ones.
Args:
completions: List of model completions
solution: List of ground truth solutions
This function is parameterized by the following arguments:
min_value_wrong: Minimum reward for wrong answers
max_value_wrong: Maximum reward for wrong answers
min_value_correct: Minimum reward for correct answers
max_value_correct: Maximum reward for correct answers
max_len: Maximum length for scaling
"""
contents = [completion[0]["content"] for completion in completions]
rewards = []
correctness = verify_answer(contents, solution)
lengths = [len(content) for content in contents]
for gen_len, is_correct in zip(lengths, correctness):
# Apply cosine scaling based on length
progress = gen_len / max_len
cosine = math.cos(progress * math.pi)
if is_correct > 0:
min_value = min_value_correct
max_value = max_value_correct
else:
# Swap min/max for incorrect answers
min_value = max_value_wrong
max_value = min_value_wrong
reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine)
rewards.append(float(reward))
return rewards
return cosine_scaled_reward
def get_repetition_penalty_reward(ngram_size: int, max_penalty: float):
"""
Computes N-gram repetition penalty as described in Appendix C.2 of https://arxiv.org/abs/2502.03373.
Reference implementation from: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py
Args:
ngram_size: size of the n-grams
max_penalty: Maximum (negative) penalty for wrong answers
"""
if max_penalty > 0:
raise ValueError(f"max_penalty {max_penalty} should not be positive")
def zipngram(text: str, ngram_size: int):
words = text.lower().split()
return zip(*[words[i:] for i in range(ngram_size)])
def repetition_penalty_reward(completions, **kwargs) -> float:
"""
reward function the penalizes repetitions
ref implementation: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py
Args:
completions: List of model completions
"""
contents = [completion[0]["content"] for completion in completions]
rewards = []
for completion in contents:
if completion == "":
rewards.append(0.0)
continue
if len(completion.split()) < ngram_size:
rewards.append(0.0)
continue
ngrams = set()
total = 0
for ng in zipngram(completion, ngram_size):
ngrams.add(ng)
total += 1
scaling = 1 - len(ngrams) / total
reward = scaling * max_penalty
rewards.append(reward)
return rewards
return repetition_penalty_reward
SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
"<think> reasoning process here </think><answer> answer here </answer>"
)
@dataclass
class R1GRPOScriptArguments(ScriptArguments):
reward_funcs: list[str] = field(
default_factory = lambda: ["accuracy", "format"],
metadata = {
"help": f"List of reward functions. Available options: 'accuracy', 'format', 'reasoning_steps', 'len', 'get_cosine_scaled', 'get_repetition_penalty'"
},
)
cosine_min_value_wrong: float = field(
default=0.0,
metadata={"help": "Minimum reward for wrong answers"},
)
cosine_max_value_wrong: float = field(
default=-0.5,
metadata={"help": "Maximum reward for wrong answers"},
)
cosine_min_value_correct: float = field(
default=0.5,
metadata={"help": "Minimum reward for correct answers"},
)
cosine_max_value_correct: float = field(
default=1.0,
metadata={"help": "Maximum reward for correct answers"},
)
cosine_max_len: int = field(
default=1000,
metadata={"help": "Maximum length for scaling"},
)
repetition_n_grams: int = field(
default=3,
metadata={"help": "Number of n-grams for repetition penalty reward"},
)
repetition_max_penalty: float = field(
default=-1.0,
metadata={"help": "Maximum (negative) penalty for for repetition penalty reward"},
)
@dataclass
class R1GRPOConfig(GRPOConfig):
"""
args for callbacks, benchmarks etc
"""
benchmarks: list[str] = field(
default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
)
callbacks: list[str] = field(
default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
)
system_prompt: Optional[str] = field(
default=None, metadata={"help": "The optional system prompt to use for benchmarking."}
)
def main(script_args, training_args, model_args):
# Set seed for reproducibility
set_seed(training_args.seed)
###############
# Setup logging
###############
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process a small summary
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Script parameters {script_args}")
logger.info(f"Data parameters {training_args}")
# Check for last checkpoint
last_checkpoint = None
if os.path.isdir(training_args.output_dir):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
logger.info(f"Last checkpoint detected, resuming training at {last_checkpoint=}.")
if last_checkpoint isnotNoneand training_args.resume_from_checkpoint isNone:
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
# Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
# Get reward functions
REWARD_FUNCS_REGISTRY = {
"accuracy": accuracy_reward,
"format": format_reward,
"reasoning_steps": reasoning_steps_reward,
"cosine": get_cosine_scaled_reward(
min_value_wrong=script_args.cosine_min_value_wrong,
max_value_wrong=script_args.cosine_max_value_wrong,
min_value_correct=script_args.cosine_min_value_correct,
max_value_correct=script_args.cosine_max_value_correct,
max_len=script_args.cosine_max_len,
),
"repetition_penalty": get_repetition_penalty_reward(
ngram_size=script_args.repetition_n_grams,
max_penalty=script_args.repetition_max_penalty,
),
"length": len_reward,
}
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
# Format into conversation
def make_conversation(example):
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["problem"]},
],
}
dataset = dataset.map(make_conversation)
for split in dataset:
if"messages"in dataset[split].column_names:
dataset[split] = dataset[split].remove_columns("messages")
logger.info("*** Initializing model kwargs ***")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
training_args.gradient_checkpointing = True
model_kwargs = dict(
revision = model_args.model_revision,
trust_remote_code = model_args.trust_remote_code,
attn_implementation = model_args.attn_implementation,
torch_dtype = torch_dtype,
use_cache = Falseif training_args.gradient_checkpointing elseTrue,
)
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path,
load_in_4bit=False, **model_kwargs)
print(model_args.model_name_or_path)
#############################
# Initialize the R1GRPO trainer
#############################
trainer = GRPOTrainer(
model = model,
reward_funcs = reward_funcs,
args = training_args,
train_dataset = dataset[script_args.dataset_train_split],
eval_dataset = dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no"elseNone,
peft_config = get_peft_config(model_args),
)
###############
# Training loop
###############
logger.info("*** Train ***")
checkpoint = None
if training_args.resume_from_checkpoint isnotNone:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint isnotNone:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
##################################
# Save model and create model card
##################################
logger.info("*** Save model ***")
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")
# Save everything else on main process
kwargs = {
"dataset_name": script_args.dataset_name,
"tags": ["GRPOTrainer-R1"],
}
if trainer.accelerator.is_main_process:
trainer.create_model_card(**kwargs)
# Restore k,v cache for fast inference
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir)
script_config = {
"dataset_name": "AI-MO/NuminaMath-TIR",
"dataset_config": "default",
"reward_funcs": [
"accuracy",
"format",
"reasoning_steps",
]
}
training_config = {
"output_dir": "output/GRPO-R1-1.5B", # 模型输出目录
"overwrite_output_dir": True, # 是否覆盖输出目录
"do_eval": True, # 是否进行评估
"eval_strategy": "steps", # 评估策略,按步数进行评估
"eval_steps": 100, # 每100步进行一次评估
"per_device_train_batch_size": 4, # 每个设备上的训练批次大小
"per_device_eval_batch_size": 4, # 每个设备上的评估批次大小
"gradient_accumulation_steps": 8, # 梯度累积步数
"learning_rate": 1.0e-06, # 学习率
"num_train_epochs": 1.0, # 训练的总轮数
"max_steps": -1, # 最大训练步数,-1表示不限制
"lr_scheduler_type": "cosine", # 学习率调度器类型,使用余弦退火
"warmup_ratio": 0.1, # 预热比例
"log_level": "info", # 日志记录级别
"logging_strategy": "steps", # 日志记录策略,按步数记录
"logging_steps": 100, # 每100步记录一次日志
"save_strategy": "no", # 保存策略,不保存
"seed": 42, # 随机种子
"bf16": True, # 是否使用bfloat16精度
"gradient_checkpointing": True, # 是否使用梯度检查点
"gradient_checkpointing_kwargs": {
"use_reentrant": False# 梯度检查点的额外参数,是否使用reentrant模式
},
"max_prompt_length": 128, # 最大提示长度
"num_generations": 4, # 生成的数量
"max_completion_length": 256, # 最大完成长度
"use_vllm": True, # 是否使用vLLM
"vllm_device": "auto", # vLLM设备,自动选择
"vllm_gpu_memory_utilization": 0.8, # vLLM GPU内存利用率
"resume_from_checkpoint": "output/GRPO-R1-1.5B", # 恢复检查点,如果没有latest文件,需要添加latest文件类似`global_step9055`
}
model_config = {
"model_name_or_path": "Qwen/Qwen2.5-1.5B-Instruct",
"model_revision": "main",
"torch_dtype": "bfloat16",
"attn_implementation": "flash_attention_2",
}
if __name__ == "__main__":
script_args = R1GRPOScriptArguments(**script_config)
training_args = R1GRPOConfig(**training_config)
model_args = ModelConfig(**model_config)
main(script_args, training_args, model_args)
(1)https://github.com/agentica-project/deepscaler (2)https://huggingface.co/datasets/agentica-org/DeepScaleR-Preview-Dataset (3)https://zhuanlan.zhihu.com/p/21393382793 (4)https://github.com/hkust-nlp/simpleRL-reason (5)https://mp.weixin.qq.com/s/RbQnInTa00ZISvJL7vORzA (6)https://zhuanlan.zhihu.com/p/629644249