前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >一文带你使用即时编译(JIT)提高 PyTorch 模型推理性能!

一文带你使用即时编译(JIT)提高 PyTorch 模型推理性能!

作者头像
OpenMMLab 官方账号
发布2023-08-25 11:55:28
1K0
发布2023-08-25 11:55:28
举报
文章被收录于专栏:OpenMMLabOpenMMLab

在之前的分享中,我们介绍了 torch jit 是如何通过 trace 转换模型使用 subgraph rewriter 优化计算图,以及如何使用 aliasDB 来避免别名造成的优化错误。通过这些步骤,由 Python 描述的模型变成了更适合部署的计算图。这次分享我们将目标转向运行时,看看 PyTorch 如何使用生成的计算图进行推理。

JIT

在正式开始之前,我们先复习一些编译原理的基本知识。编译器的工作是“翻译”,将人类可以看懂的“程序语言”翻译成计算机可以看懂的“机器语言”。

对于 C/C++ 之类的编程语言,我们会将所有的代码提前“翻译”成机器语言,这样在实际运行时就不会有额外的开销,可以全速进行推理。但是这样做也有一些缺点:

  • 编译花费的时间很长,对于需要频繁修改代码的场景,欠缺灵活性
  • 编译时无法感知运行环境,也就无法采取一些针对硬件环境的优化

另一种做法是使用解释器,不对代码进行提前编译,而是在运行时“边解释边执行”,代码更灵活,也可以尝试作一些针对当前环境的优化。这种语言被称为“解释性语言”,比如 Lisp、Pascal 等。当然代价通常就是更差的性能,毕竟“编译”需要占用运行时的时间,而且由于不能得到全部上下文,所以无法进行依赖上下文的优化。

那么如果“我都要”呢?既希望代码可以充分利用当前环境,又不希望牺牲性能的话,可以尝试第三种方案,也就是 JIT。JIT 是 just-in-time 的缩写,它不会将编译的过程一口气完成,而是先对代码进行一些处理,存储成某种序列化表示(比如计算图);然后在实际的运行时环境中,通过 profiling 的方式,进行针对环境的优化并执行代码。

torch jit 的名字就来源于此,PyTorch 使用 trace 或 script 之类的方法将模型转换成计算图,然后在运行时 "just in time" 的优化和执行推理过程。

一种常见的 JIT 实现方案是使用虚拟机来对代码(计算图)进行模拟执行。虚拟机会维护当前运行时状态、函数调用栈,每次函数调用时,就会创建一个帧(frame)来记录调用参数、程序计数器状态等等信息。如果有同学没有相关背景知识,觉得这样讲太抽象的话,可以想象这样一个场景:

当我在手机浏览器中发现一个知乎的链接后,点击会打开知乎APP,发现是一篇关于 OpenMMLab 的有趣的分享,于是我将内容通过分享按钮分享给微信好友。

这个过程牵扯到 3 个 APP,分别是“浏览器”、“知乎”和“微信”。假设每次打开应用都是一次函数调用,那么他们的调用栈就是:

上面每个 APP 的图标就表示一个帧 (Frame),这个帧包含自己当前的执行状态(比如浏览器打开了哪些标签,知乎正在看哪篇帖子)。当有新的函数调用发生时,就会向栈中填充一个新的帧,程序永远会执行栈顶的帧,保证打开 APP 的顺序正确,旧的帧则静静的躺在栈中,等待再次被唤醒。

当我完成分享并通过回退按钮返回浏览器时,调用栈的变化是:

每次返回都会弹出一个帧,弹出后的栈顶的帧就是之前执行的APP,帧中有APP 执行状态,可以恢复成之前执行的状态。

随着函数调用,栈会被不断填充,而返回时栈中对应元素会被弹出,这样就能够保证函数的执行顺序正确。torch jit 中也采用了相同的机制对推理时的状态进行模拟。

从 Python 到 C++

现在我们可以正式开始学习 torch jit 的运行时过程了。首先是要将 Python 的函数调用转换成 C++ 实现的推理实现。

torch jit 生成的计算图为 ScriptFunction 类型,当收到推理请求时,ScriptFunction 会通过 pybind11 将推理请求传递给 torch/csrc/jit/python/pybind_utils.h中的 createStackForSchema 函数。这个函数会把 Python 传入的 Tensor 参数转换成 C++ 使用的 IValue 对象,并且推入数据栈中。然后 runAndInsertCall 函数会将这个数据栈作为输入传递给后续的处理过程

代码语言:javascript
复制
inline py::object runAndInsertCall(
    Function& callee,
    const tuple_slice& args,
    const py::kwargs& kwargs,
    c10::optional<IValue> self,
    // Lambda that tells this function how to insert `callee` into the graph if
    // we're tracing.
    const std::function<Value*(Graph&, const MatchedSchema& match)>&
        callInserter) {
        ...
        // 创建输入参数栈
        auto stack =
            createStackForSchema(callee.getSchema(), args, kwargs, std::move(self));
        // 输入被调用者
        callee.run(stack);
        ...
}

从这里开始,这个输入栈会经过一系列函数调用:

  1. runAndInsertCall 会调用 Function 的 run() 。Function 对象是一个 GraphFunction 的实例,描述所有 torch jit 产生的计算图函数。
  2. Function的 run() 函数会调用 get_executor().run(stack) 。get_executor() 返回一个 GraphExecutor 的实例,负责计算图的优化以及 ExecutionPlan 的管理以及计算图推理。
  3. GraphExecutor 的 run() 函数会创建一个 ExecutionPlan 对象,然后调用InterpreterState(plan.code).run(stack) 。ExecutionPlan 会根据计算图生成指令序列,InterpreterState 会执行这个指令序列。
  4. InterpreterState 的 run() 函数会调用自身的 runImpl() 对计算图进行解释执行。
  5. InterpreterState 完成执行后,输出会被塞进数据栈中,一路返回给runAndInsertCall ,再通过 pybind11 成为 python 输出。

上面的过程中,GraphFunction 和 GraphExecutor 仅仅负责数据传递,比较重要的是生成 ExecutionPlan 和使用 InterpreterState 对模型进行推理。我们将分别介绍他们的实现细节。

ExecutionPlan

PyTorch 使用一个虚拟机来执行推理过程,这个虚拟机接收指令序列,并按顺序执行这个指令序列。ExecutionPlan 的功能就是维护计算图以及由计算图生成的指令序列:CodeImpl 对象。

CodeImpl 对象通过一个 Visitor Pattern 的访问器递归访问计算图中的所有节点,并生成对应的指令。很多 AST 解析工具都会采用类似的设计(比如 Python 的 ast.NodeVisitor),如果大家对此没有概念,可以想像一下下面的场景:

假设我在上海网购了北京的特产,那么对于物流公司的领导,他们可以设计这样的一条路线:

代码语言:javascript
复制
// 我没有物流经验,不清楚这样的规划是否合理,这只是 visitor pattern 的例子!
visit北京to上海(){
     visit北京to天津()
     visit天津to山东()
     visit山东to江苏()
     visit江苏to上海()
}

领导通过调用其他的 visit 组合来完成北京到上海的 visit,不用去实现其中的细节。物流公司各个地方负责人会规划其中的一段路线,比如下面这样。

代码语言:javascript
复制
visit天津to山东(){
    visit天津to沧州()
    visit沧州to济南()
}

类似的, Code 中对计算图中节点的代码生成也是这样,它实现了许多 emitXXX 函数来对计算图中的各种元素进行解析,解析的过程会用到其他的 emitXXX 函数。

比如当计算图中存在 If-Else 类型的节点时,会调用 emitIf 函数:

代码语言:javascript
复制
  void emitCodeForBlock(Block* block) {
    // emit block 的外部参数节点
    emitNodeAtBlockLevel(block->param_node());
    // emit block 中的节点
    for (auto node : block->nodes()) {
      emitNodeAtBlockLevel(node);
    }
    // emit 返回节点
    emitNodeAtBlockLevel(block->return_node());
  }

emitNodeAtBlockLevel 使用 emitNodeAtBlockLevel 处理 block 中的各种 node。同样 emitNodeAtBlockLevel 也会嵌套使用其他的 emit 函数处理 node,这里就不继续展开了。

这种设计模式可以保证解析计算图中各种节点的逻辑清晰且简单,如果未来添加新的节点类型的话,只要添加对应的 emit 函数即可。通过对计算图的根元素(通常是一个block)进行一次 emit,就可以遍历整个计算图并生成一系列的 Instruction 指令对象,这些对象会被存储在 CodeImpl.instructions_ 中,供 InterpreterState 实现推理。

InterpreterState

Java 语言通过在虚拟机上执行 bytecode 来运行代码。InterpreterState 采用类似的策略,还记得我们之前复习的虚拟机的调用栈吗?虚拟机的主循环从当前栈顶的帧中提取指令,并根据指令类型不同采取不同的行动。

一个帧(Frame)包含很多当前函数的信息,其中最重要的是:

  • function:一个 CodeImpl 对象,帧所对应的指令序列(函数),由上面的 ExecutionPlan 生成。
  • pc:程序计数器,代表当前正在执行 function 中的第几条指令

当前正在执行的指令就是由 function 和 pc 组合访问得到。

代码语言:javascript
复制
frame.function->instructions_[frame.pc += (X)]

InterpreterState 会启动一个循环,每一轮循环首先提取栈顶的帧(frame.back),然后根据上面的公式提取指令,再根据指令的不同执行不同的动作。

代码语言:javascript
复制
//torch/csrc/jit/runtime/interpreter.cpp

#define INST_FETCH(X) (frame.function->instructions_[frame.pc += (X)])

  // 主循环,在 runImpl 中实现
  while (true) {
    // 提取当前帧
    Frame& frame = frames.back();
    // 从当前帧中,根据帧的 pc 以及 function 提取指令
    Instruction inst = INST_FETCH(0);
    // 不同的指令使用不同的动作
    switch (inst.op) {
        case ....
    }
    // 提取并跳转到下一条指令(pc+=1)
    INST_NEXT;

指令的种类很多,这里介绍一些比较常见的指令的处理方式。

首先介绍 OP 指令。OP 指令是 PyTorch 中绝大多数运算的指令类型,element wise 运算、卷积、矩阵乘都是这种指令类型。这种指令的处理方式很简单,根据指令的偏移量(inst.X)查询 op 表(operator_table_),传入数据栈(stack)执行即可。

代码语言:javascript
复制
  case INST(OP): {
    INST_GUARD;
    frame.function->operator_table_[inst.X](stack);
  }
    INST_NEXT;
  case INST(OPN): {
    // OPN 为输入不定长的 OP
    INST_GUARD;
    stack.push_back(inst.N);
    frame.function->operator_table_[inst.X](stack);
  }
    INST_NEXT;

inst.X 和 operator_table_ 会在 ExecutionPlan 创建指令序列的时候被填充进指令。数据栈中存储着 OP 需要的参数,计算完成后输出也会被写回给数据栈,方便后续的 OP 使用。

通常虚拟机会按顺序执行当前帧中的所有指令,但是也存在指令可以修改执行顺序,比如 JMP 可以跳转到指定位置非顺序执行下一条指令。

代码语言:javascript
复制
  // 跳转到 inst.X 指定的位置
  case INST(JMP): {
    INST_GUARD;
    inst = INST_FETCH(inst.X);
  }
    INST_DISPATCH;

当发生函数调用时,一个新的帧会被推入调用栈中,InterpreterState 的主循环会从这个新的帧中提取指令并执行;函数返回时,这个帧会被推出,重新执行之前的帧。就像从浏览器打开知乎 APP 后,按返回键可以回到浏览器一样,栈的后进先出特性可以保证函数执行的正确顺序。

代码语言:javascript
复制
case INST(CALL): {
    // 函数调用
    INST_GUARD;
    // 查询需要调用的函数
    Function* fn = frame.function->function_table_[inst.X];
    // 在调用栈中创建对应的帧,分配一些帧需要的资源
    callFunction(*fn, stack);
    continue;
}

  case INST(RET): {
    // 函数返回
    if (frames.size() > 1) {
      // 将当前帧从调用栈中弹出
      leaveFrame();
      continue;
    }
    
    // 如果这是最后一个帧,意味着推理结束,可以结束主循环
    leaveFrame();
    return false;
  }

除了上面介绍的指令以外,InterpreterState 还存在大量的其他指令,比如带条件跳转 JF、循环指令 LOOP、加载与存储 LOAD 、STORE 等等,篇幅关系这里就不展开了。感兴趣的同学可以阅读 torch/csrc/jit/runtime/interpreter.cpp 了解它们的实现细节。

总结

通过将计算图转换成 Code 对象,torch jit 摆脱了笨重的 Python 运行时开销,并且给之后的模拟执行提供了指令序列;InterpreterState 解释执行指令序列,给 torch jit 提供了运行时支持。这样的设计在高级语言和 DSL 中很常见,结合代码生成技术可以让程序以非常高效的方式运行,未来我们会给大家带来更多相关技术的分享,敬请期待。

Reference

https://github.com/pytorch/pytorch

https://lernapparat.de/jit-runtime-overview

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2023-08-24,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 OpenMMLab 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档