Spatial Cross-Attention是基于Deformable Attention改进的,所以先回顾下Deformable Attention。
Paper: Deformable Detr:Deformable Transformers for end-to-end Object Detection (https://arxiv.org/pdf/2010.04159.pdf)
Deformable DETR针对Attention以下两点缺陷进行改进:
1)收敛时间长;
2)在小物体检测上效果较差;
Deformable Attention解决收敛时间长的问题:
Multi-scale Deformable Attention解决小物体检测效果差的问题。
其中,i是Camera View Index;j是Reference Points Index;
是Total Reference Points for each BEV query;
是Features of the i-th camera view;
是投影函数,用于将第j个参考点投影到第i个view image上。
由于不是BEV中的每个三维坐标只会投影到其中几个View Image上,而不会投影到所有的View Image上,可以通过这个特点大幅缩减计算量。下面是对bev_query和reference_point剔除冗余位置,重新整合的代码:
indexes = []
for i, mask_per_img in enumerate(bev_mask):
# 从每个图像上找到有效位置的index
index_query_per_img = mask_per_img[0].sum(-1).nonzero().squeeze(-1)
indexes.append(index_query_per_img)
# each camera only interacts with its corresponding BEV queries. This step can greatly save GPU memory
queries_rebatch = query.new_zeros([bs * self.num_cams, max_len, self.embed_dims])
reference_points_rebatch = reference_points_cam.new_zeros([bs * self.num_cams, max_len, D, 2])
for j in range(bs):
for i, reference_points_per_img in enumerate(reference_points_cam):
index_query_per_img = indexes[i]
# 重新整合bev_query特征,记作query_rebatch
queries_rebatch[j, i, :len(index_query_per_img)] = query[j, index_query_per_img]
# 重新整合reference_point采样位置,记作reference_points_rebatch
reference_points_rebatch[j, i, :len(index_query_per_img)] = reference_points_per_img[j, index_query_per_img]
下面是Multi-scale Deformable Attention的处理过程。如上图所示。
)。其中,num_heads是Attention Heads的数量;num_levels是Multi-Scale的Feature Map数量;num_points是采样点的数量,跟Deformalbe的设置相关。
self.sampling_offsets = nn.Linear(
embed_dims,
num_heads * num_levels * num_points * 2)
sampling_offsets = self.sampling_offsets(query)
self.value_proj = nn.Linear(embed_dims, embed_dims)
value = self.value_proj(value)
self.attention_weights = nn.Linear(embed_dims,
num_heads * num_levels * num_points)
attention_weights = self.attention_weights(query)
attention_weights = attention_weights.softmax(-1)
sampling_locations = reference_points + sampling_offsets
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights)
for j in range(bs):
for i, index_query_per_img in enumerate(indexes):
slots[j, index_query_per_img] += queries[j, i, :len(index_query_per_img)]
count = bev_mask.sum(-1) > 0
count = count.permute(1, 2, 0).sum(-1)
count = torch.clamp(count, min=1.0)
slots = slots / count[..., None]
slots = self.output_proj(slots)
return self.dropout(slots) + inp_residual
以上就是Spatial Cross-Attention模块的整体逻辑。
1.https://zhuanlan.zhihu.com/p/543335939 2.https://blog.csdn.net/weixin_42108183/article/details/128433381 3.https://zhuanlan.zhihu.com/p/538490215