前往小程序,Get更优阅读体验!
立即前往
发布
社区首页 >专栏 >Swin Transformer:深度解析其架构与代码实现

Swin Transformer:深度解析其架构与代码实现

原创
作者头像
是Dream呀
发布2025-03-03 23:17:44
发布2025-03-03 23:17:44
7700
代码可运行
举报
文章被收录于专栏:总结xyp总结xyp
运行总次数:0
代码可运行

Swin Transformer是一种强大的视觉Transformer模型,它通过引入层次化结构和基于窗口偏移的自注意力机制,有效提升了特征提取的能力。在多个计算机视觉任务中,Swin Transformer已经达到了最先进的性能水平。本文将深入探讨Swin Transformer的架构,并尝试将其网络结构进行复现。

一、Swin Transformer 概述

Swin Transformer通过扩展原始Transformer模型的能力,引入了层次化结构和基于窗口偏移的自注意力机制,使其能够有效处理图像数据,并可应用于图像分类、目标检测和分割等任务。

1.背景介绍

Swin Transformer,由微软亚洲研究院孕育的新星,今年在学术界大放异彩,以其独特的魅力在图像分类、图像分割和目标检测等众多领域中斩获了无数荣誉。

然而,Swin Transformer的诞生并非一帆风顺。在它之前,Transformer在自然语言处理(NLP)领域已经取得了辉煌的成就,但在计算机视觉(CV)的舞台上却未能同样耀眼。Swin Transformer的创造者们深入分析了这一现象,发现主要有两个难题:首先,NLP中的token大小是固定的,而CV中的特征尺度变化莫测,如同变幻莫测的风;其次,CV对于分辨率的要求更高,而使用Transformer的计算复杂度与图像尺寸的平方成正比,这无疑给计算带来了巨大的压力。

为了克服这些挑战,Swin Transformer进行了两项创新性的改进:首先,它借鉴了CNN中常用的层次化构建方式,构建了层次化的Transformer;其次,它引入了locality的概念,对没有重叠的window区域进行self-attention计算,能够精准地聚焦于每一个角落。

Swin Transformer不仅仅是一个技术革新,它更是一个多才多艺的艺术家,能够灵活地应用于图像分类、目标检测和语义分割等任务,成为这些任务的通用骨干网络。有人说,Swin Transformer可能是CNN的完美替代者,但我认为,它更像是一位能够与CNN并肩作战的伙伴,共同推动计算机视觉技术的发展。

2.主要特点

  • 层次化结构:模型采用层次化的设计,逐步降低特征的空间维度,同时增加特征的深度。
  • 移位窗口自注意力:通过在局部窗口内移动注意力焦点,减少计算量,同时捕获更丰富的上下文信息。
  • 多尺度特征学习:模型能够学习从粗粒度到细粒度的多尺度特征表示。

3.对比

下图为Swin Transformer与ViT在处理图片方式上的对比,可以看出,Swin Transformer有着ResNet一样的残差结构和CNN具有的多尺度图片结构。

在这里插入图片描述
在这里插入图片描述

二、具体实现

在这里插入图片描述
在这里插入图片描述

首先Swin-Transformer 以一张图片作为起点,这是它的画布,准备在上面绘制出精彩的图案。

1.Patch Partition 层

在 Patch Partition 层,这张图片被巧妙地拆分成众多小块,就像是将一幅大画卷分解为易于管理的小片段。Patch Partition是模型对输入图像进行预处理的一种重要操作。该操作的主要目的是将原始的连续像素图像分割成一系列固定大小的图像块(patches),以便进一步转化为Transformer可以处理的序列数据。

2.Swin Transfomer

随后,Linear Embedding 层赋予了这些小块以特征的维度,让它们不再是静止的图像,而是活跃的数据点,为之后的表演做好准备。这些特征化的小块进入 Swin Transformer Block,这是第一阶段,它们在这里学会了如何与周围的伙伴协作,共同构建起初步的图像理解。

3.Patch Merging层

接下来的第二至第四阶段,每个阶段开始前,小块们会经历 Patch Merging 的过程,这就像是将多个小故事合并为一个更加宏大的叙事,每一次合并都让图像的表示更加深入和丰富。Patch Merging层主要是进行下采样,产生分层表示。 Patch Merging 是一种减少序列长度并增加每个补丁表示中通道数的操作。

4.AdaptiveAvgPool1d 层和全连接层

在第四阶段的末尾,所有的数据汇集到输出模块,这里有一个 LayerNorm 层,它确保了数据的平衡和稳定,就像是在演出中保持舞者的稳定和优雅。最后,AdaptiveAvgPool1d 层和全连接层相继登场,它们共同作用于数据,最终完成图像的分类,为这场演出画上完美的句点。

三、代码分析

1.ShiftWindowAttentionBlock 类

代码语言:python
代码运行次数:0
复制
class ShiftWindowAttentionBlock(nn.Module):
    def __init__(self, ...):
        ...
    def forward(self, x):
    # patch_num补成能够被window_size整除
    if x.size(-2) % self.window_size:
        x = nn.ZeroPad2d((0, 0, 0, self.window_size - x.size(-2) % self.window_size))(x)

    batch, modal_leng, patch_num, input_dim = x.size()
    short_cut = x # resdual

    # 窗口偏移
    if self.shift_size:
        x = torch.roll(x, shifts=-self.shift_size, dims=2) # 只在 patch_num 上 roll   [batch, modal_leng, patch_num, input_dim]

    # 窗口化 
    window_num = patch_num // self.window_size
    window_x = x.reshape(batch, modal_leng, window_num, self.window_size, input_dim) # [batch, modal_leng, window_num, window_size, input_dim]

    # 基于窗口的多头自注意力
    q = self.query(window_x).reshape(batch, modal_leng, window_num, self.window_size, self.head_num, self.att_size).permute(0, 1, 2, 4, 3, 5) 
    ....

ShiftWindowAttentionBlock 类实现了带有窗口移位的自注意力机制。它接收一个输入张量 x,对其进行自注意力操作,并根据是否启用移位来调整窗口的覆盖范围。

2.SwinTransformer 类

train_shape: 总体训练样本的shapecategory: 类别数embedding_dim: embedding 维度patch_size: 一个patch长度head_num: 多头自注意力att_size: QKV矩阵维度window_size: 一个窗口包含多少patchs

对于传感窗口数据来讲,在每个单独的模态轴上对时序轴进行patch切分,例如 uci-har 数据集窗口尺寸为 128, 9,一个patch包含4个数据,那么每个模态轴上的patch_num为32, 总patch数为 32 * 9:

代码语言:python
代码运行次数:0
复制
class SwinTransformer(nn.Module):
    def __init__(self, train_shape, category, embedding_dim=256, patch_size=4, head_num=4, att_size=64, window_size=8):
        super().__init__()
        self.series_leng = train_shape[-2]
        self.modal_leng = train_shape[-1]
        self.patch_num = self.series_leng // patch_size
        
        self.patch_conv = nn.Conv2d(
            in_channels=1,
            out_channels=embedding_dim,
            kernel_size=(patch_size, 1),
            stride=(patch_size, 1),
            padding=0
        )

        # 位置信息
        self.position_embedding = nn.Parameter(torch.zeros(1, self.modal_leng, self.patch_num, embedding_dim))

        # patch_num维度降采样一次后的计算方式
        swin_transformer_block1_input_patch_num = math.ceil(self.patch_num / window_size) * window_size
        swin_transformer_block2_input_patch_num = math.ceil(math.ceil(swin_transformer_block1_input_patch_num / 2) / window_size) * window_size
        swin_transformer_block3_input_patch_num = math.ceil(math.ceil(swin_transformer_block2_input_patch_num / 2) / window_size) * window_size

        # Shift_Window_Attention_Layer
        # 共3个swin_transformer_block,每个swin_transformer_block对时序维降采样1/2,共降采样1/8
        self.swa = nn.Sequential(
            # swin_transformer_block 1
            nn.Sequential( 
                ShiftWindowAttentionBlock(patch_num=swin_transformer_block1_input_patch_num, input_dim=embedding_dim, head_num=head_num, att_size=att_size, window_size=window_size, shift=False),
                ShiftWindowAttentionBlock(patch_num=swin_transformer_block1_input_patch_num, input_dim=embedding_dim, head_num=head_num, att_size=att_size, window_size=window_size, shift=True)
            ),
            # swin_transformer_block 2
            nn.Sequential(
                ShiftWindowAttentionBlock(patch_num=swin_transformer_block2_input_patch_num, input_dim=embedding_dim, head_num=head_num, att_size=att_size, window_size=window_size, shift=False),
                ShiftWindowAttentionBlock(patch_num=swin_transformer_block2_input_patch_num, input_dim=embedding_dim, head_num=head_num, att_size=att_size, window_size=window_size, shift=True)
            ),
            # swin_transformer_block 3
            nn.Sequential(
                ShiftWindowAttentionBlock(patch_num=swin_transformer_block3_input_patch_num, input_dim=embedding_dim, head_num=head_num, att_size=att_size, window_size=window_size, shift=False),
                ShiftWindowAttentionBlock(patch_num=swin_transformer_block3_input_patch_num, input_dim=embedding_dim, head_num=head_num, att_size=att_size, window_size=window_size, shift=True)
            )
        )

        # classification tower
        self.dense_tower = nn.Sequential(
            nn.Linear(self.modal_leng * math.ceil(swin_transformer_block3_input_patch_num / 2) * embedding_dim, 1024),
            nn.LayerNorm(1024),
            nn.ReLU(),
            nn.Linear(1024, category)
        )

SwinTransformer 类构建了 Swin Transformer 的完整模型。它接收输入数据的形状 train_shape 和类别数 category,以及其他配置参数。

3.模型组件

  • 块卷积 (patch_conv):将输入数据分割成小块,并将其转换成嵌入维度。
  • 位置嵌入 (position_embedding):为每个块添加位置信息,帮助模型捕获空间关系。
  • Swin Transformer 块 (swa):由多个 ShiftWindowAttentionBlock 组成,逐步降低特征的空间维度,同时增加深度。
  • 分类塔 (dense_tower):在模型的顶层,将特征展平并通过一系列线性层进行分类。

4.前向传播

代码语言:python
代码运行次数:0
复制
def forward(self, x):
    x = self.patch_conv(x) # [batch, embedding_dim, patch_num, modal_leng]
    x = self.position_embedding + x.permute(0, 3, 2, 1) # [batch, modal_leng, patch_num, embedding_dim]
    x = self.swa(x)
    x = nn.Flatten()(x)
    x = self.dense_tower(x)
    return x

forward 方法定义了模型的前向传播过程:

  1. 块卷积:输入数据通过卷积操作转换成嵌入维度。
  2. 位置嵌入:将位置信息添加到块特征中。
  3. Swin Transformer 块:通过多个 Swin Transformer 块进行特征提取。
  4. 分类塔:在模型顶部,将特征展平并通过线性层进行分类。

Swin Transformer 是一种创新的模型,它将 Transformer 架构的优势引入到计算机视觉领域。通过层次化处理和高效的自注意力机制,Swin Transformer 在多个视觉任务上展现出卓越的性能。提供的代码实现了 Swin Transformer 的核心功能,为进一步的研究和应用提供了基础。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、Swin Transformer 概述
    • 1.背景介绍
    • 2.主要特点
    • 3.对比
  • 二、具体实现
    • 1.Patch Partition 层
    • 2.Swin Transfomer
    • 3.Patch Merging层
    • 4.AdaptiveAvgPool1d 层和全连接层
  • 三、代码分析
    • 1.ShiftWindowAttentionBlock 类
    • 2.SwinTransformer 类
    • 3.模型组件
    • 4.前向传播
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档