MODNet模型ONNX介绍
一键人像抠图,实时支持的模型,整个代码实现是基于Pytorch完成,通过脚本可以一键导出ONNX格式模型,官方提供了ONNXRUNTIME模型部署推理演示的Python版本代码。项目的github地址如下:
https://github.com/ZHKKKe/MODNet
然后可以直接下载官方提供的ONNX格式模型文件,模型文件打开输入与输出格式如下:
输入格式是 NCHW,其中HW支持不定长(动态)输入,但是绝对不能超过512。输出格式是:1x1xHxW,输出的大小跟输入HW一致,单个通道预测值,小于0.5可以看作是背景。
代码实现
01
图像预处理
MODNet模型输入图像数据预处理要求把图像转化0~1之间的浮点数,输入图像格式RGB顺序,转化0~1之间的浮点数是通过减去127.5然后除以127.5获得。然后把图像格式维度转化为NCHW。代码如下:
cv::Mat rgb, gblob;
cv::cvtColor(frame, rgb, cv::COLOR_BGR2RGB);
cv::resize(rgb, gblob, cv::Size(input_w, input_h));
gblob.convertTo(gblob, CV_32F);
cv::subtract(gblob, cv::Scalar(127.5, 127.5, 127.5), gblob);
cv::divide(gblob, cv::Scalar(127.5, 127.5, 127.5), gblob);
cv::Mat blob = cv::dnn::blobFromImage(gblob);
02
预测后处理
得到的推理后数据维度格式与输入相似,但是通道只有单个通道,通过阈值0.5分割为前景与背景,实现人像Mask对象提取,后处理代码如下:
cv::Mat mask = cv::Mat::zeros(cv::Size(input_w, input_h), CV_8UC1);
for (int row = 0; row < input_h; row++) {
for (int col = 0; col < input_w; col++) {
float c1 = mask_data[row*input_w + col] ;
if (c1 > 0.5) {
mask.at<uchar>(row, col) = 255;
}
}
}
cv::Mat result;
cv::imshow("mask", mask);
cv::resize(mask, mask, cv::Size(frame.cols, frame.rows));
cv::bitwise_and(frame, frame, result, mask);
03
测试运行
基于ONNXRUNTIME框架,推理测试结果运行如下:
我只能说扣的真好,然后我叠加一下背景,效果丝滑,显示如下: