前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Python ONNX 模型转换、加载、简化、推断

Python ONNX 模型转换、加载、简化、推断

作者头像
为为为什么
发布2024-02-05 08:31:47
1.3K0
发布2024-02-05 08:31:47
举报
文章被收录于专栏:又见苍岚

Microsoft 和合作伙伴社区创建了 ONNX 作为表示机器学习模型的开放标准。 本文记录 Python 下 pytorch 模型转换 ONNX 的相关内容。

简介

ONNX Runtime是一个跨平台的推理和训练机器学习加速器。

在 Pytorch 框架中训练好模型后,在部署时可以转成 onnx,再进行下一步部署。

模型转换

核心代码:

  • 生成 onnx 模型: torch.onnx.export
  • 简化 onnx 模型: onnxsim.simplify
代码语言:text
复制
import torch
import onnxsim
import onnx

def export_to_onnx(model, output_path, input_shape, input_name, output_names):
    dummy_input = torch.rand(1, *input_shape)

    model.eval()

    temp_dict = dict()

    temp_onnx_path = output_path.replace('.onnx', '_temp.onnx')

    torch.onnx.export(model, 						# pytorch 模型
                    (dummy_input, 'ALL'),  			# 可以输入 tuple 
                    temp_onnx_path, 				# 输出 onnx 模型路径
                    verbose=False, 					# 聒噪
                    opset_version=11,				# onnx 版本
                    export_params=True, 			# 一个指示是否导出模型参数(权重)以及模型架构的标志。
                    do_constant_folding=True,   	# 一个指示是否在导出过程中折叠常量节点的标志
                    input_names=[input_name],		# 输入节点名称列表(可选)
                    output_names=output_names		# 输出节点名称列表(可选)
                    )

    input_data = {'image': dummy_input.cpu().numpy()}
    model_sim, flag = onnxsim.simplify(temp_onnx_path, input_data=input_data) # 简化 onnx 

    if flag:
        onnx.save(model_sim, output_path)
        print(f"simplify onnx model successfully !")
    else:
        print(f"simplify onnx model failed !!!")

  • 注意: torch.onnx.export 输入伪数据可以支持字符串,但是在 onnx 模型中仅会记录张量流转的路径,字符串、分支逻辑一般不会保存。
模型检查

onnx 加载模型后可以检测是否合法。

代码语言:text
复制
# onnx check
onnx_model = onnx.load(onnx_model_path)
try:
    onnx.checker.check_model(onnx_model)
except onnx.checker.ValidationError as e:
    print('The model is invalid: %s' % e)
else:
    print('The model is valid!')

加载、运行 ONNX 模型

ONNXruntime 安装:

代码语言:text
复制
pip install onnxruntime       # CPU build
pip install onnxruntime-gpu   # GPU build

推理代码:

代码语言:text
复制
import onnxruntime

session = onnxruntime.InferenceSession("path to model")
session.get_modelmeta()
results = session.run(["output1", "output2"], {"input1": indata1, "input2": indata2})
results = session.run([], {"input1": indata1, "input2": indata2}) 

可以对比 onnx 模型结果与 pytorch 模型结果的差异来对转换结果进行验证。

参考资料

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2024-2-2,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 简介
  • 模型转换
    • 模型检查
    • 加载、运行 ONNX 模型
    • 参考资料
    领券
    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档