本文以TGI对Llama 2的支持为例,解读TGI的模型加载和推理实现,总结其中运用到的推理优化技巧,最后以TGI增加AWQ推理支持为例复盘模型加载逻辑。虽尽力保持行文简洁,但最后成文还是很长,请读者按需跳转阅读。本文所分析TGI代码版本为1.1.1。
图片来源:https://zhuanlan.zhihu.com/p/649756898
上图是Llama 2的模型结构。结合模型结构简述计算流程:假设用户输入的是“<BOS> 新年快”,经分词编码得到词表的映射下标为[0,22,33,44]。Input Embedding负责将前述包含4个元素的Token序列转换为维度为[4, N]的Embedding张量后,数个Transformer Block将Embbeding张量变换得到维度仍为[4, N]的特征张量,将最后一个Token(“快”)对应的特征向量通过最后的Linear升维到词表维度和通过Softmax归一化,得到预测的下一个Token的概率(Tensor对应维度为[1, M],M为词表长度,类似于分类类别数)。如果按贪婪采样的规则选取下一个Token,则词表对应下标为55、概率最高的“乐”就成了下一个Token的预测结果。
本文重点介绍Llama 2的Transformer Block的实现。更详细的结构信息,例如RMSNorm、RoPE(Rotary Position Embedding)等,请参考链接。
Attention的权重切分方案
Feed Forward部分的权重切分
张量并行(Tensor Parallel)的系统学习,可以参考这篇文章。笔者简单提醒2点:
Flash Attention 和 Paged Attention用于加速如下的Attention模块红框部分:
Flash Attention 和 Paged Attention加速的部分
两者在提出时解决的问题有所不同:
工程上,Flash Attention主要有Dao(Flash Attention原创者)、xformers、Faster Transformer 、Pytorch4家实现。Paged Attention主要有vLLM(Paged Attenion原创者)、TensorRT-LLM 2家实现。TGI在Prifill环节使用了Dao版Flash Attention,在Decode环节使用了vLLM版 Paged Attention。原因是虽然vLLM版 Paged Attention的实现采用了Flash Attention的技巧,但缺少各样本query长度不等的Batch推理API(在Prefill环节需要此API)。出于此情况TGI同时使用了两者。
下图是TGI Server层加载一个Llama 2模型时的流程,其中标黑的是重要的类,可以对照上文的“2.1.Llama 2模型结构”进行分析。
图解:最顶上是入口函数,入口函数所在源码文件,入口函数首行在源码文件的位置(行数);子框是核心逻辑,标明位置和被调用函数;红色箭头表示调用和被调用的关系。为了绘图简洁省略了大部分的参数。
Llama2 模型加载流程(点击放大)
最核心的FlashLlamaAttention和LlamaMLP的初始化和权重加载逻辑将在下文具体展开。
出于行文方便的考虑,按先FeedForward再Attention的顺序介绍。以下直接通过加注释的方式展现:
class LlamaMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.hidden_act
# Llama 2的FeedForward的激活函数是SwishGeLU
# TGI通过'silu(W0*x)'与'W1*x'点对点相乘实现
# 在这里,self.act = torch.nn.functional.silu
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
)
)
# Fuse gate and up proj
# "up_proj"指的是FeedForward的第一个FFN,该FFN用于升维,所以称“up”
# "gate_proj"指的是SwishGeLU实现自门控的线性层,所以称“gate”
# 它们的输入是相同的,所以可以把它们的权重拼接,合并做矩阵乘法
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0, # 沿着按第0维将多个权重拼接在一起
bias=False,
)
# "down_proj"指的是FeedForward的第二个FFN,该FFN用于降维,所以称“down”
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
# 第一个FFN升维后得到的向量长度
# 这里考虑了模型切分后的实际长度
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
注意到,加载gate_up_proj权重的函数是TensorParallelColumnLinear.load_multi(),而加载self.down_proj权重的函数是TensorParallelRowLinear.load()。为什么分别用Column(按列切分)和Row(按行切分)加载,“2.2. 张量并行与模型切分”已说明。至于load_multi()和load(),区别在于前者加载多个权重并在某个维度拼接这些权重,后者仅加载一个权重。
上述权重加载函数的实现都在server/text_generation_server/utils/layers.py,不妨展开看一下TensorParallelRowLinear.load()的实现,同一个文件内的加载方法都大同小异:
# 位于 server/text_generation_server/utils/layers.pyclass TensorParallelRowLinear(SuperLayer):
def __init__(self, linear, process_group):
super().__init__(linear)
# Pytorch集合通信初始化后得到process_group
# 依据process_group与相应的GPU进行集合通信
self.process_group = process_group
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
# weights在FlashLlama初始化过程中实例化(详见整体逻辑)
# prefix是权重名
# quantize指明量化方法,某些量化还需要额外加载Scale等权重
# weight是加载到该GPU的Pytorch Tensor
# 下文详细讲 get_multi_weights_row()的实现
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
# 如果有bias,则只在RANK 0做加偏置的操作
# 如果每个GPU都加,数学上就不等价了
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
# get_linear是一个工厂模式方法(下文会详细介绍)
# 传入weight、bias、quantize实例化一个线性层
return cls(
get_linear(weight, bias, config.quantize),
process_group=weights.process_group,
)
其中,weights.get_multi_weights_row()的实现:
# 位于 server/text_generation_server/utils/weights.pydef get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq":
# 如果量化方法为“gptq”,从文件加载若干权重,此处逻辑省略
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
elif quantize == "awq":
# 与上类似,省略
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else:
# 未量化权重(float32/float16),走这个加载逻辑
# 注释在下面
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weightdef get_sharded(self, tensor_name: str, dim: int):
# Weights类在初始化时,建立了权重名与权重文件的映射关系
# 通过这个映射找到所在文件与(可能需要修正的)权重名
filename, tensor_name = self.get_filename(tensor_name)
# 拿到文件句柄,类型为safetensors
f = self._get_handle(filename)
# 根据权重名从safetensors取出对应的Tensor
# 检查权重被切分的维度能被GPU数整除
slice_ = f.get_slice(tensor_name)
world_size = self.process_group.size()
size = slice_.get_shape()[dim]
assert (
size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim) # 注释在下面def get_partial_sharded(self, tensor_name: str, dim: int):
# 这里的逻辑与get_sharded()差不多,不再重复
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
world_size = self.process_group.size()
rank = self.process_group.rank()
# 按rank计算好切分的偏移量
size = slice_.get_shape()[dim]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
# 切出应加载的权重部分
if dim == 0:
tensor = slice_[start:stop]
elif dim == 1:
tensor = slice_[:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32
if tensor.dtype != torch.int32:
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
其中,get_linear()的实现:
# 位于 server/text_generation_server/utils/layers.pydef get_linear(weight, bias, quantize):
# 根据是否量化和量化方式的不同,实例化不同类型的Linear(继承自nn.Module)
if quantize is None:
# FastLinear的实现贴在下面
linear = FastLinear(weight, bias)
elif quantize == "eetq":
if HAS_EETQ:
linear = EETQLinear(weight, bias)
else:
raise ImportError(
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
)
# 其他的量化方法实例化Linear的逻辑类似,省略
# elif quantize == "bitsandbytes":
# elif quantize == "bitsandbytes-fp4":
# elif quantize == "bitsandbytes-nf4":
# elif quantize == "gptq":
# elif quantize == "awq":
else:
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
return linear# 就是普通的torch.nn.functional.linearclass FastLinear(nn.Module):
def __init__(
self,
weight,
bias,
) -> None:
super().__init__()
self.weight = nn.Parameter(weight)
if bias is not None:
self.bias = nn.Parameter(bias)
else:
self.bias = None
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight, self.bias)
这里稍微总结一下,上述逻辑完成了:
同样以加注释的方式解读Attention部分的实现:
class FlashLlamaAttention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
# 用于RoPE(Rotary Position Embedding)的计算
# self.rotary_emb = PositionRotaryEmbedding.load(
# config=config, prefix=f"{prefix}.rotary_emb", weights=weights
# )
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
)
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
# load_attention()的实现在下面
self.query_key_value = load_attention(config, prefix, weights)
# 对o_proj按行切分并加载
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
# PagedAttention的入参,用于支持GQA/MQA
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def load_attention(config, prefix, weights):
# GQA/MQA的加载
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
# MHA的加载
else:
# Baichuan 和 Llama在Attention部分有一定的差异,使用不同的加载实现
# 但都是按列切分加载
if config.model_type == "baichuan":
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.W_pack",
weights=weights,
bias=False,
)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
在Llama2中,对于Q_Proj、K_Proj、V_Proj(将输入分别仿射到Query、Key、Value的3个Linear),TGI使用的加载方法是TensorParallelColumnLinear.load_multi(),即把它们的权重拼接在一起并按列切分后加载。对于O_Proj,使用的权重加载方法是TensorParallelRowLinear.load(),即把O_Proj的权重按行切分后加载。
至于为什么先列切分再行切分,笔者在“2.2. 张量并行与模型切分”已作说明。至于TensorParallelColumnLinear.load_multi() 和 TensorParallelRowLinear.load()的实现,笔者在“3.2 FeedForward(LlamaMLP)”一节中做了解读,这里也不再赘述。
同样地,笔者整理了Llama2 推理的流程。其中标黑的是结构中重要的实例,可以对照上文的“2.1.Llama 2模型结构”进行分析。
图解:最顶上是入口函数,入口函数所在源码文件,入口函数首行在源码文件的位置(行数);子框是核心逻辑,标明位置和被调用函数;红色箭头表示调用和被调用的关系。为绘图简洁省略大部分的参数。
Llama2 推理流程(点击放大)
最核心的FlashLlamaAttention和LlamaMLP的推理逻辑将在下文具体展开。
同样地,推理部分的解读也是通过加注释的方式展现,为方便读者理解摘抄初始化的部分逻辑:
# 位于 server/text_generation_server/models/custom_modeling/flash_llama_modeling.pyclass LlamaMLP(nn.Module):
# __init__()的逻辑在上文注释过,这里不重复注释
def __init__(self, prefix, config, weights):
super().__init__()
act = config.hidden_act
self.act = () # 参数省略
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi() # 参数省略
self.down_proj = TensorParallelRowLinear.load() # 参数省略
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
def forward(self, hidden_states):
# 通过gate_up_proj, 一次矩阵计算同时求出gate_states和up_states
gate_up_states = self.gate_up_proj(hidden_states)
# 通过Reshape和Slice操作将gate_up_states分离为gate_states和up_states
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
# 在这里 self.act 为 torch.nn.functional.silu()
# 通过self.down_proj括号内的计算,把gate_states和up_states组合得到SwishGeLU的结果
# 最后通过self.down_proj()计算得到FeedForward中第二个Linear的输出
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
FeedForward部分的逻辑相对简单。其中,gate_up_proj的类型是TensorParallelColumnLinear,down_proj的类型是TensorParallelRowLinear,我们不妨再深入分析一下它们的前向实现:
# 位于 server/text_generation_server/utils/layers.py# SuperLayer是TensorParallelColumnLinear和TensorParallelRowLinear的基类class SuperLayer(nn.Module):
def __init__(self, linear):
super().__init__()
# 持有对应类型(量化/非量化)的linear
self.linear = linear
def forward(self, x):
# 简单地调self.linear的前向
return self.linear.forward(x)# TensorParallelColumnLinear没有重写forward()# 即调的是SuperLayer的forward()实现class TensorParallelColumnLinear(SuperLayer):
@classmethod
def load_qkv(cls, config, prefix: str, weights, bias: bool):
# 省略
pass
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
# 上文分析过
return cls.load_multi(config, [prefix], weights, bias, dim=0)
@classmethod
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
# 省略
pass# TensorParallelRowLinear重写了forward()class TensorParallelRowLinear(SuperLayer):
def __init__(self, linear, process_group):
super().__init__(linear)
self.process_group = process_group
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
# 上文分析过
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return cls(
get_linear(weight, bias, config.quantize),
process_group=weights.process_group,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
# 调linear的forward()进行矩阵计算
out = super().forward(input)
# 如果self.process_group.size() > 1,意味着设置了张量并行(多卡推理)
# 在必要的地方通过all_reduce整合计算结果,使结果在数学上与单卡推理一致
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out
由上面的分析可知,TensorParallelColumnLinear的前向仅是linear的前向的包装;而TensorParallelRowLinear的前向的前向除了调用了linear的前向,在多卡推理时还会调All Reduce通信。至于为什么有这个差异,我们回到“2.2. 张量并行与模型切分”中,每个Layer只需要一次做All Reduce即可保持结果的一致,而且那次All Reduce安排在权重被按行切分的那个FFN后面。因此,TGI将必要的All Reduce通信整合到TensorParallelRowLinear中。
# 位于 server/text_generation_server/models/custom_modeling/flash_llama_modeling.pyclass FlashLlamaAttention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.static() # 参数省略
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
self.o_proj = TensorParallelRowLinear.load() # 参数省略
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
):
# 将hidden_states仿射成QKV
qkv = self.query_key_value(hidden_states)
# 将QKV拆分为Q和KV并Reshape,这里考虑了GQA\MQA的情况
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
# query和key需要加上RoPE(Rotary Position Embedding)
# cos和sin已提前计算好,并在每个Layer复用
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
# 将新计算得到的Key和Value的Tensor存入PagedAttention管理的KV Cache中
# kv是新计算出来的KV,
# kv_cache是已存在的kv_cache
# slots指示reshape_and_cache()将kv拷贝到kv_cache合适的位置
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
# output tensor
attn_output = torch.empty_like(query)
# Prefill和Decode的计算差异,请回看本系列第(一)篇
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
# 重点聊一下cu_seqlen_prefill和max_s
# FlashAttention的这个API支持Batch操作
# 也就是支持将不同样本的Q(K\V) Tensor拼接在一起
# 只需要走一次推理,即可得到各样本的Attention结果
# cu_seqlen_prefill用于指示各样本在拼接Tensor中的位置
# max_s用于提示其中最长样本的长度,方便FA做调度
flash_attn.attention(
query, # Query
torch.select(kv, dim=1, index=0), # New Key
torch.select(kv, dim=1, index=1), # New Value
attn_output, # 预分配好的Output
cu_seqlen_prefill,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention.attention(
attn_output, # 预分配好的Output
query, # Query
kv_cache[0], # Key Cache
kv_cache[1], # Value Cache
self.kv_head_mapping, # 用于处理GQA/MQA,即KV的head与Q的hea数量d不等情况下如何映射
self.softmax_scale,
block_tables, # 用于指示这次计算所需KV Cache的存储位置
input_lengths, # 目前各样本长度(len(KV_Cache)+1)
max_s, # max_s用于提示其中最长样本的长度,方便FA做调度
)
# 最后将Attention计算得到的值送入self.o_proj()做一次仿射
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
源码中调用的flash_attn.attention()和paged_attention.attention(),分别是TGI对Dao版Flash Attention和vLLM版 Paged Attention的Python封装,读者若感兴趣可分析TGI源码中的flash_attn.py 和paged_attention.py。可能读者会感到疑惑,为什么TGI要使用2家Attention实现?答案是:PagedAttention虽然有Batch推理的API,但要求各样本query长度一样。这个约束在Decode阶段可以满足(Decode阶段,只需要拿各样本最新生成的Token做query即可,即各样本len(query)=1),但在Prefill阶段难以满足(一般情况下每个请求的prompt不相等,Prefill阶段query长度就是prompt长度)。所以,如果希望Prefill阶段做Batch推理以提高效率,PagedAttention是无法满足的,但好在FlashAttention有可以满足此需求的API(具体用的是API)。
笔者尝试总结一下TGI在推理层用到的优化技巧,可能不全,仅做抛砖引玉:
最后,结合以上的解读工作,分析一下#PR1019(TGI对AWQ量化推理支持)。别看这个PR改了21个文件,实际上增加一个量化方法的支持,总结起来就是简单的几个步骤:
第1步:修改Launcher,支持传入新量化方法的关键字
第二步,为AWQ增加一种新Linear的实现,必须实现__init__()和forward()方法
第三步,给weigths.py里面的Weights类的初始化函数和成员函数(比如get_multi_weights_row)增加加载AWQ权重的逻辑
第四步,给layers.py的get_linear()方法的增加实例化AWQ Linear的逻辑
第五步,额外补充量化算法特有参数的加载逻辑
第六步,由于AWQ是一个W4A16(即输入输出都是fp16,权重是int4)的量化算法,与其前后对接的模块接口(数据类型、内存布局等)完全兼容,故对AWQ算法而言推理部分完全不需要修改(反之,如果是一个W8A8量化算法,Linear输入输出都是int8类型的,前后模型是fp16类型的,那么推理部分就有得折腾了,涉及很多的反量化\重量化\类型转换等);
第七步,增加测试样例,增加AWQ Kernel编译等杂项。
文章来源:https://zhuanlan.zhihu.com/p/675292919
已获得作者授权
本文分享自 GiantPandaCV 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!
扫码关注腾讯云开发者
领取腾讯云代金券
Copyright © 2013 - 2025 Tencent Cloud. All Rights Reserved. 腾讯云 版权所有
深圳市腾讯计算机系统有限公司 ICP备案/许可证号:粤B2-20090059 深公网安备号 44030502008569
腾讯云计算(北京)有限责任公司 京ICP证150476号 | 京ICP备11018762号 | 京公网安备号11010802020287
Copyright © 2013 - 2025 Tencent Cloud.
All Rights Reserved. 腾讯云 版权所有