在之前的分享中,我们介绍了 torch jit 是如何通过 trace 转换模型,使用 subgraph rewriter 优化计算图,以及如何使用 aliasDB 来避免别名造成的优化错误。通过这些步骤,由 Python 描述的模型变成了更适合部署的计算图。这次分享我们将目标转向运行时,看看 PyTorch 如何使用生成的计算图进行推理。
JIT
在正式开始之前,我们先复习一些编译原理的基本知识。编译器的工作是“翻译”,将人类可以看懂的“程序语言”翻译成计算机可以看懂的“机器语言”。
对于 C/C++ 之类的编程语言,我们会将所有的代码提前“翻译”成机器语言,这样在实际运行时就不会有额外的开销,可以全速进行推理。但是这样做也有一些缺点:
另一种做法是使用解释器,不对代码进行提前编译,而是在运行时“边解释边执行”,代码更灵活,也可以尝试作一些针对当前环境的优化。这种语言被称为“解释性语言”,比如 Lisp、Pascal 等。当然代价通常就是更差的性能,毕竟“编译”需要占用运行时的时间,而且由于不能得到全部上下文,所以无法进行依赖上下文的优化。
那么如果“我都要”呢?既希望代码可以充分利用当前环境,又不希望牺牲性能的话,可以尝试第三种方案,也就是 JIT。JIT 是 just-in-time 的缩写,它不会将编译的过程一口气完成,而是先对代码进行一些处理,存储成某种序列化表示(比如计算图);然后在实际的运行时环境中,通过 profiling 的方式,进行针对环境的优化并执行代码。
torch jit 的名字就来源于此,PyTorch 使用 trace 或 script 之类的方法将模型转换成计算图,然后在运行时 "just in time" 的优化和执行推理过程。
一种常见的 JIT 实现方案是使用虚拟机来对代码(计算图)进行模拟执行。虚拟机会维护当前运行时状态、函数调用栈,每次函数调用时,就会创建一个帧(frame)来记录调用参数、程序计数器状态等等信息。如果有同学没有相关背景知识,觉得这样讲太抽象的话,可以想象这样一个场景:
当我在手机浏览器中发现一个知乎的链接后,点击会打开知乎APP,发现是一篇关于 OpenMMLab 的有趣的分享,于是我将内容通过分享按钮分享给微信好友。
这个过程牵扯到 3 个 APP,分别是“浏览器”、“知乎”和“微信”。假设每次打开应用都是一次函数调用,那么他们的调用栈就是:
上面每个 APP 的图标就表示一个帧 (Frame),这个帧包含自己当前的执行状态(比如浏览器打开了哪些标签,知乎正在看哪篇帖子)。当有新的函数调用发生时,就会向栈中填充一个新的帧,程序永远会执行栈顶的帧,保证打开 APP 的顺序正确,旧的帧则静静的躺在栈中,等待再次被唤醒。
当我完成分享并通过回退按钮返回浏览器时,调用栈的变化是:
每次返回都会弹出一个帧,弹出后的栈顶的帧就是之前执行的APP,帧中有APP 执行状态,可以恢复成之前执行的状态。
随着函数调用,栈会被不断填充,而返回时栈中对应元素会被弹出,这样就能够保证函数的执行顺序正确。torch jit 中也采用了相同的机制对推理时的状态进行模拟。
从 Python 到 C++
现在我们可以正式开始学习 torch jit 的运行时过程了。首先是要将 Python 的函数调用转换成 C++ 实现的推理实现。
torch jit 生成的计算图为 ScriptFunction 类型,当收到推理请求时,ScriptFunction 会通过 pybind11 将推理请求传递给 torch/csrc/jit/python/pybind_utils.h中的 createStackForSchema 函数。这个函数会把 Python 传入的 Tensor 参数转换成 C++ 使用的 IValue 对象,并且推入数据栈中。然后 runAndInsertCall 函数会将这个数据栈作为输入传递给后续的处理过程
inline py::object runAndInsertCall(
Function& callee,
const tuple_slice& args,
const py::kwargs& kwargs,
c10::optional<IValue> self,
// Lambda that tells this function how to insert `callee` into the graph if
// we're tracing.
const std::function<Value*(Graph&, const MatchedSchema& match)>&
callInserter) {
...
// 创建输入参数栈
auto stack =
createStackForSchema(callee.getSchema(), args, kwargs, std::move(self));
// 输入被调用者
callee.run(stack);
...
}
从这里开始,这个输入栈会经过一系列函数调用:
上面的过程中,GraphFunction 和 GraphExecutor 仅仅负责数据传递,比较重要的是生成 ExecutionPlan 和使用 InterpreterState 对模型进行推理。我们将分别介绍他们的实现细节。
ExecutionPlan
PyTorch 使用一个虚拟机来执行推理过程,这个虚拟机接收指令序列,并按顺序执行这个指令序列。ExecutionPlan 的功能就是维护计算图以及由计算图生成的指令序列:CodeImpl 对象。
CodeImpl 对象通过一个 Visitor Pattern 的访问器递归访问计算图中的所有节点,并生成对应的指令。很多 AST 解析工具都会采用类似的设计(比如 Python 的 ast.NodeVisitor),如果大家对此没有概念,可以想像一下下面的场景:
假设我在上海网购了北京的特产,那么对于物流公司的领导,他们可以设计这样的一条路线:
// 我没有物流经验,不清楚这样的规划是否合理,这只是 visitor pattern 的例子!
visit北京to上海(){
visit北京to天津()
visit天津to山东()
visit山东to江苏()
visit江苏to上海()
}
领导通过调用其他的 visit 组合来完成北京到上海的 visit,不用去实现其中的细节。物流公司各个地方负责人会规划其中的一段路线,比如下面这样。
visit天津to山东(){
visit天津to沧州()
visit沧州to济南()
}
类似的, Code 中对计算图中节点的代码生成也是这样,它实现了许多 emitXXX 函数来对计算图中的各种元素进行解析,解析的过程会用到其他的 emitXXX 函数。
比如当计算图中存在 If-Else 类型的节点时,会调用 emitIf 函数:
void emitCodeForBlock(Block* block) {
// emit block 的外部参数节点
emitNodeAtBlockLevel(block->param_node());
// emit block 中的节点
for (auto node : block->nodes()) {
emitNodeAtBlockLevel(node);
}
// emit 返回节点
emitNodeAtBlockLevel(block->return_node());
}
emitNodeAtBlockLevel 使用 emitNodeAtBlockLevel 处理 block 中的各种 node。同样 emitNodeAtBlockLevel 也会嵌套使用其他的 emit 函数处理 node,这里就不继续展开了。
这种设计模式可以保证解析计算图中各种节点的逻辑清晰且简单,如果未来添加新的节点类型的话,只要添加对应的 emit 函数即可。通过对计算图的根元素(通常是一个block)进行一次 emit,就可以遍历整个计算图并生成一系列的 Instruction 指令对象,这些对象会被存储在 CodeImpl.instructions_ 中,供 InterpreterState 实现推理。
InterpreterState
Java 语言通过在虚拟机上执行 bytecode 来运行代码。InterpreterState 采用类似的策略,还记得我们之前复习的虚拟机的调用栈吗?虚拟机的主循环从当前栈顶的帧中提取指令,并根据指令类型不同采取不同的行动。
一个帧(Frame)包含很多当前函数的信息,其中最重要的是:
当前正在执行的指令就是由 function 和 pc 组合访问得到。
frame.function->instructions_[frame.pc += (X)]
InterpreterState 会启动一个循环,每一轮循环首先提取栈顶的帧(frame.back),然后根据上面的公式提取指令,再根据指令的不同执行不同的动作。
//torch/csrc/jit/runtime/interpreter.cpp
#define INST_FETCH(X) (frame.function->instructions_[frame.pc += (X)])
// 主循环,在 runImpl 中实现
while (true) {
// 提取当前帧
Frame& frame = frames.back();
// 从当前帧中,根据帧的 pc 以及 function 提取指令
Instruction inst = INST_FETCH(0);
// 不同的指令使用不同的动作
switch (inst.op) {
case ....
}
// 提取并跳转到下一条指令(pc+=1)
INST_NEXT;
指令的种类很多,这里介绍一些比较常见的指令的处理方式。
首先介绍 OP 指令。OP 指令是 PyTorch 中绝大多数运算的指令类型,element wise 运算、卷积、矩阵乘都是这种指令类型。这种指令的处理方式很简单,根据指令的偏移量(inst.X)查询 op 表(operator_table_),传入数据栈(stack)执行即可。
case INST(OP): {
INST_GUARD;
frame.function->operator_table_[inst.X](stack);
}
INST_NEXT;
case INST(OPN): {
// OPN 为输入不定长的 OP
INST_GUARD;
stack.push_back(inst.N);
frame.function->operator_table_[inst.X](stack);
}
INST_NEXT;
inst.X 和 operator_table_ 会在 ExecutionPlan 创建指令序列的时候被填充进指令。数据栈中存储着 OP 需要的参数,计算完成后输出也会被写回给数据栈,方便后续的 OP 使用。
通常虚拟机会按顺序执行当前帧中的所有指令,但是也存在指令可以修改执行顺序,比如 JMP 可以跳转到指定位置非顺序执行下一条指令。
// 跳转到 inst.X 指定的位置
case INST(JMP): {
INST_GUARD;
inst = INST_FETCH(inst.X);
}
INST_DISPATCH;
当发生函数调用时,一个新的帧会被推入调用栈中,InterpreterState 的主循环会从这个新的帧中提取指令并执行;函数返回时,这个帧会被推出,重新执行之前的帧。就像从浏览器打开知乎 APP 后,按返回键可以回到浏览器一样,栈的后进先出特性可以保证函数执行的正确顺序。
case INST(CALL): {
// 函数调用
INST_GUARD;
// 查询需要调用的函数
Function* fn = frame.function->function_table_[inst.X];
// 在调用栈中创建对应的帧,分配一些帧需要的资源
callFunction(*fn, stack);
continue;
}
case INST(RET): {
// 函数返回
if (frames.size() > 1) {
// 将当前帧从调用栈中弹出
leaveFrame();
continue;
}
// 如果这是最后一个帧,意味着推理结束,可以结束主循环
leaveFrame();
return false;
}
除了上面介绍的指令以外,InterpreterState 还存在大量的其他指令,比如带条件跳转 JF、循环指令 LOOP、加载与存储 LOAD 、STORE 等等,篇幅关系这里就不展开了。感兴趣的同学可以阅读 torch/csrc/jit/runtime/interpreter.cpp 了解它们的实现细节。
总结
通过将计算图转换成 Code 对象,torch jit 摆脱了笨重的 Python 运行时开销,并且给之后的模拟执行提供了指令序列;InterpreterState 解释执行指令序列,给 torch jit 提供了运行时支持。这样的设计在高级语言和 DSL 中很常见,结合代码生成技术可以让程序以非常高效的方式运行,未来我们会给大家带来更多相关技术的分享,敬请期待。
Reference:
https://github.com/pytorch/pytorch
https://lernapparat.de/jit-runtime-overview