2020 年 5 月,Facebook AI 推出了DERT( Detection Transformer),用于目标检测和全景分割。
2020 年 10 月,谷歌提出了Vit(Vision Transformer),利用 Transformer 对图像进行分类,而不需要卷积网络。
2021年1月,OpenAI 提出两个模型:DALL·E 基于本文直接生成图像,CLIP将图像映射到文本描述的类别中。两个模型都利用 Transformer 。
2021年3月,微软提出Swin Transformer,把CV各大任务给屠榜了。。。。
我能放过它?我不能。。。总结下前段时间看了论文和代码梳理出来的swin_transformer框架和实现。
论文: https://arxiv.org/abs/2103.14030
代码: https://github.com/microsoft/Swin-Transformer
swin_transformer对比之前Vit有两个改进点:
1.引入了CNN里常用的多层次transformers结构
Vit的尺度是不变的,不易于接入到下游任务中,比如分割的encoder阶段可以方便的接入resnet等backbone网络,而Vit的特征图尺寸是不变的下图(b)。swin_transfomer通过合并image_patchesd的方式引入多层次结构,如下图(a)。
2、降低计算复杂度和内存占用
论文中定义上图中灰色块为patch,红色块定义为window。swin_transfomer通过切分窗口,计算self_attention是针对这些局部的无重叠的window。原始的MSA和论文中W-MSA的计算复杂度如下图公式,其中M是窗口包含patch的个数,也就是window_size,其大小是远小于h,w的。通过公式可以看出其计算复杂度和hw是线性关系。这里复杂度计算方法,我们后续分析源码后可以更清晰了解。
针对第一个优化点,论文使用的网络架构如下:
结构分为4个stage,stages中特征图大小分别缩小为1/4,1/8,1/16,1/32。
针对第二个优化点,论文指出仅仅对FM切分windows,然后对每个window进行self_attention有一个缺点,就是窗口之间是无沟通的。所以提出使用串联W-MSA和SW-MSA的方式。
W-MSA就是无重叠的窗口self_attention计算,而cyclic shift就如下图,对窗口进行一个shift。本来2*2的窗口个数,不等比切分为3*3个窗口。但是这样计算量会增大1.5*1.5倍。作者提出一个替换方法是进行一个roll操作,将2*2的窗口向左向上移动,移动后的窗口就包含了上层其他区域窗口的信息了。但是ABC区域本不该是邻近区域,所以还需要进行一个mask操作。
最后记得反shift把整个窗口移回去~
结果就是把CV几个大任务屠榜了。。
下面介绍从代码角度深入了解swin_transformer
先了解主要类:BasicLayer实现stage的流程,SwinTransformerBlock是BasicLayer的主要逻辑模块也是论文核心模块,WindowAttention是SwinTransformerBlock中实现attention的模块。
depths:(2,2,6,4)决定每个layer的SwinTransformerBlock执行次数。
论文提出了4套参数模型,我们下面以Swin-T为例介绍。
代码模块逻辑:
patch_embed + pos_embed
stage1
-BasicLayer
--SwinTransformerBlock(*2)
---WindowAttention
stage2
-BasicLayer
--SwinTransformerBlock(*2)
---WindowAttention
stage3
-BasicLayer
--SwinTransformerBlock(*6)
---WindowAttention
stage4
-BasicLayer
--SwinTransformerBlock(*4)
---WindowAttention
主要模块的代码逻辑:
首先进行一次patch_embed,patch_embed就是把输入按patch进行一次向量映射。我认为就是卷积操作(标题swin_transfomer,第一步就是卷积~卷积yyds)
设定输入:(3,256,256),patch_size=4,embeding_dim=96
(1)分辨率不够4整除就pad到4的倍数
(2)通用卷积kernel=4,stride=4,将image映射为无重叠的4*4的patchs:(96,64,64)
(3)如果需要norm,再进行一次layerNorm
(4)(3,256,256) 通过patch_embed,特征为(96,64,64)
如果有position_embeding步骤,需要学习一个96,64,64的pos_emded参数。和patch_embed进行concat.
将emded矩阵进行flatten+transpose-->64*64, 96
对分辨率缩小*4的特征图进行4个stage的-BasicLayer
设定window_size=7,以stage1为例输入特征图大小为(64,64)。img_mask初始为(70,70),那么通过window_partition就把特征图切分为100个7*7的窗口。
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask:, h, w, : = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
以上代码目的是得到100个49*49的attn_mask。
这里的attn_mask是为后续的cyclic shift,也就是SW-MSA使用。
首先,对img_mask70*70的图进行切分9大块赋值
63*63=0 4*63=1 3*64=2
63*4=3 4*4=4 3*4=5
64*3=6 4*3=7 3*3=8
然后通过将window_partition将窗口切分为100个7*7窗口,对数据平铺,得到100*49,每个窗口和其他窗口进行相减,得到100*49*49,再将不为0的值赋值-100。这些不为0位置含义可以理解为和相对位置不为上图中划分的同一个区域。结合cyclic shift,表示cyclic shift中在一个window内,特征不相邻的sub_window的位置,所以需要mask掉。
对输入64*64, 96进行layer_norm+reshape+pad操作。pad作用是要FM的H,W是window_size的倍数。对stage1:64*64, 96-->70,70,96
先看第一阶段W-MSA blcok,也就是不加入cyclic shift。
(a)进行window_partition,将特征图切分为window_size*window_size的patch,1,70*70,96切分为100,7,7,96,再reshape100,49,96
(b) WindowAttention
计算self_attention
step1:获取QKV矩阵。X:100,49,64-->Q,K,V:100,3,49,32
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv0, qkv1, qkv2
具体操作:输入全连接C通道扩展到3C,再根据multi_head将FM切分为head_num份,最后slipe得到qkv矩阵。100,3,49,32表示窗口个数,attention头,窗口长度,C/head
step2:计算attention。
attn = (q @ k.transpose(-2, -1))
100,3,49,32*100,3,32,49-->:100,3,49,49 。self_attention方面的原理可以查看transformers论文,这里就不详细介绍了。
step3:计算relative_position_bias
论文提出,增加相对位置编码效果更好。也就是在step2计算出的attn加上relative_position_bias。和attn一样,大小应该为(3,49*49)的矩阵。
下面看如何计算relative_position_bias。
#define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size0 - 1) * (2 * window_size1 - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size0)
coords_w = torch.arange(self.window_size1)
coords = torch.stack(torch.meshgrid(coords_h, coords_w)) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten:, :, None - coords_flatten:, None, : # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords:, :, 0 += self.window_size0 - 1 # shift to start from 0
relative_coords:, :, 1 += self.window_size1 - 1
relative_coords:, :, 0 *= 2 * self.window_size1 - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_bias = self.relative_position_bias_tableself.relative_position_index.view(-1).view(
self.window_size0 * self.window_size1, self.window_size0 * self.window_size1, -1)
我们假设窗口大小为2,方便理解计算相对位置编码逻辑。
首先建立坐标系:
然后在X和Y方向计算relative_coords。计算relative_coords第一步加(window_size-1)是为了让值都为正数,在X方向再*(2*window_size-1)是为了后续求和能区分(0,1)和(1,0)这类坐标。
最后将X和Y方向坐标值值求和,得到relative_position_index 。
根据以上计算过程,也可以知道,我们的relative_position_bias_table(需要学习的参数)最大值应该是(window_size+(window_size-1))*(2*window_size-1)。
有了relative_position_index和relative_position_bias_table后,relative_position_bias就可以通过查表方式获取。
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
step4:计算attn_out
attn = attn + relative_position_bias.unsqueeze(0)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
根据self_attntion的公式:
softmax(q*KT)*V-->:100,3,49,49*100,3,49,32-->100,3,49,32
step5:进行全连接
reshape+proj -->100,49,96
计算self_attention和transformer里attention机制一样。在NLP领域,输入为BLC,计算的attn是L*L表示每个pos的token对另一个pos的attention值。在这里CV领域,之前将特征图划分为不同窗口,每个窗口大小windowsize*windowsize,所以L对应windowsize*windowsize的长度,也就是一个窗口内每个点对其他点的attention值,是对每个窗口计算self_attention。
以上过程是通过window_partition后处理,这里需要进行window_reverse,把100,49,96还原到1,70,70,96
reverse后的FM和SwinTransformerBlock最初的输入进行一次shortcut。SwinTransformerBlock模块流程结束~了么?没有。之前我们避开了cyclic shift。
在执行block中,对shift_size是
shift_size=0 if (i % 2 == 0) else window_size // 2,
所以第二个迭代 block,我们是需要进行cyclic shift的。
执行逻辑还是以上的(1)-(4),主要不同在于步骤(2),下面主要讲解,shift_size不为0时,步骤(2)的流程。
看第二阶段SW-MSA blcok,也就是加入cyclic shift。
(a)同样进行window_partition,得到b,100,49,96的特征图。然后
cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
attn_mask = mask_matrix
else:
shifted_x = x
attn_mask = None
这行代码的含义就是,将x向左移动shift_size,向上移动shift_size。也就是下图中的cyclic shift。执行这个操作的目的是,通过window_partition后进行W-MSA,窗口和窗口之间是没有重叠的,使用SW-MSA就可以让窗口之间有关联,但是这里存在的一个问题是下图中ABC区域和邻近窗口其实是不相邻的,是通过roll操作后赋值在这个区域。
(b)windowAttention
计算attention和上诉步骤一致,只是在步骤a中我们提到了,ABC区域在计算attention时需要mask掉,这里的mask就是我们BasicLayer的第一步获取的attn_mask(100,49,49)~
if mask is not None:
nW = mask.shape0
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
mask主要逻辑,attn假设目前是200,3,49,49,我们计算的attn_mask是(100,49,49),因为是针对窗口位置mask和bs和head_num无关,所以将attn和mask分别reshape到(2, 100, 3, 49, 49)和(1,100,1,49,49)就好了。
最后记得window_rever后,记得把shift_x给sereverse回去。
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
以上就将最复杂的SwinTransformerBlock模块介绍完了~
downsamp(最后一个stage不需要)使用的是PatchMerging.对FM进行间隔采样达到降采样的目的,再concat低分辨率FM后,通过全连接对C通道裁剪。很像pixelShuffle的反向操作。
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
x = x.view(B, H, W, C)
padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x:, 0::2, 0::2, : # B H/2 W/2 C
x1 = x:, 1::2, 0::2, : # B H/2 W/2 C
x2 = x:, 0::2, 1::2, : # B H/2 W/2 C
x3 = x:, 1::2, 1::2, : # B H/2 W/2 C
x = torch.cat(x0, x1, x2, x3, -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
以上就是一个basicLayer的逻辑,通过四个stage得到不同尺度的特征图(Swin-T)
stage1-->96, 64, 64
stage2-->192, 32, 32
stage3-->384, 16, 16
stage4--> 768, 8, 8
有了这个四个特征图就可以和resnet等结构一样,接入到下游任务了~
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
扫码关注腾讯云开发者
领取腾讯云代金券
Copyright © 2013 - 2025 Tencent Cloud. All Rights Reserved. 腾讯云 版权所有
深圳市腾讯计算机系统有限公司 ICP备案/许可证号:粤B2-20090059 深公网安备号 44030502008569
腾讯云计算(北京)有限责任公司 京ICP证150476号 | 京ICP备11018762号 | 京公网安备号11010802020287
Copyright © 2013 - 2025 Tencent Cloud.
All Rights Reserved. 腾讯云 版权所有