Loading [MathJax]/jax/output/CommonHTML/config.js
前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >[1141]基于MODnet无绿幕抠图

[1141]基于MODnet无绿幕抠图

作者头像
周小董
发布于 2022-05-20 00:14:06
发布于 2022-05-20 00:14:06
1.8K10
代码可运行
举报
文章被收录于专栏:python前行者python前行者
运行总次数:0
代码可运行

文章目录

前言

MODNet由香港城市大学和商汤科技于2020年11月首次提出,用于实时抠图任务

MODNet特性:

  • 轻量级(light-weight )
  • 实时性高(real-time)
  • 预测时不需要额外的背景输入(trimap-free)
  • 准确度高(hight performance)
  • 单模型(single model instead of a complex pipeline)
  • 泛化能力强(better generalization ability)

论文地址 : https://arxiv.org/pdf/2011.11961.pdf git地址: https://github.com/ZHKKKe/MODNet https://github.com/xuebinqin/U-2-Net

复现代码

基于onnx推理代码

官方给出了基于torch和onnx推理代码,这里用的是关于onnx模型的推理代码.

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import os
import cv2
import argparse
import numpy as np
from PIL import Image

import onnx
import onnxruntime


if __name__ == '__main__':
    # define cmd arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--image-path', default= 'test.jpeg',type=str, help='path of the input image (a file)')
    parser.add_argument('--output-path',default= 'result.png', type=str, help='paht for saving the predicted alpha matte (a file)')
    parser.add_argument('--model-path', default='hrnet.onnx', type=str, help='path of the ONNX model')
    args = parser.parse_args()

    # check input arguments
    if not os.path.exists(args.image_path):
        print('Cannot find the input image: {0}'.format(args.image_path))
        exit()
    if not os.path.exists(args.model_path):
        print('Cannot find the ONXX model: {0}'.format(args.model_path))
        exit()

    ref_size = 512

    # Get x_scale_factor & y_scale_factor to resize image
    def get_scale_factor(im_h, im_w, ref_size):

        if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
            if im_w >= im_h:
                im_rh = ref_size
                im_rw = int(im_w / im_h * ref_size)
            elif im_w < im_h:
                im_rw = ref_size
                im_rh = int(im_h / im_w * ref_size)
        else:
            im_rh = im_h
            im_rw = im_w

        im_rw = im_rw - im_rw % 32
        im_rh = im_rh - im_rh % 32

        x_scale_factor = im_rw / im_w
        y_scale_factor = im_rh / im_h

        return x_scale_factor, y_scale_factor

    ##############################################
    #  Main Inference part
    ##############################################

    # read image
    im = cv2.imread(args.image_path)
    img = im.copy()
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)

    # unify image channels to 3
    if len(im.shape) == 2:
        im = im[:, :, None]
    if im.shape[2] == 1:
        im = np.repeat(im, 3, axis=2)
    elif im.shape[2] == 4:
        im = im[:, :, 0:3]

    # normalize values to scale it between -1 to 1
    im = (im - 127.5) / 127.5   

    im_h, im_w, im_c = im.shape
    x, y = get_scale_factor(im_h, im_w, ref_size) 

    # resize image
    im = cv2.resize(im, None, fx = x, fy = y, interpolation = cv2.INTER_AREA)

    # prepare input shape
    im = np.transpose(im)
    im = np.swapaxes(im, 1, 2)
    im = np.expand_dims(im, axis = 0).astype('float32')

    # Initialize session and get prediction
    session = onnxruntime.InferenceSession(args.model_path, None)
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name
    result = session.run([output_name], {input_name: im})

    # refine matte
    matte = (np.squeeze(result[0]) * 255).astype('uint8')
    matte = cv2.resize(matte, dsize=(im_w, im_h), interpolation = cv2.INTER_AREA)

    cv2.imwrite(args.output_path, matte)


    # 保存彩色图片
    # b,g,r = cv2.split(img)
    # rbga_img = cv2.merge((b, g, r, matte))
    rbga_img = cv2.merge((img, matte))
    cv2.imwrite('rbga_result.png',rbga_img)

抠图效果

测试图片

测试结果

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-B3JzUG16-1652968229233)(https://upload-images.jianshu.io/upload_images/12504508-d87c7f4b721020d9.png)]

可以发现抠图已经达到了丝发级别,对于清晰的图片抠图还是很准确的.

基于demo.image_matting.colab.inference推理代码

预训练模型在这里 :modnet_photographic_portrait_matting.ckpt

模型百度网盘:在这里 密码:gchf

把模型下载到目录:MODNet/pretrained,下面运行需要加载此模型。 现在,工作目录是MODNet,在其目录下建立输入图片和输出图片的目录: input-img, output-img 把需要抠图的图片放到input-img MODNet目录下,运行

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
python -m demo.image_matting.colab.inference-1   \
                   --input-path input-img  \
                   --output-path output-img  \
                   --ckpt-path pretrained/modnet_photographic_portrait_matting.ckpt

现在可以从output-img中找到已经抠好的图片xxx_fg.png,遮罩图片xxx_matte.png 看看MODNet模型的抠图效果

python程序如下。原作者的程序中只给出遮罩matte,没有抠图结果。鄙人不才,添加了抠出的前景图片,供参考。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import os
import sys
import argparse
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from src.models.modnet import MODNet

if __name__ == '__main__':
    # define cmd arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--input-path', type=str, help='path of input images')
    parser.add_argument('--output-path', type=str, help='path of output images')
    parser.add_argument('--ckpt-path', type=str, help='path of pre-trained MODNet')
    args = parser.parse_args()

    # check input arguments
    if not os.path.exists(args.input_path):
        print('Cannot find input path: {0}'.format(args.input_path))
        exit()
    if not os.path.exists(args.output_path):
        print('Cannot find output path: {0}'.format(args.output_path))
        exit()
    if not os.path.exists(args.ckpt_path):
        print('Cannot find ckpt path: {0}'.format(args.ckpt_path))
        exit()
    # define hyper-parameters
    ref_size = 512
    # define image to tensor transform
    im_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
    )

    # create MODNet and load the pre-trained ckpt
    modnet = MODNet(backbone_pretrained=False)
    modnet = nn.DataParallel(modnet).cuda()
    modnet.load_state_dict(torch.load(args.ckpt_path))
    modnet.eval()
# 注:程序中的数字仅表示某张输入图片尺寸,如1080x1440,此处只为记住其转换过程。
    # inference images
    im_names = os.listdir(args.input_path)
    for im_name in im_names:
        print('Process image: {0}'.format(im_name))
        # read image
        im = Image.open(os.path.join(args.input_path, im_name))
        # unify image channels to 3
        im = np.asarray(im)
        if len(im.shape) == 2:
            im = im[:, :, None]
        if im.shape[2] == 1:
            im = np.repeat(im, 3, axis=2)
        elif im.shape[2] == 4:
            im = im[:, :, 0:3]
        im_org = im                                # 保存numpy原始数组 (1080,1440,3)
        # convert image to PyTorch tensor
        im = Image.fromarray(im)
        im = im_transform(im)
        # add mini-batch dim
        im = im[None, :, :, :]
        # resize image for input
        im_b, im_c, im_h, im_w = im.shape
        if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
            if im_w >= im_h:
                im_rh = ref_size
                im_rw = int(im_w / im_h * ref_size)
            elif im_w < im_h:
                im_rw = ref_size
                im_rh = int(im_h / im_w * ref_size)
        else:
            im_rh = im_h
            im_rw = im_w
        im_rw = im_rw - im_rw % 32
        im_rh = im_rh - im_rh % 32
        im = F.interpolate(im, size=(im_rh, im_rw), mode='area')

        # inference
        _, _, matte = modnet(im.cuda(), True)    # 从模型获得的 matte ([1,1,512, 672])

        # resize and save matte,foreground picture
        matte = F.interpolate(matte, size=(im_h, im_w), mode='area')  #内插,扩展到([1,1,1080,1440])  范围[0,1]
        matte = matte[0][0].data.cpu().numpy()    # torch 张量转换成numpy (1080, 1440)
        matte_name = im_name.split('.')[0] + '_matte.png'
        Image.fromarray(((matte * 255).astype('uint8')), mode='L').save(os.path.join(args.output_path, matte_name))
        matte_org = np.repeat(np.asarray(matte)[:, :, None], 3, axis=2)   # 扩展到 (1080, 1440, 3) 以便和im_org计算
        
        foreground = im_org * matte_org + np.full(im_org.shape, 255) * (1 - matte_org)         # 计算前景,获得抠像
        fg_name = im_name.split('.')[0] + '_fg.png'
        Image.fromarray(((foreground).astype('uint8')), mode='RGB').save(os.path.join(args.output_path, fg_name))

参考: https://blog.csdn.net/weixin_44238733/article/details/114457650 https://blog.csdn.net/small_wu/article/details/124041904 https://blog.csdn.net/jacke121/article/details/110774623 https://aijishu.com/a/1060000000162206 https://blog.csdn.net/missyoudaisy/article/details/111085552

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2022-05-19,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
1 条评论
热度
最新
如果心动了赶紧自己也去制作一个吧。
如果心动了赶紧自己也去制作一个吧。
回复回复点赞举报
推荐阅读
编辑精选文章
换一批
QT(三).电子相册(3)
需要注意的是,这里面定义了一个 Ui_Pic 类 , 这个类我们之前在 pic.h 中见过
franket
2021/09/14
1.1K0
python程序界面
# -*- coding: utf-8 -*- # Form implementation generated from reading ui file 'Main.ui' # # Created: Thu Jan 29 16:25:31 2015 # by: PyQt4 UI code generator 4.11.3 # # WARNING! All changes made in this file will be lost! from PyQt4 import QtCore, QtGu
py3study
2020/01/06
1.3K0
PyQt中如何结合Qt设计师进行开发
Qt设计师是Qt的所见即所得的界面设计工具,通过拖拉方式设计界面,但它并不能产生任何代码。
bear_fish
2018/09/20
9000
PyQt中如何结合Qt设计师进行开发
【Qt】初始项目代码解释
本文将聚焦与项目创建后的这5个文件的解析,这5个文件分别为: test250225, mywidget.h, main.cpp, mywidget.cpp, mywidget.ui。 现在开始逐个解析
Yui_
2025/02/26
1810
【Qt】初始项目代码解释
1. qt 入门-整体框架[通俗易懂]
总结: 本文先通过一个例子介绍了Qt项目的大致组成,即其一个简单的项目框架,如何定义窗口类,绑定信号和槽,然后初始化窗口界面,显示窗口界面,以及将程序的控制权交给Qt库。
全栈程序员站长
2022/09/20
1.8K0
1. qt 入门-整体框架[通俗易懂]
Qt中中文处理的简单方法
    QT是一套很不错的界面开发库,而且考虑到了跨平台的要求,使用也相对比较容易上手。我也刚刚才学习用QT开发,发现它对中文的处理做的不是很好,或者更贴切的是做的不够智能吧,如果在字符串中输入中文,显示的就会是乱码。     下面就介绍一个简单的方法,让我们的中文正确显示出来,先看一段程序,该程序主要功能就是显示一个窗口,窗口上面的按钮显示中文。
阳光岛主
2019/02/19
1.3K0
QT(二).计算器(4)
void QTextCodec::setCodecForTr ( QTextCodec * c ) [static]
franket
2021/09/14
5970
初识Qt · 实现hello world的N种细节和对象树
继上文我们了解了QT的环境,历史的基本知识,以及了解了如何创建一个项目,项目的内容都包括什么,本文我们学习的是如何在GUI界面上打印Hello world,重要的不是hello world本身,而是在hello world背后牵扯到的N个知识点。
_lazy
2025/03/04
1390
初识Qt · 实现hello world的N种细节和对象树
Qt编写自定义控件23-广告轮播控件
广告轮播这个控件做的比较早,是很早以前定制一个电信客户端时候用到的,该客户端需要在首页展示轮播预先设定好的图片,图片的路径可以自由设定,然后轮播的间隔速度可以自由控制,同时该控件还需要提供两种指示器的风格,一种是迷你型的样式,一种是数字型的样式。
feiyangqingyun
2019/08/27
1K0
Qt编写自定义控件23-广告轮播控件
python+pycharm+pyqt5安装教程「建议收藏」
现在教大家在Windows系统下如何安装Python + PyCharm + PyQt5
全栈程序员站长
2022/09/25
4.3K0
python+pycharm+pyqt5安装教程「建议收藏」
【QT】QT事件处理
QT中,事件作为一个对象,继承自QEvent类,常见的有键盘事件QKeyEvent、鼠标事件QMouseEvent和定时器事件QTimerEvent等。QT中,任何QObject子类示例都可以接收和处理事件。实际编程中通常实现部件的paintEvent()、mousePressEvent()等事件处理函数来处理特定部件的特定事件。
半生瓜的blog
2023/05/13
1.7K0
【QT】QT事件处理
QT(二).计算器(2)
从中可知,这是一个冗长的 XML 文件 内容是在描述窗体与各个控件的参数 Qt 就是通过这些参数来绘制图形的 代码示例 main.cpp #include <QtGui/QApplication> //QApplication 类管理图形用户界面应用程序的控制流和主要设置 #include <QtCore/QTextCodec> //用来进行字符集转化 #include "calc.h" int main(int argc, char** argv) { QApplication app(argc,
franket
2021/09/14
1K0
【QT】QT模型/视图
MVC(Model-View-Controller)包括了3个组件:模型(model)是应用对象,用来表示数据;视图(View)是模型的用户界面,用来显示数据;控制(Controller)定义了用户界面对用户输入的反应方式。
半生瓜的blog
2023/05/13
3.2K0
【QT】QT模型/视图
被QT5 抛弃的函数和用法
注意:当有QT += webkitwidgets的时候,就不再需要QT += widgets
用户3519280
2023/07/06
6430
初识Qt · Qt的基本认识和基本项目代码解释
虽然现在学习了Linux的系统部分,C++,以及部分数据结构,也了解了一下git的相关内容,但是呢,对于向外拓展的方面笔者感觉并不是很充实,对于Qt,对于算法,对于MySQL等都没有具体了解过,所以笔者最近也是突然有了点内驱力了,打算在这个假期更新完Qt,至少咱们更新完能结合数数据库写一个项目,项目的话呢,就是仿QQ音乐的一款播放器吧!其实我是想要仿制酷狗的,后面看吧,其实都一样。
_lazy
2025/03/04
1890
初识Qt · Qt的基本认识和基本项目代码解释
19.QT-事件发送函数sendEvent()、postEvent()
Qt发送事件分为两种 -阻塞型事件发送 需要重写接收对象的event()事件处理函数 当事件发送后,将会立即进入event()事件处理函数进行事件处理 通过sendEvent()静态函数实现阻塞发送: bool QApplication::sendEvent ( QObject * receiver, QEvent * event ) ; // receiver:接收对象, event :要发送的event类型(比如:鼠标双击) //当有事件发送,将会
诺谦
2018/05/28
3.3K0
qt tabwidget切换_标签怎么在新窗口打开
QTabWidget 用来分页显示 重要函数: 1.void setTabText(int, QString); //设置页面的名字. 2.void setTabToolTip(QString); //设置页面的提示信息. 3.void setTabEnabled(bool); //设置页面是否被激活. 4.void setTabPosition(QTabPosition::South); //设置页面名字的位置. 5.void setTabsClosable(bool); //设置页面关闭按钮。 6.int currentIndex(); //返回当前页面的下标,从0开始. 7.int count(); //返回页面的数量. 8.void clear(); //清空所有页面. 9.void removeTab(int); //删除页面. 10.void setMoveable(bool); //设置页面是否可被拖拽移动. 11.void setCurrentIndex(int); //设置当前显示的页面.
全栈程序员站长
2022/11/04
4K0
qt tabwidget切换_标签怎么在新窗口打开
Qt | 安全的udp客户端搭建(代码框架值得学习)
通过网盘分享的文件:secureudpclient 链接: https://pan.baidu.com/s/1txCWIo7-WhM-CjVkp_aDdg?pwd=13j9 提取码: 13j9 【一定要转存】
Qt历险记
2024/12/15
2330
Qt | 安全的udp客户端搭建(代码框架值得学习)
【Qt】:Dialog 对话框
模态对话框 指的是:显示后无法与父窗口进行交互,是⼀种阻塞式的对话框。使用 QDialog:: exec () 函数调用。
IsLand1314
2025/02/28
2790
【Qt】:Dialog 对话框
2.QT-窗口组件(QWidget),QT坐标系统,初探消息处理(信号与槽)
本章主要内容如下: 1) 窗口组件(QWidget) 2) QT坐标系统 3) 初探消息处理(信号与槽) ---- 窗口组件(QWidget) 介绍 Qt以组件对象的方式构建图形用户界面 Qt中没有父
诺谦
2018/04/23
2.2K0
2.QT-窗口组件(QWidget),QT坐标系统,初探消息处理(信号与槽)
相关推荐
QT(三).电子相册(3)
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验