前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >FlashAttention:快速且内存高效的准确注意力机制

FlashAttention:快速且内存高效的准确注意力机制

作者头像
857技术社区
发布2024-07-04 11:17:13
2040
发布2024-07-04 11:17:13
举报
文章被收录于专栏:857-Bigdata857-Bigdata

在深度学习领域,注意力机制是提高模型性能的关键组件。然而,传统的注意力机制在长序列处理时会消耗大量内存和计算资源。为了解决这个问题,Tri Dao等人提出了FlashAttention,这是一种快速且内存高效的注意力机制。本文将介绍FlashAttention及其改进版FlashAttention-2的核心概念、安装方法和使用示例。

论文介绍

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

  • 作者: Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
  • 论文链接: arxiv.org/abs/2205.14135

FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

  • 作者: Tri Dao
  • 论文链接: flash2.pdf

安装和特性

环境要求

  • CUDA: 11.6及以上
  • PyTorch: 1.12及以上
  • 操作系统: Linux(从v2.3.2开始有部分Windows的正面反馈,但Windows编译仍需更多测试)

我们推荐使用Nvidia的PyTorch容器,其中包含安装FlashAttention所需的所有工具。

安装步骤

  1. 确保已安装PyTorch
  2. 安装packagingpip install packaging
  3. 安装ninja并确保其正常工作:ninja --version && echo $?应返回退出码0。如果未返回0,重新安装ninja:pip uninstall -y ninja && pip install ninja
使用pip安装
代码语言:javascript
复制
pip install flash-attn --no-build-isolation
从源码编译
代码语言:javascript
复制
python setup.py install
控制并行编译任务数(适用于RAM少于96GB且有多个CPU核心的机器)
代码语言:javascript
复制
MAX_JOBS=4 pip install flash-attn --no-build-isolation

使用示例

FlashAttention主要实现了缩放点积注意力(softmax(Q @ K^T * softmax_scale) @ V)。以下是使用FlashAttention的核心函数:

代码语言:javascript
复制
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func

# 当Q, K, V已堆叠为一个张量时,使用flash_attn_qkvpacked_func
out = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
                                window_size=(-1, -1), alibi_slopes=None, deterministic=False)

# 直接使用Q, K, V时,使用flash_attn_func
out = flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
                      window_size=(-1, -1), alibi_slopes=None, deterministic=False)

参数说明

  • qkv: (batch_size, seqlen, 3, nheads, headdim)格式的张量,包含Q, K, V
  • dropout_p: float,Dropout概率
  • softmax_scale: float,softmax前QK^T的缩放比例,默认为1 / sqrt(headdim)
  • causal: bool,是否应用因果注意力掩码(如用于自回归建模)
  • window_size: (left, right),如果不为(-1, -1),则实现滑动窗口局部注意力
  • alibi_slopes: (nheads,)或(batch_size, nheads),fp32。对查询i和键j的注意力分数加上一个偏置(-alibi_slope * |i - j|)
  • deterministic: bool,是否使用确定性实现的反向传播(略慢且使用更多内存)

性能表现

加速效果

FlashAttention在A100 80GB SXM5 GPU上使用FP16/BF16格式时的加速效果如下:

  • Head Dimension: 64或128
  • Hidden Dimension: 2048(即32或16个heads)
  • Sequence Length: 512, 1k, 2k, 4k, 8k, 16k
  • Batch Size: 16k / seqlen

内存节省

FlashAttention在处理较长序列时能显著节省内存。与标准注意力机制内存使用随序列长度二次增长不同,FlashAttention的内存使用线性增长。在序列长度为2K时可节省10倍内存,4K时可节省20倍内存。

完整模型代码和训练脚本

已发布了完整的GPT模型实现,并提供了其他层(如MLP、LayerNorm、交叉熵损失、旋转嵌入)的优化实现。整体上,训练速度较基线实现(如Huggingface实现)提高3-5倍,达到每A100 225 TFLOPs/sec,相当于72%的模型FLOPs利用率。

FlashAttention 更新日志

2.0:完全重写,速度提升2倍

FlashAttention在2.0版本中进行了完全重写,速度提升了两倍。本次更新引入了多个更改和改进,包括一些函数名称的更改以及在输入具有相同序列长度的情况下简化了使用方式。 FlashAttention-2是对原始FlashAttention算法的一系列改进,旨在优化在GPU上的计算性能。本文详细讨论了FlashAttention-2的算法、并行性以及工作分区策略。

算法

FlashAttention-2的关键优化点在于减少非矩阵乘法(matmul)的浮点运算,以充分利用GPU上的专用计算单元(如Nvidia GPU上的Tensor Cores),这些单元在处理matmul操作(尤其是在FP16/BF16格式下)时性能显著优化。该优化的目标是通过尽可能多地执行matmul操作来最大化GPU的吞吐量。

前向传播
  1. 在线Softmax技巧:FlashAttention-2对在线Softmax计算进行了修改,以最小化非matmul浮点操作:
    • 避免通过 diag(ℓ(2))^-1 重新缩放输出更新的两个项。
    • 维持一个“未缩放”的O(2)版本,并保留统计信息 ℓ(2)。
    • 仅在循环结束时,通过 diag(ℓ(last))^-1 缩放最终的O(last)以获得正确的输出。
  2. 最大化matmul FLOPs:为了最大化GPU的性能,FlashAttention-2重点优化了matmul操作,因为现代GPU上的专用单元(如Tensor Cores)在这些操作上表现出色。以Nvidia A100 GPU为例,其FP16/BF16 matmul的理论吞吐量可以达到312 TFLOPs/s,而非matmul FP32的吞吐量仅为19.5 TFLOPs/s。因此,FlashAttention-2通过优化算法,尽可能地减少非matmul操作,从而保持高吞吐量的执行效率。
  3. 算法细节:FlashAttention-2的前向传播通过以下步骤实现:
    • 将输入矩阵Q、K、V分成大小为𝐵𝑟 × 𝑑的𝑇𝑟块,将输出矩阵O和logsumexp𝐿也相应地分块。
    • 在每个线程块内部分配工作以最大化GPU资源的利用。
    • 引入了在线Softmax技巧,通过有效管理和缩放中间结果,减少了不必要的计算开销。

反向传播

FlashAttention-2的反向传播与FlashAttention类似,但也有一些微调:

  • 仅使用逐行logsumexp 𝐿,而不是softmax中的最大值和指数和。
  • 使用类似的分块策略来优化计算和内存访问,以提高反向传播的效率和性能。

FlashAttention-2在并行性和工作分区方面进行了深入优化,以在GPU上实现更高的计算效率和性能。本节详细讨论了FlashAttention-2的并行化策略和工作分区方法。

并行性

前向传播

在FlashAttention-2中,前向传播的并行化策略如下:

  1. 线程块调度:每个注意力头使用一个线程块来处理,总共有batch size × number of heads个线程块。每个线程块被调度到一个流多处理器(SM)上执行。例如,Nvidia A100 GPU上有108个这样的SM。这种调度在大量线程块(如≥ 80)时非常高效,因为可以充分利用GPU的计算资源。
  2. 对长序列的优化:对于长序列(通常意味着较小的batch size或较少的头数),为了更好地利用GPU上的多处理器,FlashAttention-2额外并行化了序列长度维度。这在这种情况下显著提高了性能和效率。
反向传播

在反向传播中,为了避免在不同列块之间的共享计算,FlashAttention-2采用了类似的并行化策略:

  • 线程块调度:每个列块使用一个线程块来处理。通过使用原子加操作来在不同线程块之间进行通信,以更新dQ,从而避免了共享内存的读写冲突。

工作分区

前向传播

在前向传播中,FlashAttention-2改进了工作分区策略,避免了FlashAttention中的"split-K"方案,具体包括:

  • K和V的分割:FlashAttention-2将Q分割到4个线程束(warp)中,同时使得K和V对所有线程束可访问。每个线程束执行矩阵乘法以获取QK>的一部分,并将其与V的一部分相乘,从而获得对应输出的片段。这种改进减少了线程束之间的通信,降低了共享内存的读写次数,从而提升了性能。
反向传播

在反向传播中,为了避免"split-K"方案带来的同步问题,FlashAttention-2选择了适当的线程束分区策略,以优化计算和内存访问效率。

函数重命名

以下函数的名称已更新,以反映其更新后的功能:

  • flash_attn_unpadded_func -> flash_attn_varlen_func
  • flash_attn_unpadded_qkvpacked_func -> flash_attn_varlen_qkvpacked_func
  • flash_attn_unpadded_kvpacked_func -> flash_attn_varlen_kvpacked_func

如果输入在同一批次中具有相同的序列长度,使用以下函数将更加简单和快速:

  • flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
  • flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)

2.1:更改causal标志的行为

如果 seqlen_q != seqlen_k 并且 causal=True,则causal掩码将对齐到注意力矩阵的右下角,而不是左上角。

例如,如果 seqlen_q = 2seqlen_k = 5,则causal掩码(1 = 保留,0 = 掩盖)如下:

v2.0版本:

代码语言:javascript
复制
1 0 0 0 0
1 1 0 0 0

v2.1版本:

代码语言:javascript
复制
1 1 1 1 0
1 1 1 1 1

如果 seqlen_q = 5seqlen_k = 2,则causal掩码如下:

v2.0版本:

代码语言:javascript
复制
1 0
1 1
1 1
1 1
1 1

v2.1版本:

代码语言:javascript
复制
0 0
0 0
0 0
1 0
1 1

如果掩码的行全为零,则输出也将为零。

2.2:针对推理进行优化

在查询序列长度非常短(例如查询序列长度=1)的情况下,针对推理(迭代解码)进行优化。这里的瓶颈是尽可能快地加载KV缓存,我们通过不同线程块分割加载,并使用一个单独的内核来合并结果。

请参阅具有更多推理功能的 flash_attn_with_kvcache 函数(执行旋转嵌入,原地更新KV缓存)。

感谢xformers团队,特别是Daniel Haziza的合作。

2.3:局部(即滑动窗口)注意力

实现滑动窗口注意力(即局部注意力)。感谢Mistral AI团队,特别是Timothée Lacroix的贡献。滑动窗口被用于Mistral 7B模型中。

2.4:ALiBi(线性偏差注意力),确定性反向传播

实现ALiBi(Press等人,2021)。感谢Kakao Brain的Sanghun Cho的贡献。

实现确定性反向传播。感谢美团的工程师们的贡献。

2.5:分页KV缓存

支持分页KV缓存(即PagedAttention)。感谢 @beginlner 的贡献。

代码目录flash_attn/modules/block.py

代码解读:Block

在这篇博客中,我们将逐段解读 Block 类的代码。该类实现了一个通用的块结构,广泛应用于Transformer等模型中。我们将详细介绍代码实现流程和所用到的理论基础。

理论基础

在Transformer架构中,最基本的构件是编码器和解码器层(block)。每个层通常包括以下部分:

  1. 多头自注意力机制:用于计算每个词对其他词的注意力权重。
  2. 前馈神经网络:对每个词的表示进行非线性变换。
  3. 残差连接和层归一化:为了稳定训练,添加了残差连接和层归一化。

有两种常见的层结构:

  • Prenorm结构:层归一化在主要操作(注意力或前馈神经网络)之前应用。
  • Postnorm结构:层归一化在主要操作之后应用。
代码实现流程
代码语言:javascript
复制
class Block(nn.Module):
    def __init__(
        self,
        dim,
        mixer_cls=None,
        mlp_cls=None,
        norm_cls=nn.LayerNorm,
        dropout_cls=nn.Dropout,
        prenorm=True,
        resid_dropout1=0.0,
        resid_dropout2=0.0,
        drop_path1=0.0,
        drop_path2=0.0,
        fused_dropout_add_ln=False,
        return_residual=False,
        residual_in_fp32=False,
        sequence_parallel=False,
        mark_shared_params=False,
    ):

这段代码定义了 Block 类的构造函数。以下是参数的解释:

  • dim:输入和输出的维度。
  • mixer_cls:用于计算注意力的类。
  • mlp_cls:用于前馈神经网络的类。
  • norm_cls:用于层归一化的类。
  • dropout_cls:用于Dropout的类。
  • prenorm:是否使用Prenorm结构。
  • resid_dropout1resid_dropout2:残差连接的Dropout率。
  • drop_path1drop_path2:用于Stochastic Depth的参数。
  • fused_dropout_add_ln:是否融合Dropout、Add和LayerNorm操作。
  • return_residual:是否在每个子层返回残差。
  • residual_in_fp32:是否使用FP32精度保存残差。
  • sequence_parallel:是否并行处理序列。
  • mark_shared_params:是否标记共享参数。
代码语言:javascript
复制
        super().__init__()
        self.prenorm = prenorm
        self.fused_dropout_add_ln = fused_dropout_add_ln
        self.return_residual = return_residual
        self.residual_in_fp32 = residual_in_fp32
        if self.residual_in_fp32:
            assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
        if mixer_cls is None:
            mixer_cls = partial(MHA, num_heads=dim // 64)
        if mlp_cls is None:
            mlp_cls = partial(Mlp, hidden_features=4 * dim)
        self.mixer = mixer_cls(dim)
        self.dropout1 = dropout_cls(resid_dropout1)
        self.drop_path1 = StochasticDepth(drop_path1, mode="row")
        self.norm1 = norm_cls(dim)
        self.mlp = mlp_cls(dim)
        if not isinstance(self.mlp, nn.Identity):
            self.dropout2 = dropout_cls(resid_dropout2)
            self.drop_path2 = StochasticDepth(drop_path2, mode="row")
            self.norm2 = norm_cls(dim)

在构造函数中,首先初始化了各个参数。根据 prenormfused_dropout_add_ln 等标志设置了一些断言和默认值。如果没有提供 mixer_clsmlp_cls,则使用默认的多头注意力机制(MHA)和前馈神经网络(Mlp)。

接下来,初始化了 mixerdropoutStochasticDepthnorm 层。

代码语言:javascript
复制
        if self.fused_dropout_add_ln:
            assert layer_norm_fn is not None, "Triton is not installed"
            assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
                self.dropout1, nn.Dropout
            )

        if sequence_parallel:
            for p in self.norm1.parameters():
                p._sequence_parallel = True
            if hasattr(self, "norm2"):
                for p in self.norm2.parameters():
                    p._sequence_parallel = True

        if mark_shared_params:
            for p in self.norm1.parameters():
                p._shared_params = True
            if hasattr(self, "norm2"):
                for p in self.norm2.parameters():
                    p._shared_params = True

这段代码处理了 fused_dropout_add_lnsequence_parallelmark_shared_params 的情况。如果启用了 fused_dropout_add_ln,则确保安装了 Triton,并且 norm1dropout1 是有效类型。如果启用了 sequence_parallel,则将 norm1norm2 的参数标记为需要序列并行。如果启用了 mark_shared_params,则将这些参数标记为共享参数。

代码语言:javascript
复制
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)

    def forward(
        self,
        hidden_states: Tensor,
        residual: Optional[Tensor] = None,
        mixer_subset=None,
        mixer_kwargs=None,
    ):

定义了两个方法:

  • allocate_inference_cache:为推理阶段分配缓存。
  • forward:前向传播函数,处理输入的 hidden_statesresidual
代码语言:javascript
复制
        if self.prenorm:
            if not self.fused_dropout_add_ln:
                dropped = self.drop_path1(self.dropout1(hidden_states))
                residual = (dropped + residual) if residual is not None else dropped
                hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
                if self.residual_in_fp32:
                    residual = residual.to(torch.float32)
            else:
                if self.drop_path1.p == 0 or not self.training:
                    rowscale1 = None
                else:
                    rowscale1 = self.drop_path1(
                        torch.ones(
                            hidden_states.shape[:-1],
                            device=hidden_states.device,
                            dtype=hidden_states.dtype,
                        )
                    )
                hidden_states, residual = layer_norm_fn(
                    hidden_states,
                    self.norm1.weight,
                    self.norm1.bias,
                    residual=residual,
                    eps=self.norm1.eps,
                    dropout_p=self.dropout1.p if self.training else 0.0,
                    rowscale=rowscale1,
                    prenorm=True,
                    residual_in_fp32=self.residual_in_fp32,
                    is_rms_norm=isinstance(self.norm1, RMSNorm)
                )
            if mixer_kwargs is None:
                mixer_kwargs = {}
            if mixer_subset is not None:
                mixer_kwargs["mixer_subset"] = mixer_subset
            hidden_states = self.mixer(hidden_states, **mixer_kwargs)
            if mixer_subset is not None:
                residual = residual[:, mixer_subset]
            if not isinstance(self.mlp, nn.Identity):
                if not self.fused_dropout_add_ln:
                    dropped = self.drop_path2(self.dropout2(hidden_states))
                    residual = (dropped + residual) if residual is not None else dropped
                    hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
                    if self.residual_in_fp32:
                        residual = residual.to(torch.float32)
                else:
                    if self.drop_path2.p == 0 or not self.training:
                        rowscale2 = None
                    else:
                        rowscale2 = self.drop_path2(
                            torch.ones(
                                hidden_states.shape[:-1],
                                device=hidden_states.device,
                                dtype=hidden_states.dtype,
                            )
                        )
                    hidden_states, residual = layer_norm_fn(
                        hidden_states,
                        self.norm2.weight,
                        self.norm2.bias,
                        residual=residual,
                        eps=self.norm2.eps,
                        dropout_p=self.dropout2.p if self.training else 0.0,
                        rowscale=rowscale2,
                        prenorm=True,
                        residual_in_fp32=self.residual_in_fp32,
                        is_rms_norm=isinstance(self.norm2, RMSNorm)
                    )
                hidden_states = self.mlp(hidden_states)
            return hidden_states, residual

如果使用 prenorm,则首先处理 dropoutresidual,然后应用 norm1。根据 fused_dropout_add_ln 的设置,选择是否融合这些操作。之后调用 mixer 层(通常是多头注意力机制)。最后,如果 mlp 不是 Identity,则进行类似的操作处理 mlp 层。

代码语言:javascript
复制
        else:
            assert residual is None
            mixer_out = self.mixer(
                hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
            )
            if self.return_residual:  # mixer out is actually a pair here
                mixer_out, hidden_states = mixer_out
            if not self

.fused_dropout_add_ln:
                hidden_states = self.norm1(
                    (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
                        dtype=self.norm1.weight.dtype
                    )
                )
            else:
                if self.drop_path1.p == 0 or not self.training:
                    rowscale1 = None
                else:
                    rowscale1 = self.drop_path1(
                        torch.ones(
                            mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype
                        )
                    )
                hidden_states = layer_norm_fn(
                    mixer_out,
                    self.norm1.weight,
                    self.norm1.bias,
                    residual=hidden_states,
                    eps=self.norm1.eps,
                    dropout_p=self.dropout1.p if self.training else 0.0,
                    rowscale=rowscale1,
                    prenorm=False,
                    is_rms_norm=isinstance(self.norm1, RMSNorm)
                )
            if not isinstance(self.mlp, nn.Identity):
                mlp_out = self.mlp(hidden_states)
                if self.return_residual:  # mlp out is actually a pair here
                    mlp_out, hidden_states = mlp_out
                if not self.fused_dropout_add_ln:
                    hidden_states = self.norm2(
                        (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
                            dtype=self.norm2.weight.dtype
                        )
                    )
                else:
                    if self.drop_path2.p == 0 or not self.training:
                        rowscale2 = None
                    else:
                        rowscale2 = self.drop_path2(
                            torch.ones(
                                mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype
                            )
                        )
                    hidden_states = layer_norm_fn(
                        mlp_out,
                        self.norm2.weight,
                        self.norm2.bias,
                        residual=hidden_states,
                        eps=self.norm2.eps,
                        dropout_p=self.dropout2.p if self.training else 0.0,
                        rowscale=rowscale2,
                        prenorm=False,
                        is_rms_norm=isinstance(self.norm2, RMSNorm)
                    )
            return hidden_states

对于 postnorm 结构,处理流程类似,但 layer norm 应用于主要操作(注意力和前馈神经网络)之后。这里的 residual 在一开始设为 None。处理 mixer 层,之后处理 mlp 层。

通过这段代码,Block 类可以灵活地支持 prenormpostnorm 结构,以及各种Dropout、残差连接和层归一化的组合。这使得它在实现不同类型的Transformer架构时非常高效和通用。

代码解读博客:ParallelBlock

在这篇博客中,我们将逐段解读 ParallelBlock 类的代码。该类实现了并行的注意力(mixer)和MLP块,类似于GPT-J、GPT-NeoX和PaLM模型的结构。

理论基础

ParallelBlock 类采用了一种略有不同于常规Transformer块的结构。传统的Transformer块通常遵循以下结构:Layer Norm (LN) -> Multi-Head Attention (MHA) / MLP -> Dropout -> Add。而 ParallelBlock 中的结构为:Dropout -> Add -> LN -> MHA / MLP。这种结构的优势在于可以融合dropout、add和LayerNorm操作,从而提升性能。

代码实现流程
代码语言:javascript
复制
class ParallelBlock(nn.Module):
    """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
    and PaLM.
    """

    def __init__(
        self,
        dim,
        mixer_cls=None,
        mlp_cls=None,
        norm_cls=nn.LayerNorm,
        dropout_cls=nn.Dropout,
        resid_dropout1=0.0,
        resid_dropout2=0.0,
        tied_norm=False,
        fused_dropout_add_ln=False,
        residual_in_fp32=False,
        sequence_parallel=False,
        mark_shared_params=False,
    ):
        super().__init__()
        self.tied_norm = tied_norm
        self.fused_dropout_add_ln = fused_dropout_add_ln
        self.residual_in_fp32 = residual_in_fp32
        if mixer_cls is None:
            mixer_cls = partial(MHA, num_heads=dim // 64)
        if mlp_cls is None:
            mlp_cls = partial(Mlp, hidden_features=4 * dim)
        self.mixer = mixer_cls(dim)
        self.dropout1 = dropout_cls(resid_dropout1)
        self.norm1 = norm_cls(dim)
        self.mlp = mlp_cls(dim)
        self.dropout2 = dropout_cls(resid_dropout2)
        if not self.tied_norm:
            self.norm2 = norm_cls(dim)

        if self.fused_dropout_add_ln:
            assert layer_norm_fn is not None, "Triton is not installed"
            assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
                self.dropout1, nn.Dropout
            )

        if sequence_parallel:
            for p in self.norm1.parameters():
                p._sequence_parallel = True
            if hasattr(self, "norm2"):
                for p in self.norm2.parameters():
                    p._sequence_parallel = True

        if mark_shared_params:
            for p in self.norm1.parameters():
                p._shared_params = True
            if hasattr(self, "norm2"):
                for p in self.norm2.parameters():
                    p._shared_params = True

在构造函数中,首先初始化了各个参数。根据 tied_normfused_dropout_add_ln 等标志设置了一些断言和默认值。如果没有提供 mixer_clsmlp_cls,则使用默认的多头注意力机制(MHA)和前馈神经网络(Mlp)。初始化了 mixerdropoutnormmlp 层,并根据条件设置了层归一化参数的并行序列和共享标志。

代码语言:javascript
复制
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)

    def forward(
        self,
        hidden_states1: Tensor,
        hidden_states2: Optional[Tensor] = None,
        residual: Optional[Tensor] = None,
        mixer_kwargs=None,
    ):
        r"""Pass the input through the encoder layer.

        Args:
            hidden_states1: the output of the previous attention (mixer) or embedding layer.
            hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
            residual.
        """
        if not self.fused_dropout_add_ln:
            dropped1 = self.dropout1(hidden_states1)
            if hidden_states2 is not None:
                dropped2 = self.dropout2(hidden_states2)
                residual = (
                    (residual + dropped1 + dropped2)
                    if residual is not None
                    else dropped1 + dropped2
                )
            else:
                residual = (residual + dropped1) if residual is not None else dropped1
            hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
            hidden_states2 = (
                self.norm2(residual.to(dtype=self.norm2.weight.dtype))
                if not self.tied_norm
                else hidden_states1
            )
            if self.residual_in_fp32:
                residual = residual.to(torch.float32)
        else:
            weight2, bias2 = (
                (self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None)
            )
            hidden_states1, *rest, residual = layer_norm_fn(
                hidden_states1,
                self.norm1.weight,
                self.norm1.bias,
                residual=residual,
                x1=hidden_states2,
                weight1=weight2,
                bias1=bias2,
                eps=self.norm1.eps,
                dropout_p=self.dropout1.p if self.training else 0.0,
                prenorm=True,
                residual_in_fp32=self.residual_in_fp32,
                is_rms_norm=isinstance(self.norm1, RMSNorm)
            )
            if self.tied_norm:
                hidden_states2 = hidden_states1
            else:
                hidden_states2, = rest

        if mixer_kwargs is None:
            mixer_kwargs = {}
        hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
        hidden_states2 = self.mlp(hidden_states2)
        return hidden_states1, hidden_states2, residual

forward 方法中,根据 fused_dropout_add_ln 的设置,选择是否融合dropout、add和LayerNorm操作。根据 hidden_states2 是否为 None,决定是否添加dropout到 residual 中。然后应用 norm1norm2,并处理残差。调用 mixermlp 层,最后返回 hidden_states1hidden_states2residual

通过这段代码,ParallelBlock 类实现了并行的注意力和MLP块结构,为模型的性能优化提供了一种有效的方法。

代码目录 flash_attn/modules/mha.py

代码解读:FlashSelfAttention

在这篇博客中,我们将逐段解读 FlashSelfAttention 类的代码。该类实现了一个带有 Softmax 的多头自注意力机制。我们将详细介绍代码实现流程和所用到的理论基础。

理论基础

自注意力机制是Transformer架构的核心。其主要原理是通过查询(query)、键(key)和值(value)来计算输入序列中每个元素的重要性权重,并根据这些权重对值进行加权求和。具体来说,自注意力机制通过以下步骤实现:

  1. 线性变换:将输入向量通过线性层分别映射到查询、键和值向量。
  2. 计算注意力权重:使用查询和键向量的点积计算注意力权重,通常会进行缩放并通过Softmax函数归一化。
  3. 加权求和:将归一化后的注意力权重与值向量相乘,得到输出向量。

多头注意力机制通过多个独立的注意力头来捕捉输入序列中的不同特征,并将这些特征拼接后再通过线性变换进行融合。

代码实现流程
代码语言:javascript
复制
class FlashSelfAttention(nn.Module):
    """实现带有Softmax的缩放点积注意力。
    参数
    ---------
        softmax_scale: 用于Softmax注意力的温度参数。
                      (默认值:1/sqrt(d_keys),其中d_keys在运行时计算)
        attention_dropout: 对注意力应用的Dropout率
                           (默认值:0.0)
    """

    def __init__(
        self,
        causal=False,
        softmax_scale=None,
        attention_dropout=0.0,
        window_size=(-1, -1),
        alibi_slopes=None,
        deterministic=False,
    ):
        super().__init__()
        assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention未安装"
        assert flash_attn_qkvpacked_func is not None, "FlashAttention未安装"
        self.causal = causal
        self.softmax_scale = softmax_scale
        self.drop = nn.Dropout(attention_dropout)
        self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
        self.window_size = window_size
        self.deterministic = deterministic

这段代码定义了 FlashSelfAttention 类的构造函数。以下是参数的解释:

  • causal:是否为因果注意力,即是否考虑序列的时间顺序。
  • softmax_scale:Softmax的温度参数,用于缩放点积结果。
  • attention_dropout:注意力机制中的Dropout率。
  • window_size:局部窗口大小。
  • alibi_slopes:用于调整注意力偏置的斜率。
  • deterministic:是否使用确定性操作。

构造函数中首先通过断言确保所需的FlashAttention函数已安装,然后初始化各个参数,并通过 register_buffer 注册不需要梯度的参数。

代码语言:javascript
复制
    def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
        """实现多头Softmax注意力。
        参数
        ---------
            qkv: 包含查询、键和值的张量。
                如果cu_seqlens为None且max_seqlen为None,则qkv形状为(B, S, 3, H, D)。
                如果cu_seqlens不为None且max_seqlen不为None,则qkv形状为(total, 3, H, D),
                其中total是批次中序列长度的总和。
            causal: 如果传递,将覆盖self.causal
            cu_seqlens: (batch_size + 1,) 形状的张量,类型为torch.int32。批次中序列的累计长度,用于索引到qkv中。
            max_seqlen: int。批次中最大序列长度。
        返回:
        --------
            out: 如果cu_seqlens不为None且max_seqlen不为None,则形状为(total, H, D),
                否则为(B, S, H, D)。
        """
        assert qkv.dtype in [torch.float16, torch.bfloat16]
        assert qkv.is_cuda
        causal = self.causal if causal is None else causal
        unpadded = cu_seqlens is not None
        if self.alibi_slopes is not None:
            self.alibi_slopes = self.alibi_slopes.to(torch.float32)
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
            return flash_attn_varlen_qkvpacked_func(
                qkv,
                cu_seqlens,
                max_seqlen,
                self.drop.p if self.training else 0.0,
                softmax_scale=self.softmax_scale,
                causal=causal,
                alibi_slopes=self.alibi_slopes,
                window_size=self.window_size,
                deterministic=self.deterministic,
            )
        else:
            return flash_attn_qkvpacked_func(
                qkv,
                self.drop.p if self.training else 0.0,
                softmax_scale=self.softmax_scale,
                causal=causal,
                alibi_slopes=self.alibi_slopes,
                window_size=self.window_size,
                deterministic=self.deterministic,
            )

这段代码实现了 forward 方法,即前向传播过程。以下是参数的解释:

  • qkv:包含查询、键和值的张量。
  • causal:如果传递,将覆盖 self.causal
  • cu_seqlens:批次中序列的累计长度,用于索引到 qkv 中。
  • max_seqlen:批次中最大序列长度。

在前向传播过程中,首先检查 qkv 的数据类型和设备类型。然后根据是否有 cu_seqlens 来确定是否使用未填充的序列。如果使用未填充的序列,则通过 flash_attn_varlen_qkvpacked_func 函数计算注意力;否则通过 flash_attn_qkvpacked_func 函数计算注意力。

代码小结

FlashSelfAttention 类实现了一个带有Softmax的多头自注意力机制。其主要步骤包括:

  1. 初始化参数。
  2. 在前向传播过程中,根据输入张量的形状和参数选择适当的函数计算注意力。

通过以上代码,我们可以高效地计算自注意力机制,并在需要时应用Dropout和因果注意力。

代码解读博客:FlashCrossAttention

在这篇博客中,我们将逐段解读 FlashCrossAttention 类的代码。该类实现了带有 Softmax 的缩放点积交叉注意力机制。我们将详细介绍代码实现流程和所用到的理论基础。

理论基础

交叉注意力机制与自注意力机制类似,但它使用不同的查询(query)、键(key)和值(value)来源于不同的序列。其主要步骤包括:

  1. 线性变换:将查询、键和值向量通过线性层进行映射。
  2. 计算注意力权重:使用查询和键向量的点积计算注意力权重,通常进行缩放并通过 Softmax 函数归一化。
  3. 加权求和:将归一化后的注意力权重与值向量相乘,得到输出向量。

交叉注意力在许多任务中具有广泛应用,如机器翻译中的编码器-解码器架构。

代码实现流程
代码语言:javascript
复制
class FlashCrossAttention(nn.Module):
    """实现带有Softmax的缩放点积注意力。
    参数
    ---------
        softmax_scale: 用于Softmax注意力的温度参数。
                      (默认值:1/sqrt(d_keys),其中d_keys在运行时计算)
        attention_dropout: 对注意力应用的Dropout率
                           (默认值:0.0)
    """

    def __init__(
        self,
        causal=False,
        softmax_scale=None,
        attention_dropout=0.0,
        alibi_slopes=None,
        window_size=(-1, -1),
        deterministic=False,
    ):
        super().__init__()
        assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention未安装"
        assert flash_attn_kvpacked_func is not None, "FlashAttention未安装"
        self.causal = causal
        self.softmax_scale = softmax_scale
        self.drop = nn.Dropout(attention_dropout)
        self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
        self.window_size = window_size
        self.deterministic = deterministic

这段代码定义了 FlashCrossAttention 类的构造函数。以下是参数的解释:

  • causal:是否为因果注意力,即是否考虑序列的时间顺序。
  • softmax_scale:Softmax的温度参数,用于缩放点积结果。
  • attention_dropout:注意力机制中的Dropout率。
  • window_size:局部窗口大小。
  • alibi_slopes:用于调整注意力偏置的斜率。
  • deterministic:是否使用确定性操作。

构造函数中首先通过断言确保所需的FlashAttention函数已安装,然后初始化各个参数,并通过 register_buffer 注册不需要梯度的参数。

代码语言:javascript
复制
    def forward(
        self,
        q,
        kv,
        causal=None,
        cu_seqlens=None,
        max_seqlen=None,
        cu_seqlens_k=None,
        max_seqlen_k=None,
    ):
        """实现多头Softmax注意力。
        参数
        ---------
            q: 包含查询的张量。形状为 (B, Sq, H, D)
            kv: 包含键和值的张量。形状为 (B, Sk, 2, H_k, D)
            causal: 如果传递,将覆盖self.causal
            cu_seqlens: (batch_size + 1,) 形状的张量,类型为torch.int32。批次中序列的累计长度,用于索引到 q 中。
            max_seqlen: int。批次中 q 的最大序列长度。
            cu_seqlens_k: (batch_size + 1,) 形状的张量,类型为torch.int32。批次中序列的累计长度,用于索引到 kv 中。
            max_seqlen_k: int。批次中 k 和 v 的最大序列长度。
        """
        assert q.dtype in [torch.float16, torch.bfloat16]
        assert q.is_cuda and kv.is_cuda
        causal = self.causal if causal is None else causal
        unpadded = cu_seqlens is not None
        if self.alibi_slopes is not None:
            self.alibi_slopes = self.alibi_slopes.to(torch.float32)
        if unpadded:
            assert cu_seqlens.dtype == torch.int32
            assert max_seqlen is not None
            assert isinstance(max_seqlen, int)
            assert cu_seqlens_k is not None
            assert cu_seqlens_k.dtype == torch.int32
            assert max_seqlen_k is not None
            assert isinstance(max_seqlen, int)
            return flash_attn_varlen_kvpacked_func(
                q,
                kv,
                cu_seqlens,
                cu_seqlens_k,
                max_seqlen,
                max_seqlen_k,
                self.drop.p if self.training else 0.0,
                softmax_scale=self.softmax_scale,
                causal=causal,
                alibi_slopes=self.alibi_slopes,
                window_size=self.window_size,
                deterministic=self.deterministic,
            )
        else:
            batch_size, seqlen_q = q.shape[0], q.shape[1]
            seqlen_k = kv.shape[1]
            assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
            return flash_attn_kvpacked_func(
                q,
                kv,
                self.drop.p if self.training else 0.0,
                causal=causal,
                softmax_scale=self.softmax_scale,
                alibi_slopes=self.alibi_slopes,
                window_size=self.window_size,
                deterministic=self.deterministic,
            )

这段代码实现了 forward 方法,即前向传播过程。以下是参数的解释:

  • q:包含查询的张量,形状为 (B, Sq, H, D)
  • kv:包含键和值的张量,形状为 (B, Sk, 2, H_k, D)
  • causal:如果传递,将覆盖 self.causal
  • cu_seqlens:批次中序列的累计长度,用于索引到 q 中。
  • max_seqlen:批次中 q 的最大序列长度。
  • cu_seqlens_k:批次中序列的累计长度,用于索引到 kv 中。
  • max_seqlen_k:批次中 kv 的最大序列长度。

在前向传播过程中,首先检查 qkv 的数据类型和设备类型。然后根据是否有 cu_seqlens 来确定是否使用未填充的序列。如果使用未填充的序列,则通过 flash_attn_varlen_kvpacked_func 函数计算注意力;否则通过 flash_attn_kvpacked_func 函数计算注意力。

代码小结

FlashCrossAttention 类实现了一个带有Softmax的多头交叉注意力机制。其主要步骤包括:

  1. 初始化参数。
  2. 在前向传播过程中,根据输入张量的形状和参数选择适当的函数计算注意力。

通过以上代码,我们可以高效地计算交叉注意力机制,并在需要时应用Dropout和因果注意力。

代码解读博客:SelfAttention

在这篇博客中,我们将逐段解读 SelfAttention 类的代码。该类实现了带有 Softmax 的缩放点积自注意力机制。我们将详细介绍代码实现流程和所用到的理论基础。

理论基础

自注意力机制是Transformer架构的核心,它允许模型在计算每个词的表示时考虑序列中的其他所有词。其主要步骤包括:

  1. 线性变换:将查询(query)、键(key)和值(value)向量通过线性层进行映射。
  2. 计算注意力权重:使用查询和键向量的点积计算注意力权重,通常进行缩放并通过Softmax函数归一化。
  3. 加权求和:将归一化后的注意力权重与值向量相乘,得到输出向量。
代码实现流程
代码语言:javascript
复制
class SelfAttention(nn.Module):
    """实现带有Softmax的缩放点积注意力。
    参数
    ---------
        softmax_scale: 用于Softmax注意力的温度参数。
                      (默认值:1/sqrt(d_keys),其中d_keys在运行时计算)
        attention_dropout: 对注意力应用的Dropout率
                           (默认值:0.0)
    """

    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
        super().__init__()
        self.causal = causal
        self.softmax_scale = softmax_scale
        self.drop = nn.Dropout(attention_dropout)

这段代码定义了 SelfAttention 类的构造函数。以下是参数的解释:

  • causal:是否为因果注意力,即是否考虑序列的时间顺序。
  • softmax_scale:Softmax的温度参数,用于缩放点积结果。
  • attention_dropout:注意力机制中的Dropout率。

构造函数中初始化了因果注意力标志、Softmax缩放参数和Dropout层。

代码语言:javascript
复制
    def forward(self, qkv, causal=None, key_padding_mask=None):
        """实现多头Softmax注意力。
        参数
        ---------
            qkv: 包含查询、键和值的张量。形状为 (B, S, 3, H, D)
            causal: 如果传递,将覆盖self.causal
            key_padding_mask: 布尔掩码,用于对注意力权重进行掩码处理。True表示保留,False表示掩码。形状为 (B, S)
        """
        batch_size, seqlen = qkv.shape[0], qkv.shape[1]
        causal = self.causal if causal is None else causal
        q, k, v = qkv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])

在前向传播过程中,首先获取批次大小和序列长度。如果传递了 causal 参数,则覆盖 self.causal。然后将 qkv 张量沿第三个维度进行拆分,得到查询、键和值。最后,计算Softmax的缩放参数。

代码语言:javascript
复制
        scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)

这行代码使用爱因斯坦求和约定(einsum)计算查询和键的点积,并进行缩放。scores 张量形状为 (B, H, T, S),表示每个查询与所有键的相似度。

代码语言:javascript
复制
        if key_padding_mask is not None:
            padding_mask = torch.full(
                (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
            )
            padding_mask.masked_fill_(key_padding_mask, 0.0)
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
            scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")

如果提供了 key_padding_mask,则对注意力分数进行掩码处理。首先创建一个填充掩码,将指定位置的值设为一个非常小的数(例如-10000.0)。然后,将这个掩码应用到注意力分数上,掩盖掉不需要考虑的值。

代码语言:javascript
复制
        if causal:
            # "triu_tril_cuda_template" not implemented for 'BFloat16'
            # So we have to construct the mask in float
            causal_mask = torch.triu(
                torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
            )
            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
            scores = scores + causal_mask.to(dtype=scores.dtype)

如果是因果注意力,则构造一个上三角掩码(只保留对角线及其以下的元素),掩盖掉未来的时间步。这个掩码用于确保当前时间步只能关注自己和之前的时间步,防止信息泄漏。

代码语言:javascript
复制
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
        attention_drop = self.drop(attention)
        output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
        return output

计算Softmax得到注意力权重,然后应用Dropout。最后,使用注意力权重对值进行加权求和,得到最终的输出。

代码小结

SelfAttention 类实现了一个带有Softmax的多头自注意力机制。其主要步骤包括:

  1. 初始化参数。
  2. 在前向传播过程中,根据输入张量计算注意力分数。
  3. 应用键填充掩码和因果掩码。
  4. 计算Softmax得到注意力权重,并应用Dropout。
  5. 使用注意力权重对值进行加权求和,得到最终输出。

结论

FlashAttention及其改进版FlashAttention-2为注意力机制在深度学习中的应用提供了显著的速度和内存优化,使得处理长序列数据变得更加高效。希望本文对您了解和使用FlashAttention有所帮助。

如果您对FlashAttention有任何问题或建议,欢迎通过GitHub issue与我们联系。

参考链接:

  • FlashAttention论文
  • FlashAttention-2论文
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2024-06-20,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 857Hub 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 论文介绍
    • FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
      • FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
      • 安装和特性
        • 环境要求
          • 安装步骤
          • 使用示例
            • 参数说明
            • 性能表现
              • 加速效果
                • 内存节省
                • 完整模型代码和训练脚本
                • FlashAttention 更新日志
                  • 2.0:完全重写,速度提升2倍
                    • 算法
                    • 反向传播
                    • 并行性
                    • 工作分区
                  • 函数重命名
                    • 2.1:更改causal标志的行为
                      • 2.2:针对推理进行优化
                        • 2.3:局部(即滑动窗口)注意力
                          • 2.4:ALiBi(线性偏差注意力),确定性反向传播
                            • 2.5:分页KV缓存
                              • 代码解读:Block
                              • 代码解读博客:ParallelBlock
                              • 代码解读:FlashSelfAttention
                              • 代码解读博客:FlashCrossAttention
                              • 代码解读博客:SelfAttention
                            • 结论
                            相关产品与服务
                            机器翻译
                            机器翻译(Tencent Machine Translation,TMT)结合了神经机器翻译和统计机器翻译的优点,从大规模双语语料库自动学习翻译知识,实现从源语言文本到目标语言文本的自动翻译,目前可支持十余种语言的互译。
                            领券
                            问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档