Loading [MathJax]/jax/output/CommonHTML/config.js
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >一文带你使用即时编译(JIT)提高 PyTorch 模型推理性能!

一文带你使用即时编译(JIT)提高 PyTorch 模型推理性能!

作者头像
OpenMMLab 官方账号
发布于 2023-08-25 03:55:28
发布于 2023-08-25 03:55:28
2.4K01
代码可运行
举报
文章被收录于专栏:OpenMMLabOpenMMLab
运行总次数:1
代码可运行

在之前的分享中,我们介绍了 torch jit 是如何通过 trace 转换模型使用 subgraph rewriter 优化计算图,以及如何使用 aliasDB 来避免别名造成的优化错误。通过这些步骤,由 Python 描述的模型变成了更适合部署的计算图。这次分享我们将目标转向运行时,看看 PyTorch 如何使用生成的计算图进行推理。

JIT

在正式开始之前,我们先复习一些编译原理的基本知识。编译器的工作是“翻译”,将人类可以看懂的“程序语言”翻译成计算机可以看懂的“机器语言”。

对于 C/C++ 之类的编程语言,我们会将所有的代码提前“翻译”成机器语言,这样在实际运行时就不会有额外的开销,可以全速进行推理。但是这样做也有一些缺点:

  • 编译花费的时间很长,对于需要频繁修改代码的场景,欠缺灵活性
  • 编译时无法感知运行环境,也就无法采取一些针对硬件环境的优化

另一种做法是使用解释器,不对代码进行提前编译,而是在运行时“边解释边执行”,代码更灵活,也可以尝试作一些针对当前环境的优化。这种语言被称为“解释性语言”,比如 Lisp、Pascal 等。当然代价通常就是更差的性能,毕竟“编译”需要占用运行时的时间,而且由于不能得到全部上下文,所以无法进行依赖上下文的优化。

那么如果“我都要”呢?既希望代码可以充分利用当前环境,又不希望牺牲性能的话,可以尝试第三种方案,也就是 JIT。JIT 是 just-in-time 的缩写,它不会将编译的过程一口气完成,而是先对代码进行一些处理,存储成某种序列化表示(比如计算图);然后在实际的运行时环境中,通过 profiling 的方式,进行针对环境的优化并执行代码。

torch jit 的名字就来源于此,PyTorch 使用 trace 或 script 之类的方法将模型转换成计算图,然后在运行时 "just in time" 的优化和执行推理过程。

一种常见的 JIT 实现方案是使用虚拟机来对代码(计算图)进行模拟执行。虚拟机会维护当前运行时状态、函数调用栈,每次函数调用时,就会创建一个帧(frame)来记录调用参数、程序计数器状态等等信息。如果有同学没有相关背景知识,觉得这样讲太抽象的话,可以想象这样一个场景:

当我在手机浏览器中发现一个知乎的链接后,点击会打开知乎APP,发现是一篇关于 OpenMMLab 的有趣的分享,于是我将内容通过分享按钮分享给微信好友。

这个过程牵扯到 3 个 APP,分别是“浏览器”、“知乎”和“微信”。假设每次打开应用都是一次函数调用,那么他们的调用栈就是:

上面每个 APP 的图标就表示一个帧 (Frame),这个帧包含自己当前的执行状态(比如浏览器打开了哪些标签,知乎正在看哪篇帖子)。当有新的函数调用发生时,就会向栈中填充一个新的帧,程序永远会执行栈顶的帧,保证打开 APP 的顺序正确,旧的帧则静静的躺在栈中,等待再次被唤醒。

当我完成分享并通过回退按钮返回浏览器时,调用栈的变化是:

每次返回都会弹出一个帧,弹出后的栈顶的帧就是之前执行的APP,帧中有APP 执行状态,可以恢复成之前执行的状态。

随着函数调用,栈会被不断填充,而返回时栈中对应元素会被弹出,这样就能够保证函数的执行顺序正确。torch jit 中也采用了相同的机制对推理时的状态进行模拟。

从 Python 到 C++

现在我们可以正式开始学习 torch jit 的运行时过程了。首先是要将 Python 的函数调用转换成 C++ 实现的推理实现。

torch jit 生成的计算图为 ScriptFunction 类型,当收到推理请求时,ScriptFunction 会通过 pybind11 将推理请求传递给 torch/csrc/jit/python/pybind_utils.h中的 createStackForSchema 函数。这个函数会把 Python 传入的 Tensor 参数转换成 C++ 使用的 IValue 对象,并且推入数据栈中。然后 runAndInsertCall 函数会将这个数据栈作为输入传递给后续的处理过程

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
inline py::object runAndInsertCall(
    Function& callee,
    const tuple_slice& args,
    const py::kwargs& kwargs,
    c10::optional<IValue> self,
    // Lambda that tells this function how to insert `callee` into the graph if
    // we're tracing.
    const std::function<Value*(Graph&, const MatchedSchema& match)>&
        callInserter) {
        ...
        // 创建输入参数栈
        auto stack =
            createStackForSchema(callee.getSchema(), args, kwargs, std::move(self));
        // 输入被调用者
        callee.run(stack);
        ...
}

从这里开始,这个输入栈会经过一系列函数调用:

  1. runAndInsertCall 会调用 Function 的 run() 。Function 对象是一个 GraphFunction 的实例,描述所有 torch jit 产生的计算图函数。
  2. Function的 run() 函数会调用 get_executor().run(stack) 。get_executor() 返回一个 GraphExecutor 的实例,负责计算图的优化以及 ExecutionPlan 的管理以及计算图推理。
  3. GraphExecutor 的 run() 函数会创建一个 ExecutionPlan 对象,然后调用InterpreterState(plan.code).run(stack) 。ExecutionPlan 会根据计算图生成指令序列,InterpreterState 会执行这个指令序列。
  4. InterpreterState 的 run() 函数会调用自身的 runImpl() 对计算图进行解释执行。
  5. InterpreterState 完成执行后,输出会被塞进数据栈中,一路返回给runAndInsertCall ,再通过 pybind11 成为 python 输出。

上面的过程中,GraphFunction 和 GraphExecutor 仅仅负责数据传递,比较重要的是生成 ExecutionPlan 和使用 InterpreterState 对模型进行推理。我们将分别介绍他们的实现细节。

ExecutionPlan

PyTorch 使用一个虚拟机来执行推理过程,这个虚拟机接收指令序列,并按顺序执行这个指令序列。ExecutionPlan 的功能就是维护计算图以及由计算图生成的指令序列:CodeImpl 对象。

CodeImpl 对象通过一个 Visitor Pattern 的访问器递归访问计算图中的所有节点,并生成对应的指令。很多 AST 解析工具都会采用类似的设计(比如 Python 的 ast.NodeVisitor),如果大家对此没有概念,可以想像一下下面的场景:

假设我在上海网购了北京的特产,那么对于物流公司的领导,他们可以设计这样的一条路线:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
// 我没有物流经验,不清楚这样的规划是否合理,这只是 visitor pattern 的例子!
visit北京to上海(){
     visit北京to天津()
     visit天津to山东()
     visit山东to江苏()
     visit江苏to上海()
}

领导通过调用其他的 visit 组合来完成北京到上海的 visit,不用去实现其中的细节。物流公司各个地方负责人会规划其中的一段路线,比如下面这样。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
visit天津to山东(){
    visit天津to沧州()
    visit沧州to济南()
}

类似的, Code 中对计算图中节点的代码生成也是这样,它实现了许多 emitXXX 函数来对计算图中的各种元素进行解析,解析的过程会用到其他的 emitXXX 函数。

比如当计算图中存在 If-Else 类型的节点时,会调用 emitIf 函数:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
  void emitCodeForBlock(Block* block) {
    // emit block 的外部参数节点
    emitNodeAtBlockLevel(block->param_node());
    // emit block 中的节点
    for (auto node : block->nodes()) {
      emitNodeAtBlockLevel(node);
    }
    // emit 返回节点
    emitNodeAtBlockLevel(block->return_node());
  }

emitNodeAtBlockLevel 使用 emitNodeAtBlockLevel 处理 block 中的各种 node。同样 emitNodeAtBlockLevel 也会嵌套使用其他的 emit 函数处理 node,这里就不继续展开了。

这种设计模式可以保证解析计算图中各种节点的逻辑清晰且简单,如果未来添加新的节点类型的话,只要添加对应的 emit 函数即可。通过对计算图的根元素(通常是一个block)进行一次 emit,就可以遍历整个计算图并生成一系列的 Instruction 指令对象,这些对象会被存储在 CodeImpl.instructions_ 中,供 InterpreterState 实现推理。

InterpreterState

Java 语言通过在虚拟机上执行 bytecode 来运行代码。InterpreterState 采用类似的策略,还记得我们之前复习的虚拟机的调用栈吗?虚拟机的主循环从当前栈顶的帧中提取指令,并根据指令类型不同采取不同的行动。

一个帧(Frame)包含很多当前函数的信息,其中最重要的是:

  • function:一个 CodeImpl 对象,帧所对应的指令序列(函数),由上面的 ExecutionPlan 生成。
  • pc:程序计数器,代表当前正在执行 function 中的第几条指令

当前正在执行的指令就是由 function 和 pc 组合访问得到。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
frame.function->instructions_[frame.pc += (X)]

InterpreterState 会启动一个循环,每一轮循环首先提取栈顶的帧(frame.back),然后根据上面的公式提取指令,再根据指令的不同执行不同的动作。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
//torch/csrc/jit/runtime/interpreter.cpp

#define INST_FETCH(X) (frame.function->instructions_[frame.pc += (X)])

  // 主循环,在 runImpl 中实现
  while (true) {
    // 提取当前帧
    Frame& frame = frames.back();
    // 从当前帧中,根据帧的 pc 以及 function 提取指令
    Instruction inst = INST_FETCH(0);
    // 不同的指令使用不同的动作
    switch (inst.op) {
        case ....
    }
    // 提取并跳转到下一条指令(pc+=1)
    INST_NEXT;

指令的种类很多,这里介绍一些比较常见的指令的处理方式。

首先介绍 OP 指令。OP 指令是 PyTorch 中绝大多数运算的指令类型,element wise 运算、卷积、矩阵乘都是这种指令类型。这种指令的处理方式很简单,根据指令的偏移量(inst.X)查询 op 表(operator_table_),传入数据栈(stack)执行即可。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
  case INST(OP): {
    INST_GUARD;
    frame.function->operator_table_[inst.X](stack);
  }
    INST_NEXT;
  case INST(OPN): {
    // OPN 为输入不定长的 OP
    INST_GUARD;
    stack.push_back(inst.N);
    frame.function->operator_table_[inst.X](stack);
  }
    INST_NEXT;

inst.X 和 operator_table_ 会在 ExecutionPlan 创建指令序列的时候被填充进指令。数据栈中存储着 OP 需要的参数,计算完成后输出也会被写回给数据栈,方便后续的 OP 使用。

通常虚拟机会按顺序执行当前帧中的所有指令,但是也存在指令可以修改执行顺序,比如 JMP 可以跳转到指定位置非顺序执行下一条指令。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
  // 跳转到 inst.X 指定的位置
  case INST(JMP): {
    INST_GUARD;
    inst = INST_FETCH(inst.X);
  }
    INST_DISPATCH;

当发生函数调用时,一个新的帧会被推入调用栈中,InterpreterState 的主循环会从这个新的帧中提取指令并执行;函数返回时,这个帧会被推出,重新执行之前的帧。就像从浏览器打开知乎 APP 后,按返回键可以回到浏览器一样,栈的后进先出特性可以保证函数执行的正确顺序。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
case INST(CALL): {
    // 函数调用
    INST_GUARD;
    // 查询需要调用的函数
    Function* fn = frame.function->function_table_[inst.X];
    // 在调用栈中创建对应的帧,分配一些帧需要的资源
    callFunction(*fn, stack);
    continue;
}

  case INST(RET): {
    // 函数返回
    if (frames.size() > 1) {
      // 将当前帧从调用栈中弹出
      leaveFrame();
      continue;
    }
    
    // 如果这是最后一个帧,意味着推理结束,可以结束主循环
    leaveFrame();
    return false;
  }

除了上面介绍的指令以外,InterpreterState 还存在大量的其他指令,比如带条件跳转 JF、循环指令 LOOP、加载与存储 LOAD 、STORE 等等,篇幅关系这里就不展开了。感兴趣的同学可以阅读 torch/csrc/jit/runtime/interpreter.cpp 了解它们的实现细节。

总结

通过将计算图转换成 Code 对象,torch jit 摆脱了笨重的 Python 运行时开销,并且给之后的模拟执行提供了指令序列;InterpreterState 解释执行指令序列,给 torch jit 提供了运行时支持。这样的设计在高级语言和 DSL 中很常见,结合代码生成技术可以让程序以非常高效的方式运行,未来我们会给大家带来更多相关技术的分享,敬请期待。

Reference

https://github.com/pytorch/pytorch

https://lernapparat.de/jit-runtime-overview

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2023-08-24,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 OpenMMLab 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
rbpf虚拟机-即时编译器(JIT)
本文记录的是基于 x86-64 架构的 eBPF(Extended Berkeley Packet Filter)即时编译器(JIT)。
盹猫
2025/07/22
700
一文读懂eBPF|即时编译(JIT)实现原理
在《eBPF实现原理》一文中,我们介绍了 eBPF 的实现原理,这篇文章我们主要介绍 eBPF 运行加速器 JIT(Just In Time)的实现原理。
用户7686797
2022/05/17
3.5K0
一文读懂eBPF|即时编译(JIT)实现原理
PyTorch 源码解读之即时编译篇
来源丨https://zhuanlan.zhihu.com/p/361101354
BBuf
2021/07/01
1.3K0
PyTorch 源码解读之即时编译篇
PyTorch 2.0 之 Dynamo: 窥探加速背后的真相
PyTorch 2.0 算是正式官宣了,预计在明年 3 月和大家见面。官方的 blog 宣发了非常多的内容,但是阅读下来不难发现,几乎所有的性能提升、体验优化都源自于 PyTorch 新设计的即时编译工具:Dynamo。
OpenMMLab 官方账号
2022/12/30
2.8K0
PyTorch 2.0 之 Dynamo: 窥探加速背后的真相
Rc-lang开发周记4 函数其一
这里我一开始没想好怎么做的,所以会做的很诡异,最大的原因是静态类型语言和动态类型语言是不同的。由于我只对动态语言有一些了解,这里暂时只提动态语言的一些点
AkemiHomura
2023/04/07
3170
太强了!鹅厂程序员“自研”脚本语言 eben
计算机科学家 David Wheeler 有一句名言,“All problems in computer science can be solved by another level of indirection, except for the problem of too many layers of indirection.”。大意是指,计算机科学中所有问题都可以通过多一个间接层来解决,除了间接层太多这个问题本身。
腾讯云开发者
2023/10/19
1.3K0
太强了!鹅厂程序员“自研”脚本语言 eben
TorchScript 系列解读 (二):Torch jit tracer 实现解析
小伙伴们好呀,TorchScript 解读系列教程更新啦~在上篇文章中,我们带领大家初步了解了 TorchScript。
OpenMMLab 官方账号
2022/04/09
1.8K0
TorchScript 系列解读 (二):Torch jit tracer 实现解析
堆、栈、方法区到底是什么?一文带你搞懂 JVM 运行时数据区内存模型!
在 JVM 的世界中,运行时数据区域是整个虚拟机的基础,它决定了程序的内存管理、线程的执行流以及垃圾回收的核心逻辑。
码哥字节
2024/11/26
8060
堆、栈、方法区到底是什么?一文带你搞懂 JVM 运行时数据区内存模型!
idapython使用笔记
You have chosen to enable IDAPython 2. The IDAPython 3 plugins have been renamed to idapython.3.disabled and idapython64.3.disabled in the plugins subdirectory. If you want to switch to IDAPython 3, proceed as follows:
用户1879329
2023/02/27
1.5K0
idapython使用笔记
eBPF原理介绍与编程实践
注:本文包括了ebpf的原理介绍、流程分析、相关资料链接、工具编写实战等,可以选择感兴趣的部分直接阅读;鉴于作者语文水平有限,很多地方描述可能不清楚,有错误或疑问欢迎指出交流
johnazhang
2022/07/18
2.8K1
用 go 实现 lua 虚拟机
下面依次介绍上面的一些步骤,本文旨在一篇文章写清楚大概流程,具体的细节将会忽略,实际的实现也会尽可能的简化,本文主要参考 自己动手实现 lua,和 gopher-lua
王磊-字节跳动
2020/12/27
2.2K0
Rc-lang开发周记2 VM相关
本周主要先对tac的函数进行了简单的测试,以确保能够正确运行我的vm demo,修正了function的一些问题,之后就是处理对vm指令的生成,处理了一下符号相关的信息,还做了一点函数的相关的以及生成C++的解析代码(都没做完,还是下周吧
AkemiHomura
2023/04/07
5260
用沐神的方法阅读PyTorch FX论文
【GiantPandaCV导语】torch.fx对于PyTorch来说确实是一个比较好的工作,因为它消除了一些动态图和静态图的Gap。比如在图改写方面,torch.fx让PyTorch想做一些其它静态图框架的算子融合优化非常容易。并且torch.fx让后训练量化和感知训练量化以及AMP等的实现难度大大降低,这得益于我们可以直接在Python层操作这个IR,所以我认为这是一个不错的工作。尤其是对使用PyTorch开发的算法工程师来说,现在可以基于这个特性大开脑洞了。我之前围绕FX也做了一个QAT的工作,感兴趣可以阅读:基于OneFlow实现量化感知训练。torch.fx的卖点就是,它使用纯Python语言实现了一个可以捕获PyTorch程序的计算图并转化为一个IR的库,并且非常方便的在这个IR上做Pass,同时提供将变换后的IR Codegen合法的Python代码功能。我觉得算是达到了在Eager下写Pass就像做链表插入删除题目一样顺滑。
BBuf
2021/12/27
9150
用沐神的方法阅读PyTorch FX论文
【精通 JVM 原理】浅析 JavaAgent & Instrumentation 机制
1、JVM的字节码指令,方法调用机制 2、Java类加载器 3、JavaAgent 4、Java Instrumentation
一个会写诗的程序员
2021/04/04
1.5K0
【AI系统】AI编译器前瞻
本文首先会基于 The Deep Learning Compiler: A Comprehensive Survey 中的调研做一个热门 AI 编译器的横向对比,并简要介绍几个当前常用的 AI 编译器。随后会分析当前 AI 编译器面临的诸多挑战,并展望 AI 编译器的未来。
用户11307734
2024/11/28
4750
一文带你学明白java虚拟机:C1编译器,HIR代码优化
为了减少编译时间,C1在抽象解释生成HIR期间,每生成一条SSA指令,都会调用append_with_bci努力尝试若干局部优化。除此之外,HIR构造完成之后,C1还会执行若干轻量级全局优化。本节将详细描述这些优化的执行过程。这些优化都位于build_hir()。
愿天堂没有BUG
2022/10/28
9870
一文带你学明白java虚拟机:C1编译器,HIR代码优化
《Python 源码剖析》一些理解以及勘误笔记(1)
以下是本人阅读此书时理解的一些笔记,包含一些影响文义的笔误修正,当然不一定正确,贴出来一起讨论。 注:此书剖析的源码是2.5版本,在python.org 可以找到源码。纸质书阅读,pdf 贴图。 文章
s1mba
2017/12/26
1.1K0
《Python 源码剖析》一些理解以及勘误笔记(1)
浅谈机器学习模型推理性能优化
在机器学习领域,清晰明了的数据预处理和表现优异的模型往往是数据科学家关注的重点,而实际生产中如何让模型落地、工程化也同样值得关注,工程化机器学习模型避不开的一个难点就是模型的推理(Inference / Serving)性能优化。
ThoughtWorks
2021/01/08
1.3K0
Android虚拟机的JIT编译器
最近参加了华为方舟的Workshop,从编译到Runtime都有了一些体会,并且对于虚拟机的运行也有了一些了解。
None_Ling
2019/06/14
1.6K0
Android虚拟机的JIT编译器
一文教你搞懂 Go 中栈操作
多任务操作系统中的每个进程都在自己的内存沙盒中运行。在32位模式下,它总是4GB内存地址空间,内存分配是分配虚拟内存给进程,当进程真正访问某一虚拟内存地址时,操作系统通过触发缺页中断,在物理内存上分配一段相应的空间再与之建立映射关系,这样进程访问的虚拟内存地址,会被自动转换变成有效物理内存地址,便可以进行数据的存储与访问了。
luozhiyun
2021/04/05
1.1K0
相关推荐
rbpf虚拟机-即时编译器(JIT)
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验