前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >带你玩转 3D 检测和分割(一):MMDetection3D 整体框架介绍

带你玩转 3D 检测和分割(一):MMDetection3D 整体框架介绍

作者头像
OpenMMLab 官方账号
发布2022-04-08 18:12:30
3K0
发布2022-04-08 18:12:30
举报
文章被收录于专栏:OpenMMLab

由于 3D 本身数据的复杂性和 MMDetection3D 支持任务(点云 3D 检测、单目 3D 检测、多模态 3D 检测和点云 3D 语义分割等)和场景(室内和室外)的多样性,整个框架结构相对复杂,新人用户的上手门槛相对较高。所以我们推出新的系列文章,让各个细分方向的用户都能轻松上手 MMDetection3D,基于框架进行自己的研究和开发。在系列文章的初期,我们会先带大家了解整个框架的设计流程,分析框架中的各种核心组件,介绍数据集的处理方法,然后再对各个细分任务及经典模型进行具体细节的代码层级介绍。同时也欢迎大家在评论区提出自己的需求,我们会收集各位的反馈补充更多的文章教程 ~

我们首先为大家介绍整个代码库的目录结构,让大家有个初步的认识:

代码语言:javascript
复制
# MMDetection3D 代码目录结构,展示主要部分
mmdetection3d
   |
   |- configs                    # 配置文件
   |- data                       # 原始数据及预处理后数据文件
   |- mmdet3d 
   |     |- ops                  # cuda 算子(即将迁移到 mmcv 中)
   |     |- core                 # 核心组件
   |     |- datasets             # 数据集相关代码
   |     |- models               # 模型相关代码
   |     |- utils                # 辅助工具
   |     |- ...
   |- tools
   |     |- analysis_tools       # 分析工具,包括可视化、计算flops等
   |     |- data_converter       # 各个数据集预处理转换脚本
   |     |- create_data.py       # 数据预处理入口
   |     |- train.py             # 训练脚本
   |     |- test.py              # 测试脚本
   |     |- ...                      
   |- ...

作为开篇文章,笔者将从任务介绍、算法模型支持、数据预处理、模块抽象以及训练和测试流程给大家带来介绍。

1. 任务介绍

3D 目标检测按照输入数据模态划分可以分为:点云 3D 检测、纯视觉 3D 检测以及多模态 3D 检测(点云+图片)

点云 3D 检测

单目 3D 检测

从目前来说,基于纯视觉(例如单目)的 3D 检测方法在性能上和基于点云的 3D 检测方法仍然有比较大的差距,但是其胜在便捷性和低成本;同时,多模态 3D 检测也是一个在学术界和工业界都很火热的方向,对不同模态的数据各取所长,相互配合从而达到更好的检测效果。

上述描述的主要还是室外场景的 3D 检测,最广泛的实际应用场景就是最近火热的自动驾驶领域;而室内场景的 3D 检测同样也有广阔的应用前景,例如室内机器人(扫地机器人)、室内导航等等,而目前室内 3D 检测仍然以点云数据为主。

除此以外,MMDetection3D 还拓展到了点云 3D 语义分割领域,目前已经支持了室内点云语义分割,同时会在将来支持室外点云语义分割

2. 算法模型支持

所有模型相关代码位于 mmdet3d/models 下,MMDetection3D 支持的各个方向的模型大体可以归类如下:

总体来说,由于 MMDetection3D 依赖于 MMDetection 和 MMSegmentation, 所以很多的模型及组件都是直接复用或者继承而来。目前在 MMDetection3D 内,整体模型的构建方式会根据任务类型被划分为三种方式,具体如下图所示 (PS: 我们正在进行整体代码的重构,统一所有任务的模型构建方式

点云 3D 检测(包含多模态 3D 检测)

对于点云 3D 检测(多模态 3D 检测),我们继承自 MMDetection 中的 BaseDetector 构建了适用于 3D 检测的 Base3DDetector ,再根据检测中的单阶段和二阶段分别构造,需要注意的是不同于 SingleStage3DDetector,为了尽可能的复用已有的代码组件,二阶段检测器 TwoStage3DDetector 同时继承自 Base3DDetector 和 TwoStageDetector。而由于多模态任务的特殊性,我们专门为多模态检测方法设计了 MVXTwoStage3DDetector,图中只列出了部分支持的模型算法。

单目 3D 检测:

对于单目 3D 检测,考虑到和 2D 检测输入数据的一致性,同时方便做 2D 检测的同学能快速的上手单目 3D 检测,我们继承自 MMDetection 中的 SingleStageDetector, 构建了 SingleStageMono

3DDetector,目前所支持的单目 3D 检测算法都是基于该类构建的。

点云 3D 语义分割 :

对于点云 3D 语义分割,我们继承自 MMSegmentation 中的 BaseSegmentor 构建了适用于点云分割的 Base3DSegmentor,而目前所支持的点云分割算法都是遵循 EncoderDecoder3D 模式。

3. 数据预处理

该部分对应于 tools/create_data.py ,各个数据集预处理脚本位于 tools/data_converter 目录下。由于 3D 数据集的多样性,MMDetection3D 会对数据集做预处理。我们在官方文档里面介绍了不同的数据集的格式转换方法和命令,在这里我们从整体视角来看一下数据预处理的文件生成过程:

在 MMDetection3D 中,不同的任务和不同的场景(室内或室外)的数据预处理都会存在一定的区别,如上图所示,会产生不同的预处理后的文件,便于后续训练。

  • 对所有的任务和场景,我们统一使用数据处理脚本转换后的 pkl 文件,该文件包含数据集的各种信息,包括数据集路径、calib 信息和标注信息等等,从而做到各个数据集内部格式尽可能的统一。
  • 对于点云 (多模态)3D 检测,室内和室外数据集生成的文件是不一样的: 对于某些室外数据集,我们会借助 pkl 文件的信息进一步提取 reduced_point_cloud 和 gt_database:前者是仅包含前方视野的点云文件,通常存在于 kitti 数据集处理过程中,因为 kitti 数据集仅包含前方视野的标注;后者则是将包含在训练数据集的每个 3D 边界框中的点云数据分别提取出来得到的各个物体的点云文件,常用来在数据增强时使用(copy-paste)。 而对于室内数据集,由于室内点云较为密集的特点,通常会进行点云的下采样处理,保存在points内。
  • 对于单目 3D 检测,由于在前面提到,整个模型构建的流程是遵循 2D 检测的,同样的在数据处理的过程中,在生成基本的 pkl 文件后,还需要将其转换为 coco 标注格式的 json 文件,该过程中会对 pkl 的标注信息做相应处理,实际在该任务中,pkl 文件用来提供 data 信息,json 文件提供标注信息。
  • 对于点云 3D 语义分割,目前 MMDetection3D 仅支持室内点云分割,相对于检测任务,如图所示需要生成额外的文件:instance_mask 包含每个点云的实例标签,semantic_mask 包含每个点云的语义标签,seg_info 包含额外的辅助训练的信息。

我们在这里对数据预处理生成的文件有个初步的认识,在后续的文章中我们会按照场景为数据集进行分类,对处理过程做具体介绍,方便大家的理解和使用自己的数据集训练模型。做数据转换的过程主要是为了尽可能统一各个数据的格式,从而简化训练的过程,整个数据预处理的部分是相对独立的。

4. 模块抽象

和 MMDetection 一脉相承,整个 MMDetection3D 的模块内部抽象流程也主要包括 Pipeline、DataParallel、Model、Runner 和 Hooks。如果对DataParallel、Runner 和 Hooks这三个抽象模块不熟悉的同学,我们非常推荐大家先参考轻松掌握 MMDetection 整体构建流程(二)这篇文章中的【第二层模块抽象】部分了解这些抽象概念。在这里我们重点介绍不同的 Pipeline 和 Model。

4.1 Pipeline

具体在 Pipeline 方面由于数据模态的不同,所以在数据处理过程中包含不同的信息。

上图展示了三个比较典型的 3D 检测 pipeline, 流程自上而下分别是点云 3D 检测、多模态 3D 检测和单目 3D 检测,从上述的流程可以,pipeline 其实是由一系列的按照插入顺序运行的数据处理模块组成,接受数据字典,输出经过处理后的数据字典,MMDetection3D 对于点云 3D 检测提供了很多常用的 pipeline 模块,比如 GlobalRotScaleTrans(点云的旋转缩放)、PointsRangeFilter / ObjectRangeFilter(限定了点云和物体的范围)、PointShuffle(打乱点云数据);而对于单目 3D 检测基本就是直接调用 MMDetection 的数据处理模块,比如 Resize (图片缩放)、Normalize (正则化)、Pad (图片填充);多模态检测则兼用两者。我们可以看到其实这些任务共享了部分的 pipeline 模块,比如 LoadAnnotations3D (标签载入)、RandomFlip3D(会对点云和图片同时进行翻转)、DefaultFormatBundle3D(数据格式化)、Collect3D (选取需要用于训练的数据和标签),这些代码都在 mmdet3d/datasets/pipeline 目录下。

4.2 Model

在该部分我们按照任务类型分类,对于整个模型内部做抽象介绍。和 2D 检测类似, 3D 检测器通常也包含了几个核心组件:Backbone 用于提取特征、Neck 进行特征融合和增强、Head 用于输出需要的结果。

1)点云 3D 检测模型

目前点云目标检测按照对点云数据的处理方式,可以分为体素处理方法 (Voxel-based) 原始点云处理方法 (Point-based),这两种方法其实在构建模型的时候会有一定的区别,整体的模型构建按照下图流程所示:

- 基于体素的模型通常需要 Encoder 来对点云体素化,如 HardVFE 和 PointPillarScatter等,采用的稀疏卷积或者 Pillars 的方法从点云中生成 2D 特征图,然后基本可以套用 2D 检测流程进行 3D 检测。

- 基于原始点云模型通常直接采用 3D Backbone (Pointnet / Pointnet++ 等) 提取点的特征,再针对提取到的点云特征采用 RoI 或者 Group 等方式回归 3D bounding box。有关的具体内容我们会在后续的文章中针对典型的方法进行分析介绍。

2)单目 3D 检测模型

由于单目 3D 检测的输入是图片,输出是 3D bounding box, 所以整体的检测流程和模型组成来说基本和 2D 检测保持一致,具体检测方法同样也会在后续文章中进行解析。

3)多模态 3D 检测模型

多模态的检测模型从组成来看可以看成 2D 检测模型和点云检测模型的拼接。

4) 点云 3D 语义分割模型

MMDetection3D 内部支持的 3D 分割模型都是符合 EncoderDecoder 结构的,需要 backbone 来 encode feature, decode_head 用来预测每个点云的类别的进行分割,目前主要只支持室内场景的 3D 语义分割,具体的分割模型方法同样会在后续文章中进行解析。

5. 训练和测试流程

轻松掌握 MMDetection 整体构建流程(二)中的 【第三层代码抽象】部分中,按照训练和测试整体代码抽象流程-> Runner 训练和验证代码抽象 -> Model 训练和测试代码抽象的方式给大家进行了介绍,在这里我们简要概括前两部分:

首先我们训练和验证调用的是 tools/train.py 脚本,先进行 Dataset、Model 等相关类初始化,然后我们构建了一个 runner,最终模型的训练和验证过程是发生在 runner 内部的,而训练和验证的时候实际上是 runner 调用了 model 内部的 train_step 和 val_step 函数。

对如何从 tools/train.py 脚本开始到调用 train_step 和 val_step 函数的细节过程可以参考前述文章的【第三层代码抽象】部分,而理解了这两个函数调用流程就理解了 MMDetection3D 训练和测试流程。笔者在这一部分主要以 PointPillars 为例分析 MMDetection3D 中 Model 的训练和测试代码:

5.1 train 和 val 流程

1) 调用 runner 中的 train_step 或者 val_step 【该部分内容来自前序文章】

在 runner 中调用 train_step 或者 val_step,代码如下:

代码语言:javascript
复制
#=================== mmcv/runner/epoch_based_runner.py ==================
if train_mode:
    outputs = self.model.train_step(data_batch,...)
else:
    outputs = self.model.val_step(data_batch,...)

实际上,首先会调用 DataParallel 中的 train_step 或者 val_step ,其具体调用流程为:

代码语言:javascript
复制
# 非分布式训练
#=================== mmcv/parallel/data_parallel.py/MMDataParallel ==================
def train_step(self, *inputs, **kwargs):
    if not self.device_ids:
        inputs, kwargs = self.scatter(inputs, kwargs, [-1])
        # 此时才是调用 model 本身的 train_step
        return self.module.train_step(*inputs, **kwargs)
    # 单 gpu 模式
    inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
    # 此时才是调用 model 本身的 train_step
    return self.module.train_step(*inputs[0], **kwargs[0])

# val_step 也是的一样逻辑
def val_step(self, *inputs, **kwargs):
    inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
    # 此时才是调用 model 本身的 val_step
    return self.module.val_step(*inputs[0], **kwargs[0])

可以发现,在调用 model 本身的 train_step 前,需要额外调用 scatter 函数,前面说过该函数的作用是处理 DataContainer 格式数据,使其能够组成 batch,否则程序会报错。

如果是分布式训练,则调用的实际上是 mmcv/parallel/distributed.py/MMDistributedDataParallel,最终调用的依然是 model 本身的 train_step 或者 val_step。

2) 调用 model 中的 train_step 或者 val_step

训练流程:

代码语言:javascript
复制
#=================== mmdet/models/detectors/base.py/BaseDetector =============
def train_step(self, data, optimizer):
    # 调用本类自身的 forward 方法
    losses = self(**data)
    # 解析 loss
    loss, log_vars = self._parse_losses(losses)
    # 返回字典对象
    outputs = dict(
        loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))
    return outputs
    
#=================== mmdet/models/detectors/base.py/Base3DDetector ===========
# Base3DDetector 主要是重写了 forward,改变了模型输入数据的类型,可同时传入点云数据和图片数据,从而满足多模态检测的需求
@auto_fp16(apply_to=('img', 'points'))
def forward(self, return_loss=True, **kwargs):
    if return_loss:
        # 训练模式
        return self.forward_train(**kwargs)
    else:
        # 测试模式
        return self.forward_test(**kwargs)

forward_train 和 forward_test 需要在不同的算法子类中实现,输出是 Loss 或者 预测结果。

3) 调用子类中的 forward_train 方法

PointPillars 采用的是 VoxelNet 检测器,核心逻辑还是比较通用的。

代码语言:javascript
复制
#============= mmdet/models/detectors/voxelnet.py/VoxelNet ============
def forward_train(self,
                  points,
                  img_metas,
                  gt_bboxes_3d,
                  gt_labels_3d,
                  gt_bboxes_ignore=None):
    # 先进行点云的特征提取  
    x = self.extract_feat(points, img_metas)
    # 主要是调用 bbox_head 内部的 forward_train 方法,得到 head 输出
    outs = self.bbox_head(x)
    loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_metas)
    # 将 head 部分的输出和数据的 label 送入计算 loss
    losses = self.bbox_head.loss(
        *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
    return losses

4) 调用 model 中的 _parse_losses 方法

代码语言:javascript
复制
#=================== mmdet/models/detectors/base.py/BaseDetector ==================
def _parse_losses(self, losses):
    
    # 返回来的 losses 是一个dict, 我们需要对 loss 进行求和
    log_vars = OrderedDict()
    for loss_name, loss_value in losses.items():
        if isinstance(loss_value, torch.Tensor):
            log_vars[loss_name] = loss_value.mean()
        elif isinstance(loss_value, list):
            log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
        else:
            raise TypeError(
                f'{loss_name} is not a tensor or list of tensors')

    loss = sum(_value for _key, _value in log_vars.items()
               if 'loss' in _key)

    log_vars['loss'] = loss
    for loss_name, loss_value in log_vars.items():
        # reduce loss when distributed training
        if dist.is_available() and dist.is_initialized():
            loss_value = loss_value.data.clone()
            dist.all_reduce(loss_value.div_(dist.get_world_size()))
        log_vars[loss_name] = loss_value.item()

    return loss, log_vars

5.2 test 流程

test 流程如上图所示, 我们可以看见在 test 的时候流程相比 train / val 更为简单,没有调用 runner 对象。

1) 调用 model 中的 forward_test

代码语言:javascript
复制
#=================== mmdet/models/detectors/base.py/Base3DDetector ===========
def forward_test(self, points, img_metas, img=None, **kwargs):
    num_augs = len(points)
    if num_augs != len(img_metas):
        raise ValueError(
            'num of augmentations ({}) != num of image meta ({})'.format(
                len(points), len(img_metas)))
    # 根据 points list 长度判断是 simple_test 还是 aug_test
    if num_augs == 1:
        img = [img] if img is None else img
        return self.simple_test(points[0], img_metas[0], img[0], **kwargs)
    else:
        return self.aug_test(points, img_metas, img, **kwargs)

2) 调用子类 的 simple_test 或 aug_test

代码语言:javascript
复制
#============= mmdet/models/detectors/voxelnet.py/VoxelNet ============
def simple_test(self, points, img_metas, imgs=None, rescale=False):
    # 无数据增强测试
    # 提取特征
    x = self.extract_feat(points, img_metas)
    # 调用 head 
    outs = self.bbox_head(x)
    # 根据 head 输出结果生成 bboxes
    bbox_list = self.bbox_head.get_bboxes(
        *outs, img_metas, rescale=rescale)
    # 对检测结果进行格式调整
    bbox_results = [
        bbox3d2result(bboxes, scores, labels)
        for bboxes, scores, labels in bbox_list
    ]
    return bbox_results

def aug_test(self, points, img_metas, imgs=None, rescale=False):
    # 数据增强测试
    feats = self.extract_feats(points, img_metas)

    # 目前只支持单个 sample 的 aug_test
    aug_bboxes = []
    for x, img_meta in zip(feats, img_metas):
        outs = self.bbox_head(x)
        bbox_list = self.bbox_head.get_bboxes(
            *outs, img_meta, rescale=rescale)
        bbox_list = [
            dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels)
            for bboxes, scores, labels in bbox_list
        ]
        aug_bboxes.append(bbox_list[0])

    # 将增强后的 bboxes 进行 merge 合并操作
    merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
                                        self.bbox_head.test_cfg)

    return [merged_bboxes]

以上我们主要分析了整体的框架流程,在下一篇文章中我们会为大家带来 MMDetection3D 中的各种核心组件的分析和介绍,包括 3D 检测中令人困惑的坐标系问题,敬请期待~

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-03-09,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 OpenMMLab 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
图像处理
图像处理基于腾讯云深度学习等人工智能技术,提供综合性的图像优化处理服务,包括图像质量评估、图像清晰度增强、图像智能裁剪等。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档