前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >Text Generation Inference源码解读(二):模型加载与推理

Text Generation Inference源码解读(二):模型加载与推理

作者头像
BBuf
发布于 2024-02-22 00:21:08
发布于 2024-02-22 00:21:08
2.1K00
代码可运行
举报
文章被收录于专栏:GiantPandaCVGiantPandaCV
运行总次数:0
代码可运行

1. 前言

本文以TGI对Llama 2的支持为例,解读TGI的模型加载和推理实现,总结其中运用到的推理优化技巧,最后以TGI增加AWQ推理支持为例复盘模型加载逻辑。虽尽力保持行文简洁,但最后成文还是很长,请读者按需跳转阅读。本文所分析TGI代码版本为1.1.1

2. 背景知识

2.1. Llama 2模型结构

图片来源:https://zhuanlan.zhihu.com/p/649756898

上图是Llama 2的模型结构。结合模型结构简述计算流程:假设用户输入的是“<BOS> 新年快”,经分词编码得到词表的映射下标为[0,22,33,44]。Input Embedding负责将前述包含4个元素的Token序列转换为维度为[4, N]的Embedding张量后,数个Transformer Block将Embbeding张量变换得到维度仍为[4, N]的特征张量,将最后一个Token(“快”)对应的特征向量通过最后的Linear升维到词表维度和通过Softmax归一化,得到预测的下一个Token的概率(Tensor对应维度为[1, M],M为词表长度,类似于分类类别数)。如果按贪婪采样的规则选取下一个Token,则词表对应下标为55、概率最高的“乐”就成了下一个Token的预测结果。

本文重点介绍Llama 2的Transformer Block的实现。更详细的结构信息,例如RMSNorm、RoPE(Rotary Position Embedding)等,请参考链接。

2.2. 张量并行与模型切分

Attention的权重切分方案

Feed Forward部分的权重切分

张量并行(Tensor Parallel)的系统学习,可以参考这篇文章。笔者简单提醒2点:

  1. Attention部分和Feed Forward部分的均涉及2次权重切分和1次All Reduce通信。为使数学等价,必须2次权重切分的方向必须是不相同的(不信可以尝试,连续2次按列切分或连续2次按行切分后能否整合出结果);
  2. 切分顺序调换一下(权重先按行切分再按列切分)是否可行?数学上是可行的,但工程上述切分方案性能更好。主要原因是,先列后行得到的Z1、Z1是Z的部分和,为了得到Z需要做加和,分布式的加和通过All Reduce实现;而先行后列得到的Z1、Z2是Z的一部分,为了得到Z需要做拼接操作,分布式的拼接通过All Gather实现。相比较而言,前者在通信上的效率优于优于后者,所以一般使用先列切分后行切分的方式。

2.3. Flash Attention与Paged Attention

Flash Attention 和 Paged Attention用于加速如下的Attention模块红框部分:

Flash Attention 和 Paged Attention加速的部分

两者在提出时解决的问题有所不同:

  1. 上述操作的朴素实现存在的计算量小但不快(访存受限、高端GPU的算力发挥不出来)问题,Flash Attention针对此问题引入Tiling和Recomputation技巧提高计算效率;
  2. LLM推理广泛利用KV Cache技术加速,注意到生成过程中KV Cache是逐渐变长的,如果每个样本都按最大长度一次性分配KV Cache需要的显存,会有显存浪费的问题(实际生成文本的长短差别很大);而如果简单地对KV Cache做动态扩容(类似STL Vector的扩容机制),在大并发下因动态扩容产生显存分配和回收操作频繁且开销不可忽略(显存占用过高时会触发碎片整理),对吞吐率影响很大。Paged Attention借鉴操作系统的虚拟内存和分页思想,实现一个样本内连续的KV Cache离散地存储,在GPU计算时并行地将离散的KV Cache后整合再完成上述操作,从而有效地减少因KV Cahce动态扩容导致的显存管理开销。

工程上,Flash Attention主要有Dao(Flash Attention原创者)、xformers、Faster Transformer 、Pytorch4家实现。Paged Attention主要有vLLM(Paged Attenion原创者)、TensorRT-LLM 2家实现。TGI在Prifill环节使用了Dao版Flash Attention,在Decode环节使用了vLLM版 Paged Attention。原因是虽然vLLM版 Paged Attention的实现采用了Flash Attention的技巧,但缺少各样本query长度不等的Batch推理API(在Prefill环节需要此API)。出于此情况TGI同时使用了两者。

3. 模型加载

3.1. 整体流程

下图是TGI Server层加载一个Llama 2模型时的流程,其中标黑的是重要的类,可以对照上文的“2.1.Llama 2模型结构”进行分析。

图解:最顶上是入口函数,入口函数所在源码文件,入口函数首行在源码文件的位置(行数);子框是核心逻辑,标明位置和被调用函数;红色箭头表示调用和被调用的关系。为了绘图简洁省略了大部分的参数。

Llama2 模型加载流程(点击放大)

最核心的FlashLlamaAttentionLlamaMLP的初始化和权重加载逻辑将在下文具体展开。

3.2 FeedForward(LlamaMLP)

出于行文方便的考虑,按先FeedForward再Attention的顺序介绍。以下直接通过加注释的方式展现:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
class LlamaMLP(nn.Module):
    def __init__(self, prefix, config, weights):
        super().__init__()
        act = config.hidden_act
        # Llama 2的FeedForward的激活函数是SwishGeLU
        # TGI通过'silu(W0*x)''W1*x'点对点相乘实现
        # 在这里,self.act = torch.nn.functional.silu
        self.act = (
            ACT2FN[act]
            if "gelu" not in act
            else lambda x: torch.nn.functional.gelu(
                x,
                approximate="tanh"
                if act in ["gelu_fast", "gelu_pytorch_tanh"]
                else "none",
            )
        )
        # Fuse gate and up proj
        # "up_proj"指的是FeedForward的第一个FFN,该FFN用于升维,所以称“up”
        # "gate_proj"指的是SwishGeLU实现自门控的线性层,所以称“gate”
        # 它们的输入是相同的,所以可以把它们的权重拼接,合并做矩阵乘法
        self.gate_up_proj = TensorParallelColumnLinear.load_multi(
            config,
            prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
            weights=weights,
            dim=0, # 沿着按第0维将多个权重拼接在一起
            bias=False,
        )
        #  "down_proj"指的是FeedForward的第二个FFN,该FFN用于降维,所以称“down”
        self.down_proj = TensorParallelRowLinear.load(
            config,
            prefix=f"{prefix}.down_proj",
            weights=weights,
            bias=False,
        )
        # 第一个FFN升维后得到的向量长度
        # 这里考虑了模型切分后的实际长度
        self.intermediate_size = (
            config.intermediate_size // weights.process_group.size()
        )  

注意到,加载gate_up_proj权重的函数是TensorParallelColumnLinear.load_multi(),而加载self.down_proj权重的函数是TensorParallelRowLinear.load()。为什么分别用Column(按列切分)和Row(按行切分)加载,“2.2. 张量并行与模型切分”已说明。至于load_multi()和load(),区别在于前者加载多个权重并在某个维度拼接这些权重,后者仅加载一个权重。

上述权重加载函数的实现都在server/text_generation_server/utils/layers.py,不妨展开看一下TensorParallelRowLinear.load()的实现,同一个文件内的加载方法都大同小异:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# 位于 server/text_generation_server/utils/layers.pyclass TensorParallelRowLinear(SuperLayer):
    def __init__(self, linear, process_group):
        super().__init__(linear)
        # Pytorch集合通信初始化后得到process_group
        # 依据process_group与相应的GPU进行集合通信
        self.process_group = process_group

    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
        # weights在FlashLlama初始化过程中实例化(详见整体逻辑)
        # prefix是权重名
        # quantize指明量化方法,某些量化还需要额外加载Scale等权重
        # weight是加载到该GPU的Pytorch Tensor
        # 下文详细讲 get_multi_weights_row()的实现
        weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)

        # 如果有bias,则只在RANK 0做加偏置的操作
        # 如果每个GPU都加,数学上就不等价了
        if bias and weights.process_group.rank() == 0:
            # Rank is only on the first rank process
            bias = weights.get_tensor(f"{prefix}.bias")
        else:
            bias = None
        # get_linear是一个工厂模式方法(下文会详细介绍)
        # 传入weight、bias、quantize实例化一个线性层
        return cls(
            get_linear(weight, bias, config.quantize),
            process_group=weights.process_group,
        )

其中,weights.get_multi_weights_row()的实现:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# 位于 server/text_generation_server/utils/weights.pydef get_multi_weights_row(self, prefix: str, quantize: str):
        if quantize == "gptq":
            # 如果量化方法为“gptq”,从文件加载若干权重,此处逻辑省略
            weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
        elif quantize == "awq":
            # 与上类似,省略
            weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
        else:
            # 未量化权重(float32/float16),走这个加载逻辑
            # 注释在下面
            weight = self.get_sharded(f"{prefix}.weight", dim=1)
        return weightdef get_sharded(self, tensor_name: str, dim: int):
    # Weights类在初始化时,建立了权重名与权重文件的映射关系
    # 通过这个映射找到所在文件与(可能需要修正的)权重名
    filename, tensor_name = self.get_filename(tensor_name)
    # 拿到文件句柄,类型为safetensors
    f = self._get_handle(filename)
    # 根据权重名从safetensors取出对应的Tensor
    # 检查权重被切分的维度能被GPU数整除
    slice_ = f.get_slice(tensor_name)
    world_size = self.process_group.size()
    size = slice_.get_shape()[dim]
    assert (
        size % world_size == 0
    ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
    return self.get_partial_sharded(tensor_name, dim) # 注释在下面def get_partial_sharded(self, tensor_name: str, dim: int):
    # 这里的逻辑与get_sharded()差不多,不再重复
    filename, tensor_name = self.get_filename(tensor_name)
    f = self._get_handle(filename)
    slice_ = f.get_slice(tensor_name)
    world_size = self.process_group.size()
    rank = self.process_group.rank()

    # 按rank计算好切分的偏移量
    size = slice_.get_shape()[dim]
    block_size = size // world_size
    start = rank * block_size
    stop = (rank + 1) * block_size

    # 切出应加载的权重部分
    if dim == 0:
        tensor = slice_[start:stop]
    elif dim == 1:
        tensor = slice_[:, start:stop]
    else:
        raise NotImplementedError("Let's make that generic when needed")
    # Special case for gptq which shouldn't convert
    # u4 which are disguised as int32
    if tensor.dtype != torch.int32:
        tensor = tensor.to(dtype=self.dtype)
    tensor = tensor.to(device=self.device)
    return tensor

其中,get_linear()的实现:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# 位于 server/text_generation_server/utils/layers.pydef get_linear(weight, bias, quantize):
    # 根据是否量化和量化方式的不同,实例化不同类型的Linear(继承自nn.Module)
    if quantize is None:
        # FastLinear的实现贴在下面
        linear = FastLinear(weight, bias)
    elif quantize == "eetq":
        if HAS_EETQ:
            linear = EETQLinear(weight, bias)
        else:
            raise ImportError(
                "Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
            )
    # 其他的量化方法实例化Linear的逻辑类似,省略
    # elif quantize == "bitsandbytes":
    # elif quantize == "bitsandbytes-fp4":
    # elif quantize == "bitsandbytes-nf4":
    # elif quantize == "gptq":
    # elif quantize == "awq":
    else:
        raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
    return linear# 就是普通的torch.nn.functional.linearclass FastLinear(nn.Module):
    def __init__(
        self,
        weight,
        bias,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(weight)
        if bias is not None:
            self.bias = nn.Parameter(bias)
        else:
            self.bias = None

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.weight, self.bias)

这里稍微总结一下,上述逻辑完成了:

  1. 根据权重名找到权重所在文件;
  2. 加载权重并按照模型并行规则(按列切分或按行切分,权重是否融合)切分对应的权重;
  3. 利用切分好的权重实例化用于非量化/量化推理的Linear;
  4. 多个Linear实例和激活函数组成LlamaMLP实例。

3.3. Attention(FlashLlamaAttention)

同样以加注释的方式解读Attention部分的实现:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
class FlashLlamaAttention(torch.nn.Module):
    def __init__(
        self,
        prefix: str,
        config,
        weights,
    ):
        super().__init__()
        self.num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.num_heads

        # 用于RoPE(Rotary Position Embedding)的计算
        # self.rotary_emb = PositionRotaryEmbedding.load(
        #     config=config, prefix=f"{prefix}.rotary_emb", weights=weights
        # )
        self.rotary_emb = PositionRotaryEmbedding.static(
            config=config,
            dim=self.head_size,
            base=config.rope_theta,
            device=weights.device,
        )

        self.softmax_scale = self.head_size**-0.5

        if self.num_heads % weights.process_group.size() != 0:
            raise ValueError(
                f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
                f"and `num_shards`: {weights.process_group.size()}"
            )
        self.num_heads = self.num_heads // weights.process_group.size()
        self.num_key_value_heads = (
            config.num_key_value_heads // weights.process_group.size()
        )

        # load_attention()的实现在下面
        self.query_key_value = load_attention(config, prefix, weights)

        # 对o_proj按行切分并加载
        self.o_proj = TensorParallelRowLinear.load(
            config,
            prefix=f"{prefix}.o_proj",
            weights=weights,
            bias=False,
        )
        # PagedAttention的入参,用于支持GQA/MQA
        self.num_groups = self.num_heads // self.num_key_value_heads
        self.kv_head_mapping = torch.arange(
            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
        ).repeat_interleave(self.num_groups)

    def load_attention(config, prefix, weights):
        # GQA/MQA的加载
        if config.num_attention_heads != config.num_key_value_heads:
            return _load_gqa(config, prefix, weights)
        # MHA的加载
        else:
            # Baichuan 和 Llama在Attention部分有一定的差异,使用不同的加载实现
            # 但都是按列切分加载
            if config.model_type == "baichuan":
                return TensorParallelColumnLinear.load_qkv(
                    config,
                    prefix=f"{prefix}.W_pack",
                    weights=weights,
                    bias=False,
                )
            else:
                return TensorParallelColumnLinear.load_multi(
                    config,
                    prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
                    dim=0,
                    weights=weights,
                    bias=False,
                )

在Llama2中,对于Q_Proj、K_Proj、V_Proj(将输入分别仿射到Query、Key、Value的3个Linear),TGI使用的加载方法是TensorParallelColumnLinear.load_multi(),即把它们的权重拼接在一起并按列切分后加载。对于O_Proj,使用的权重加载方法是TensorParallelRowLinear.load(),即把O_Proj的权重按行切分后加载。

至于为什么先列切分再行切分,笔者在“2.2. 张量并行与模型切分”已作说明。至于TensorParallelColumnLinear.load_multi() 和 TensorParallelRowLinear.load()的实现,笔者在“3.2 FeedForward(LlamaMLP)”一节中做了解读,这里也不再赘述。

3. 模型推理

3.1. 整体流程

同样地,笔者整理了Llama2 推理的流程。其中标黑的是结构中重要的实例,可以对照上文的“2.1.Llama 2模型结构”进行分析。

图解:最顶上是入口函数,入口函数所在源码文件,入口函数首行在源码文件的位置(行数);子框是核心逻辑,标明位置和被调用函数;红色箭头表示调用和被调用的关系。为绘图简洁省略大部分的参数。

Llama2 推理流程(点击放大)

最核心的FlashLlamaAttentionLlamaMLP的推理逻辑将在下文具体展开。

3.2. FeedForward(LlamaMLP)

同样地,推理部分的解读也是通过加注释的方式展现,为方便读者理解摘抄初始化的部分逻辑:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# 位于 server/text_generation_server/models/custom_modeling/flash_llama_modeling.pyclass LlamaMLP(nn.Module):
    # __init__()的逻辑在上文注释过,这里不重复注释
    def __init__(self, prefix, config, weights):
        super().__init__()
        act = config.hidden_act
        self.act = () # 参数省略
        # Fuse gate and up proj
        self.gate_up_proj = TensorParallelColumnLinear.load_multi() # 参数省略
        self.down_proj = TensorParallelRowLinear.load() # 参数省略
        self.intermediate_size = (
            config.intermediate_size // weights.process_group.size()
        )

    def forward(self, hidden_states):
        # 通过gate_up_proj, 一次矩阵计算同时求出gate_states和up_states
        gate_up_states = self.gate_up_proj(hidden_states)
        # 通过Reshape和Slice操作将gate_up_states分离为gate_states和up_states
        gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
        # 在这里 self.act 为 torch.nn.functional.silu()
        # 通过self.down_proj括号内的计算,把gate_states和up_states组合得到SwishGeLU的结果
        # 最后通过self.down_proj()计算得到FeedForward中第二个Linear的输出
        return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])

FeedForward部分的逻辑相对简单。其中,gate_up_proj的类型是TensorParallelColumnLinear,down_proj的类型是TensorParallelRowLinear,我们不妨再深入分析一下它们的前向实现:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# 位于 server/text_generation_server/utils/layers.py# SuperLayer是TensorParallelColumnLinear和TensorParallelRowLinear的基类class SuperLayer(nn.Module):
    def __init__(self, linear):
        super().__init__()
        # 持有对应类型(量化/非量化)的linear
        self.linear = linear

    def forward(self, x):
        # 简单地调self.linear的前向
        return self.linear.forward(x)# TensorParallelColumnLinear没有重写forward()# 即调的是SuperLayer的forward()实现class TensorParallelColumnLinear(SuperLayer):
    @classmethod
    def load_qkv(cls, config, prefix: str, weights, bias: bool):
        # 省略
        pass

    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
        # 上文分析过
        return cls.load_multi(config, [prefix], weights, bias, dim=0)

    @classmethod
    def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
        # 省略
        pass# TensorParallelRowLinear重写了forward()class TensorParallelRowLinear(SuperLayer):
    def __init__(self, linear, process_group):
        super().__init__(linear)
        self.process_group = process_group

    @classmethod
    def load(cls, config, prefix: str, weights, bias: bool):
        # 上文分析过
        weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)

        if bias and weights.process_group.rank() == 0:
            # Rank is only on the first rank process
            bias = weights.get_tensor(f"{prefix}.bias")
        else:
            bias = None
        return cls(
            get_linear(weight, bias, config.quantize),
            process_group=weights.process_group,
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        # 调linear的forward()进行矩阵计算
        out = super().forward(input)
        # 如果self.process_group.size() > 1,意味着设置了张量并行(多卡推理)
        # 在必要的地方通过all_reduce整合计算结果,使结果在数学上与单卡推理一致
        if self.process_group.size() > 1:
            torch.distributed.all_reduce(out, group=self.process_group)
        return out

由上面的分析可知,TensorParallelColumnLinear的前向仅是linear的前向的包装;而TensorParallelRowLinear的前向的前向除了调用了linear的前向,在多卡推理时还会调All Reduce通信。至于为什么有这个差异,我们回到“2.2. 张量并行与模型切分”中,每个Layer只需要一次做All Reduce即可保持结果的一致,而且那次All Reduce安排在权重被按行切分的那个FFN后面。因此,TGI将必要的All Reduce通信整合到TensorParallelRowLinear中。

3.3. Attention(FlashLlamaAttention)

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# 位于 server/text_generation_server/models/custom_modeling/flash_llama_modeling.pyclass FlashLlamaAttention(torch.nn.Module):
    def __init__(
        self,
        prefix: str,
        config,
        weights,
    ):
        super().__init__()
        self.num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.num_heads

        self.rotary_emb = PositionRotaryEmbedding.static() # 参数省略

        self.softmax_scale = self.head_size**-0.5

        if self.num_heads % weights.process_group.size() != 0:
            raise ValueError
        self.num_heads = self.num_heads // weights.process_group.size()
        self.num_key_value_heads = (
            config.num_key_value_heads // weights.process_group.size()
        )

        self.query_key_value = load_attention(config, prefix, weights)

        self.o_proj = TensorParallelRowLinear.load() # 参数省略
        self.num_groups = self.num_heads // self.num_key_value_heads
        self.kv_head_mapping = torch.arange(
            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
        ).repeat_interleave(self.num_groups)

    def forward(
        self,
        hidden_states,
        cos,
        sin,
        cu_seqlen_prefill,
        kv_cache,
        block_tables,
        slots,
        input_lengths,
        max_s,
    ):
        # 将hidden_states仿射成QKV
        qkv = self.query_key_value(hidden_states)
        # 将QKV拆分为QKV并Reshape,这里考虑了GQA\MQA的情况
        query, kv = qkv.split(
            [
                self.head_size * self.num_heads,
                2 * self.head_size * self.num_key_value_heads,
            ],
            dim=1,
        )
        query = query.view(-1, self.num_heads, self.head_size)
        kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)

        # query和key需要加上RoPE(Rotary Position Embedding)
        # cos和sin已提前计算好,并在每个Layer复用 
        self.rotary_emb(query, cos, sin)
        self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)

        # 将新计算得到的Key和Value的Tensor存入PagedAttention管理的KV Cache中
        # kv是新计算出来的KV,
        # kv_cache是已存在的kv_cache
        # slots指示reshape_and_cache()将kv拷贝到kv_cache合适的位置
        paged_attention.reshape_and_cache(
            kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
        )

        # output tensor
        attn_output = torch.empty_like(query)

        # Prefill和Decode的计算差异,请回看本系列第()篇
        # Prefill
        if cu_seqlen_prefill is not None:
            # flash attention
            # 重点聊一下cu_seqlen_prefill和max_s
            # FlashAttention的这个API支持Batch操作
            # 也就是支持将不同样本的Q(K\V) Tensor拼接在一起
            # 只需要走一次推理,即可得到各样本的Attention结果
            # cu_seqlen_prefill用于指示各样本在拼接Tensor中的位置
            # max_s用于提示其中最长样本的长度,方便FA做调度
            flash_attn.attention(
                query, # Query
                torch.select(kv, dim=1, index=0), # New Key
                torch.select(kv, dim=1, index=1), # New Value
                attn_output, # 预分配好的Output
                cu_seqlen_prefill, 
                max_s,
                self.softmax_scale,
            )
        # Decode
        else:
            paged_attention.attention(
                attn_output, # 预分配好的Output
                query, # Query
                kv_cache[0], # Key Cache
                kv_cache[1], # Value Cache
                self.kv_head_mapping, # 用于处理GQA/MQA,KV的head与Q的hea数量d不等情况下如何映射
                self.softmax_scale,
                block_tables, # 用于指示这次计算所需KV Cache的存储位置
                input_lengths, # 目前各样本长度(len(KV_Cache)+1)
                max_s, # max_s用于提示其中最长样本的长度,方便FA做调度
            )

        # 最后将Attention计算得到的值送入self.o_proj()做一次仿射
        return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

源码中调用的flash_attn.attention()和paged_attention.attention(),分别是TGI对Dao版Flash Attention和vLLM版 Paged Attention的Python封装,读者若感兴趣可分析TGI源码中的flash_attn.py 和paged_attention.py。可能读者会感到疑惑,为什么TGI要使用2家Attention实现?答案是:PagedAttention虽然有Batch推理的API,但要求各样本query长度一样。这个约束在Decode阶段可以满足(Decode阶段,只需要拿各样本最新生成的Token做query即可,即各样本len(query)=1),但在Prefill阶段难以满足(一般情况下每个请求的prompt不相等,Prefill阶段query长度就是prompt长度)。所以,如果希望Prefill阶段做Batch推理以提高效率,PagedAttention是无法满足的,但好在FlashAttention有可以满足此需求的API(具体用的是API)。

4. TGI推理层优化技巧小结

笔者尝试总结一下TGI在推理层用到的优化技巧,可能不全,仅做抛砖引玉:

  1. 算子融合:包括上文提到的整合Attention的q_proj、k_proj、v_proj,整合FeedForward的up_proj和gate_proj。除此之外,还使用融合了RMSNorm与Resdual Add的算子(链接,出自Dao版Flash Attention),Rotary Position Embedding的单算子CUDA实现(链接,同样出自Dao版Flash Attention)等;
  2. 去除冗余计算:除了使用了KV Cache,还包括Llama 2每一个Layer都需要做RoPE的操作,TGI提前计算并缓存了所需要的cos和sin的值(实现在这里),并让各Layer复用;
  3. 灵活使用Attention API:为使Prefill阶段支持Batch操作,在Pefill和Decode阶段分别使用不同开源项目的API;
  4. Batched Sampling:TGI允许每个用户在请求中指定的各自Token采样方法(贪婪/Top P/Top K等,且可以组合),并以Batch的方式统一处理(通过Mask控制Logit Filter是否对某个样本生效)。以Batched Top K为例,实现在这里;
  5. 负载均衡应用于Tokenizer 解码字符:解码新字符时,需要将预测出来的Token ID从GPU拷贝到CPU才能使用Tokenizer解码,这个过程有一定的性能开销。在使用张量并行时,TGI允许各GPU计算进程独立地把Response回传给Router,所以每个GPU计算进程处理Batch中一部分样本的字符解码即可,实现在这里;
  6. 去除冗余数据提高Prefill效率:Prefill环节,除非需要给用户返回Prompt每个Token的概率,正常只需要算序列最后一个Token的logit即可。因此,可以切片切出最后一个Token的hidden states,再进行通信和计算,实现在这里。很明显,当Prompt特别长时这个优化会比较显著。

5. 实例分析:TGI对新量化推理方法的支持

最后,结合以上的解读工作,分析一下#PR1019(TGI对AWQ量化推理支持)。别看这个PR改了21个文件,实际上增加一个量化方法的支持,总结起来就是简单的几个步骤:

第1步:修改Launcher,支持传入新量化方法的关键字

 第一步
第一步

第二步,为AWQ增加一种新Linear的实现,必须实现__init__()和forward()方法

第二步
第二步

第三步,给weigths.py里面的Weights类的初始化函数和成员函数(比如get_multi_weights_row)增加加载AWQ权重的逻辑

第三步
第三步

第四步,给layers.py的get_linear()方法的增加实例化AWQ Linear的逻辑

第四步
第四步

第五步,额外补充量化算法特有参数的加载逻辑

第五步
第五步

第六步,由于AWQ是一个W4A16(即输入输出都是fp16,权重是int4)的量化算法,与其前后对接的模块接口(数据类型、内存布局等)完全兼容,故对AWQ算法而言推理部分完全不需要修改(反之,如果是一个W8A8量化算法,Linear输入输出都是int8类型的,前后模型是fp16类型的,那么推理部分就有得折腾了,涉及很多的反量化\重量化\类型转换等);

第七步,增加测试样例,增加AWQ Kernel编译等杂项。

文章来源:https://zhuanlan.zhihu.com/p/675292919

已获得作者授权

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2024-02-01,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 GiantPandaCV 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. 前言
  • 2. 背景知识
    • 2.1. Llama 2模型结构
    • 2.2. 张量并行与模型切分
    • 2.3. Flash Attention与Paged Attention
  • 3. 模型加载
    • 3.1. 整体流程
    • 3.2 FeedForward(LlamaMLP)
    • 3.3. Attention(FlashLlamaAttention)
  • 3. 模型推理
    • 3.1. 整体流程
    • 3.2. FeedForward(LlamaMLP)
    • 3.3. Attention(FlashLlamaAttention)
  • 4. TGI推理层优化技巧小结
  • 5. 实例分析:TGI对新量化推理方法的支持
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档