
💡💡💡本文独家改进:改进1)重叠空间降维注意力(OSRA),2)混合网络模块(D-Mixer),聚合全局信息和局部细节,分别引入到YOLOv8,做到二次创新;
推荐指数:5颗星

问题点:
本文依旧从经典的 ViTs 说起,即基于 MHSA 构建远距离建模实现全局感受野的覆盖,但缺乏像 CNNs 般的归纳偏差能力。因此在泛化能力上相对较弱,需要大量的训练样本和数据增强策略来弥补。
本文:为了解决上述问题,这篇论文针对性地引入了一种新的混合网络模块,称为Dual Dynamic Token Mixer (D-Mixer),它以一种依赖于输入的方式聚合全局信息和局部细节。具体来说,输入特征被分成两部分,分别经过一个全局自注意力模块和一个依赖于输入的深度卷积模块进行处理,然后将两个输出连接在一起。这种简单的设计可以使网络同时看到全局和局部信息,从而增强了归纳偏差。论文中的实验证明,这种方法在感受野方面表现出色,即网络可以看到更广泛的上下文信息。

提出了一个轻量级的双动态token混频器(D-Mixer),它以一种依赖输入的方式聚合全局信息和局部细节。D-Mixer的工作原理是在均匀分割的特征段上分别应用高效的全局注意模块和输入依赖的深度卷积,从而赋予网络强大的归纳偏置和扩大的有效接受野。用D-Mixer作为基本构建块来设计TransXNet,这是一种新颖的混合CNN-Transformer视觉骨干网络。

如图1所示,提出的TransXNet采用了四个阶段的分层架构。每个阶段由一个patch嵌入层和几个顺序堆叠的块组成。使用7×7卷积层(步幅=4)实现第一个patch嵌入层,然后使用批归一化(BN),而其余阶段的patch嵌入层使用3×3卷积层(步幅=2)和BN。每个块由一个动态位置编码(DPE)层、一个双动态token混频器(D-Mixer)和一个多尺度前馈网络(MS-FFN)组成。

Overlapping Spatial Reduction Attention (OSRA)
空间降维注意(SRA)在前人的研究中得到了广泛的应用,利用稀疏标记区域关系高效提取全局信息。然而,为了减少标记计数而进行的非重叠空间缩减打破了patch边界附近的空间结构,降低了token的质量。为了解决这一问题,在SRA中引入了重叠空间缩减(OSR),通过使用更大的重叠斑块来更好地表示斑块边界附近的空间结构。在实践中,将OSR实例化为深度可分离卷积,其中步幅跟随PVT,内核大小等于步幅加3。

参考:
理论部分详见:CNN 与 ViT 的完美结合 | TransXNet: 结合局部和全局注意力提供强大的归纳偏差和高效感受野 - 知乎 (zhihu.com)
核心代码:
class OSRAAttention(nn.Module): ### OSRA
def __init__(self, dim,
num_heads=1,
qk_scale=None,
attn_drop=0,
sr_ratio=1, ):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.sr_ratio = sr_ratio
self.q = nn.Conv2d(dim, dim, kernel_size=1)
self.kv = nn.Conv2d(dim, dim * 2, kernel_size=1)
self.attn_drop = nn.Dropout(attn_drop)
if sr_ratio > 1:
self.sr = nn.Sequential(
ConvModule(dim, dim,
kernel_size=sr_ratio + 3,
stride=sr_ratio,
padding=(sr_ratio + 3) // 2,
groups=dim,
bias=False,
norm_cfg=dict(type='BN2d'),
act_cfg=dict(type='GELU')),
ConvModule(dim, dim,
kernel_size=1,
groups=dim,
bias=False,
norm_cfg=dict(type='BN2d'),
act_cfg=None, ), )
else:
self.sr = nn.Identity()
self.local_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim)
def forward(self, x, relative_pos_enc=None):
B, C, H, W = x.shape
q = self.q(x).reshape(B, self.num_heads, C // self.num_heads, -1).transpose(-1, -2)
kv = self.sr(x)
kv = self.local_conv(kv) + kv
k, v = torch.chunk(self.kv(kv), chunks=2, dim=1)
k = k.reshape(B, self.num_heads, C // self.num_heads, -1)
v = v.reshape(B, self.num_heads, C // self.num_heads, -1).transpose(-1, -2)
attn = (q @ k) * self.scale
if relative_pos_enc is not None:
if attn.shape[2:] != relative_pos_enc.shape[2:]:
relative_pos_enc = F.interpolate(relative_pos_enc, size=attn.shape[2:],
mode='bicubic', align_corners=False)
attn = attn + relative_pos_enc
attn = torch.softmax(attn, dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(-1, -2)
return x.reshape(B, C, H, W)
详见:
https://blog.csdn.net/m0_63774211/category_12289773.html
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。