前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >TorchScript 系列解读 (二):Torch jit tracer 实现解析

TorchScript 系列解读 (二):Torch jit tracer 实现解析

作者头像
OpenMMLab 官方账号
发布2022-04-09 16:31:21
1.6K0
发布2022-04-09 16:31:21
举报
文章被收录于专栏:OpenMMLab

小伙伴们好呀,TorchScript 解读系列教程更新啦~在上篇文章中,我们带领大家初步了解了 TorchScript。

TorchScript 是 PyTorch 提供的模型序列化以及部署方案,可以弥补 PyTorch 难于部署的缺点,也可以轻松实现图优化或后端对接。TorchScript 支持通过 trace 来记录数据流的生成方式;也支持解析 AST 直接生成图的 script 方式。

今天我们将介绍 TorchScript 通过 trace 来记录数据流的生成方式,同时还将分享使用该机制实现的 ONNX 导出过程。接下来,就让我们进入今天的正题吧~

基本概念

首先来看一下同一个模型的三种不同表述,为了方便展示各种 jit 的组件,这里会使用 script 方式创建图:

代码

代码语言:javascript
复制
def forward(self, x):
    x = x * 2
    x.add_(0)
    x = x.view(-1)
    if x[0] > 1:
        return x[0]
    else:
        return x[-1]

TorchScript Graph

代码语言:javascript
复制
graph(%self : __torch__.TestModel,
      %x.1 : Tensor):
  %12 : int = prim::Constant[value=-1]() # graph_example.py:12:19
  %3 : int = prim::Constant[value=2]() # graph_example.py:10:16
  %6 : int = prim::Constant[value=0]() # graph_example.py:11:15
  %10 : int = prim::Constant[value=1]() # graph_example.py:12:20
  %x.3 : Tensor = aten::mul(%x.1, %3) # graph_example.py:10:12
  %8 : Tensor = aten::add_(%x.3, %6, %10) # graph_example.py:11:8
  %13 : int[] = prim::ListConstruct(%12)
  %x.6 : Tensor = aten::view(%x.3, %13) # graph_example.py:12:12
  %17 : Tensor = aten::select(%x.6, %6, %6) # graph_example.py:13:11
  %18 : Tensor = aten::gt(%17, %10) # graph_example.py:13:11
  %20 : bool = aten::Bool(%18) # graph_example.py:13:11
  %41 : Tensor = prim::If(%20) # graph_example.py:13:8
    block0():
      %23 : Tensor = aten::select(%x.6, %6, %6) # graph_example.py:14:19
      -> (%23)
    block1():
      %32 : Tensor = aten::select(%x.6, %6, %12) # graph_example.py:16:19
      -> (%32)
  return (%41)

netron

以上中间的部分就是 TorchScript 模型的可视化结果,其中包含如下一些元素:

Graph

表格中 Graph 列整体用来表示一个 Graph,它有如下性质:

· Graph 用来表示一个“函数”,一个 Module 中的不同函数(比如 forward 等)会被转换成不同的 Graph。

· Graph 拥有许多的 Node,这些 Node 由一个 Block 管理。所有 Node 组织成双向链表的形式,方便插入删除,其中返回值节点 “Return Node” 会作为这个双向链表的“哨兵”。双向链表通常会被拓扑排序,保证执行的正确性。

Node

表格中 Graph 列里 3~14 行,以及 16 和 19 行表示各个Node,一个 Node 对应一个操作。操作的输入为 Value,少数情况下还会有一些 static attribute。Node 中包含很多信息,包括:

· kind() 表示 Node 的操作类型,上图中的 aten::mul 和 prim::ListConstruct 等都是对应 Node 的 kind。注意它只是个字符串,因此修改这个字符串也就意味着修改了操作。

· FunctionSchema 指对这个函数的接口的描述,格式看起来就类似 ops 函数的声明,另外可以添加一些标记表示某个 Tensor 是否是另一个 Tensor 的 Alias 等等(别名分析是保证优化结果正确的依据),可以作为 peelhole-optimize 的时候的检索依据。以 Tensor.add_ 函数为例:

代码语言:javascript
复制
// add_是一个inplace运算,因此输出和self共享相同的内存空间
// FunctionSchema中标注了这种别名关系,保证了输出的正确性
// netron的可视化似乎不会进行alias analysis?因此上面右图的可视化中,add_的部分存在错误
"add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)"

· 常用的函数的 schema 可以在aten/src/ATen/native/native_functions.yaml 中查看。

Block

Block 表示一个 Node 的有序列表,代表输入的 Node 的 kind=Param,代表输出的 Node 的 kind=Return。

实际上 Graph 本身隐含一个 root Block 对象,用来管理所有的 Node。部分 Node 可能还会存在 sub Block。比如表中的 Graph 就有 3 个 Block,一个是 Graph 隐含的 root Block,另两个是 prim::If Node 的 sub Block。

Block 的概念可能源于编译原理中的基本块。所谓基本块就是一系列不包含任何跳转指令的指令序列,由于基本块内的内容可以保证是顺序执行的,因此很多的优化都会以基本块作为前提。实际上 PyTorch 中对中间表示(IR)的优化有非常多是 Block 级别的。

Value

Value 是 Node 的输入输出,可以是 Tensor 也可以是容器或其他类型,可以通过 type() 判断。

Value 对象维护了一个 use_list,只要这个 Value 成为某个 Node 的输入,那么这个 Node 就要加入到它的 use_list 中。通过这个 use_list,可以很方便地解决新加入的 Node 与其他 Node 的输入输出关系。

注意:Value 是用来表述 Graph 的结构的,与 Runtime 无关!真正在推理时用到的是 IValue 对象,IValue 中有运行时的真实数据。

Pass

严格地说这不是 Graph 的一部分,pass 是一个来源于编译原理的概念,它会接收一种中间表示(IR),遍历它并且进行一些变换,生成满足某种条件的新 IR。

TorchScript 中定义了许多 pass 来优化 Graph。比如对于常规编译器很常见的 DeadCodeElimination(DCE),CommonSubgraphElimination (CSE) 等等;也有一些针对深度学习的融合优化,比如 FuseConvBN 等;还有针对特殊任务的 pass,ONNX 的导出就是其中一类 pass。

JIT Trace

Jit trace 在 python 侧的接口为 torch.jit.trace,输入的参数会经过层层传递,最终会进入torch/jit/frontend/trace.cpp 中的 trace 函数中。这个函数是 Jit trace 的核心,大致执行了下面几个步骤:

1)创建新的 TracingState 对象,该对象会维护 trace 的 Graph 以及一些必要的环境参数。

2)根据 trace 时的模型输入参数,生成 Graph 的输入节点。

3)进行模型推理,同时生成 Graph 中的各个元素。

4)生成 Graph 的输出节点。

5)进行一些简单的优化。

torch/jit/frontend/trace.cpp 链接:

https://github.com/pytorch/pytorch/blob/9f541aa3aca768e7fbfa4a9d648b554f22b261f7/torch/csrc/jit/frontend/tracer.cpp#L457

下面会一一介绍这些步骤的细节:

1. 创建TracingState对象

TracingState 对象包含了 Graph 的指针、函数名映射、栈帧信息等,trace 的过程就是不断更新 TracingState 的过程。

TracingState 网址:

https://github.com/pytorch/pytorch/blob/9f541aa3aca768e7fbfa4a9d648b554f22b261f7/torch/csrc/jit/frontend/tracer.h#L43

代码语言:javascript
复制
struct TORCH_API TracingState
    : public std::enable_shared_from_this<TracingState> {
  // 部分接口,可以帮助Graph的构建
  std::shared_ptr<Graph> graph;

  void enterFrame();
  void leaveFrame();
  
  void setValue(const IValue& v, Value* value);
  void delValue(const IValue& var);
  Value* getValue(const IValue& var);
  Value* getOutput(const IValue& var, size_t i);
  bool hasValue(const IValue& var) const;

  Node* createNode(c10::Symbol op_name, size_t num_outputs);
  void insertNode(Node* node);
};

2. 生成 Graph 输入

这个步骤会根据输入的 IValue 的类型,在 graph 中插入新的输入 Value。还记得在基本概念章节中我们提到的 IValue 与 Value 的区别吗?

代码语言:javascript
复制
for (IValue& input : inputs) {
    // addInput这个函数会unpack一些容器类型的IValue,创建对应的Node
    input = addInput(state, input, input.type(), state->graph->addInput());
}

3. 进行 Tracing

Tracing 的过程就是使用样本数据进行一次推理的过程,但是实际在 github 的源码中,并不能找到关于推理时如何更新 TracingState 的代码。

那么 PyTorch 到底是如何做到在推理时更新 TracingState 的呢?我们首先介绍关于 PyTorch 源码编译的一些小细节。

PyTorch 要适配各种硬件以及环境,为所有这些情况定制代码工作量大得可怕,也不方便后续的维护更新。因此 PyTorch 中许多代码是根据 build 时的参数生成出来,更新 TracingState 的代码就是其中之一。生成 Tracing 代码的脚本如下:

代码语言:javascript
复制
python -m tools.autograd.gen_autograd \
    aten/src/ATen/native/native_functions.yaml \
    ${OUTPUT_DIR} \
    tools/autograd
    
# derivatives.yaml和native_functions.yaml中包含
# 许多FunctionSchema以及生成代码需要的信息

大家可以跑一下看看都生成了些什么。生成的代码中 TraceTypeEverything.cpp 包含了许多关于更新 TracingState 的内容,我们还是以 add 算子举例如下:

yaml

代码语言:javascript
复制
- func: scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor
  structured_delegate: scatter_add.out
  variants: function, method

- func: scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)
  structured_delegate: scatter_add.out
  variants: method

- func: scatter_add.out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!)
  structured: True
  variants: function
  dispatch:
    CPU, CUDA: scatter_add
    
 # func的内容是一个FunctionSchema,定义了函数的输入输出、别名信息等。

cpp

代码语言:javascript
复制
at::Tensor scatter_add(c10::DispatchKeySet ks, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) {
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
  // 步骤1:如果tracing时,使用TracingState创建ops对应的Node并插入Graph
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = c10::Symbol::fromQualString("aten::scatter_add");
    node = tracer_state->createNode(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    jit::tracer::addInputs(node, "dim", dim);
    jit::tracer::addInputs(node, "index", index);
    jit::tracer::addInputs(node, "src", src);
    tracer_state->insertNode(node);
  
    jit::tracer::setTracingState(nullptr);
  }
  // 步骤2:ops计算,不管是否进行Tracing都会执行
  auto result =at::_ops::scatter_add::redispatch(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Tracer), self, dim, index, src);
  if (tracer_state) {
  // 步骤3:在TracingState中设置ops输出
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}

以上上方是 FunctionSchema,下方为生成的代码。代码会根据是否 isTracing 来选择是否记录 Graph 的结构信息。

实际在 Tracing 时,每经过一个 ops,都会调用一个类似上面生成的函数,执行如下步骤:

1)在推理前根据解析的 FunctionSchema 生成 Node 以及各个输入 Value;

2)然后进行 ops 的正常计算;

3)最后根据 ops 的输出生成 Node 的输出 Value。

4. 注册 Graph 输出

这部分没有太多值得说的,就是挨个把推理的输出注册成 Graph 的输出 Value。由于输出在一个栈中,因此输出的编号要逆序。

代码语言:javascript
复制
    size_t i = 0;
    for (auto& output : out_stack) {
      // NB: The stack is in "reverse" order, so when we pass the diagnostic
      // number we need to flip it based on size.
      state->graph->registerOutput(
          state->getOutput(output, out_stack.size() - i));
      i++;
    }

5. Graph 优化

完成 Tracing 后,会对 Graph 进行一些简单的优化,包括如下数个 passes:

· Inline (Optional):网络定义经常会包含很多嵌套结构,比如 Resnet 会由很多 BottleNeck 组成。这就会涉及到对 sub module 的调用,这种调用会生成 prim::CallMethod 等 Node。Inline 优化会将 sub module 的 Graph 内联到当前的 Graph 中,消除 CallMethod、CallFunction 等节点。

· FixupTraceScopeBlock:处理一些与 scope 相关的 node,比如将诸如prim::TracedAttr[scope="__module.f.param"]()这样的 Node 拆成数个 prim::GetAttr 的组合。

· NormalizeOps:有些不同名 Node 可能有相同的功能,比如 aten::absolute 和 aten::abs,N ormalizeOps 会把这些 Node 的类型名字统一(通常为较短的那个)。

对 pass 更详细的分析会在后续的分享中介绍。

Inline 网址:

https://github.com/pytorch/pytorch/blob/f883ed9095b26ba042509785f14076188a452c01/torch/csrc/jit/passes/inliner.cpp

FixupTraceScopeBlock 网址:

https://github.com/pytorch/pytorch/blob/f883ed9095b26ba042509785f14076188a452c01/torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp

NormalizeOps 网址:

https://github.com/pytorch/pytorch/blob/f883ed9095b26ba042509785f14076188a452c01/torch/csrc/jit/passes/normalize_ops.cpp

经过上述步骤,就可以得到经过 trace 的结果。

ONNX Export

Onnx 模型的导出同样要用到 jit trace 的过程,大致的步骤如下:

1)加载 ops 的 symbolic 函数,主要是 torch 中预定义的 symbolic。

2)设置环境,包括 opset_version,是否折叠常量等等。

3)使用 jit trace 生成 Graph。

4)将 Graph 中的 Node 映射成 ONNX 的 Node,并进行必要的优化。

5)将模型导出成 ONNX 的序列化格式。

接下来,我们将按照顺序介绍以上几个步骤:

1. 加载 Symbolic

严格地说这一步在 export 之前就已经完成。在 symbolic_registry.py 中,会维护一个 _symbolic_versions 对象,在导入这个模块时会使用 importlib 将预先定义的 symbolic(torch.onnx.symbolic_opset) 加载到其中。

symbolic_registry.py 网址:

https://github.com/pytorch/pytorch/blob/9f541aa3aca768e7fbfa4a9d648b554f22b261f7/torch/onnx/symbolic_registry.py

代码语言:javascript
复制
_symbolic_versions: Dict[Union[int, str], Any] = {}
from torch.onnx.symbolic_helper import _onnx_stable_opsets, _onnx_main_opset
for opset_version in _onnx_stable_opsets + [_onnx_main_opset]:
    module = importlib.import_module("torch.onnx.symbolic_opset{}".format(opset_version))
    _symbolic_versions[opset_version] = module

_symbolic_versions中 key 为 opset_version,value 为对应的 symbolic 集合。symbolic 是一种映射函数,用来把对应的 aten/prim Node 映射成 onnx 的 Node。可以阅读 torch/onnx/symbolic_opset.py 了解更多细节。

torch/onnx/symbolic_opset.py 网址:

https://github.com/pytorch/pytorch/tree/9f541aa3aca768e7fbfa4a9d648b554f22b261f7/torch/onnx

2. 设置环境

根据 export 的输入参数调整环境信息,比如 opset 的版本、是否将 init 导出成 Input、是否进行常量折叠等等。后续的优化会根据这些环境运行特定的 passes。

3. Graph Tracing

这一步实际执行的就是上面介绍过的 Jit Tracing 过程,如果遗忘的话可以再复习一下哦。

ToONNX

Graph 在实际使用之前会经过很多的 pass,每个 pass 都会对 Graph 进行一些变换,可以在 torch/csrc/jit/passes 中查看实现细节。这些 pass 很多功能与常见的编译器中的类似,篇幅关系就不在这里展开介绍了。对于 torchscript->ONNX 而言,最重要的 pass 当属 ToONNX。

ToONNX 的 python 接口为torch._C._jit_pass_onnx,对应的实现为 onnx.cpp。它会遍历 Graph 中所有的 Node,生成对应的 ONNX Node,插入新的 Graph 中:

torch/csrc/jit/passes 网址:

https://github.com/pytorch/pytorch/tree/9f541aa3aca768e7fbfa4a9d648b554f22b261f7/torch/csrc/jit/passes

onnx.cpp 网址:

https://github.com/pytorch/pytorch/blob/9f541aa3aca768e7fbfa4a9d648b554f22b261f7/torch/csrc/jit/passes/onnx.cpp#L163

代码语言:javascript
复制
  auto k = old_node->kind();    // 取得Node的ops类型
  if (k.is_caffe2()) {
    // ToONNX之前的会有一些对caffe2算子的pass
    // 因此这里只要直接clone到新的graph中即可
    cloneNode(old_node);
  } else if (k == prim::PythonOp) {
    // 如果是Python自定义的函数,比如继承自torch.autograd.Function的函数
    // 就会查找并调用对应的symbolic函数进行转换
    callPySymbolicMethod(static_cast<ConcretePythonOp*>(old_node));
  } else {
    // 如果是其他情况(通常是aten的算子)调用步骤1加载的symbolic进行转换
    callPySymbolicFunction(old_node);
  }

cloneNode 的功能就和名字一样,就是简单的拷贝 old_node,然后塞进新的 Graph 中。

callPySymbolicFunction

当 Node 的类型为 PyTorch 的内置类型时,会调用这个函数来处理。

该函数会调用 python 侧的 torch.onnx.utils._run_symbolic_function 函数,将 Node 进行转换,并插入新的 Graph,我们可以尝试如下 python 代码:

torch.onnx.utils._run_symbolic_function 链接:

https://github.com/pytorch/pytorch/blob/9f541aa3aca768e7fbfa4a9d648b554f22b261f7/torch/onnx/utils.py#L1045

代码语言:javascript
复制
graph = torch._C.Graph()  # 创建Graph
[graph.addInput() for _ in range(2)]  # 插入两个输入
node = graph.create('aten::add', list(graph.inputs()))  # 创建节点
node = graph.insertNode(node)  # 插入节点
graph.registerOutput(node.output())  # 注册输出
print(f'old graph:\n {graph}')

new_graph = torch._C.Graph()  # 创建新的Graph用于ONNX
[new_graph.addInput() for _ in range(2)]  # 插入两个输入
_run_symbolic_function(
    new_graph, node, inputs=list(new_graph.inputs()),
    env={})  # 将aten Node转换为onnx Node, 插入新的Graph
# 如果是torch>1.8,那么可能还要传入block
print(f'new graph:\n {new_graph}')

然后看一下可视化的结果:

Old graph

代码语言:javascript
复制
 graph(%0 : Tensor,
      %1 : Tensor):
  %2 : Tensor = aten::add(%0, %1)
  return (%2)

New graph

代码语言:javascript
复制
 graph(%0 : Tensor,
      %1 : Tensor):
  %2 : Tensor = onnx::Add(%0, %1)
  return ()

可以看见,原来的 aten::add 节点已经被替换为了onnx::Add。那么这个映射是如何完成的呢?还记得第一步记录的 _symbolic_versions 吗?_run_symbolic_function 会调用 torch.onnx.symbolic_registry 中的 _find_symbolic_in_registry 函数,查找 _symbolic_versions 中是否存在满足条件的映射,如果存在,就会进行如上的转换。

注意:转换的新 Graph 中没有输出 Value,这是因为这部分是在 ToONNX 的 c++ 代码中实现,_run_symbolic_function 仅负责 Node 的映射。

callPySymbolicMethod

一些非 pytorch 原生的计算会被标记为 PythonOp。碰到这种 Node 时,会有三种可能的处理方式:

1)如果这个 PythonOp 带有名为 symbolic 的属性,那么就会尝试使用这个 symbolic 当作映射函数,生成 ONNX 节点。

2)如果没有 symbolic 属性,但是在步骤 1 的时候注册了 prim::PythonOp 的 symbolic 函数,那么就会使用这个函数生成节点。

3)如果都没有,则直接 clone PythonOp 节点到新的 Graph。

symbolic 函数的写法很简单,基本上就是调用 python bind 的 Graph 接口创建新节点,比如:

代码语言:javascript
复制
class CustomAdd(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, val):
        return x + val

    @staticmethod
    def symbolic(g, x, val):
        # g.op 可以创建新的Node
        # Node的名字 为 <domain>::<node_name>,如果domain为onnx,可以只写node_name
        # Node可以有很多属性,这些属性名必须有_<type>后缀,比如val如果为float类型,则必须有_f后缀
        return g.op("custom_domain::add", x, val_f=val)

实际在使用上面的函数时,就会生成 custom_domain::add 这个 Node。当然,能否被用于推理这就要看推理引擎的支持情况了。

通过 callPySymbolicFunction 和 callPySymbolicMethod,就可以生成一个由 ONNX(或自定义的 domain 下的 Node)组成的新 Graph。这之后还会执行一些优化 ONNX Graph 的 pass,这里不详细展开了。

5. 序列化

到这里为止建图算是完成了,但是要给其他后端使用的话,需要将这个 Grap 序列化并导出。序列化的过程比较简单,基本上只是调用 ONNX 的 proto 接口,将 Graph 中的各个元素映射到 ONNX 的 GraphProto 上。没有太多值得展开的内容,可以阅读 export.cpp 中的 EncodeGraph,EncodeBlock,EncodeNode 函数了解更多细节。

之后只要根据具体的 export_type,将序列化后的 proto 写入文件即可。

至此,ONNX export 完成,可以开始享受各种推理引擎带来的速度提升了。

通过上面的内容分享,我们应该对如何使用 trace 方式生成 jit 模型,以及 trace 模型如何影响 ONNX 导出有了一个初步的认识。为了让模型更好地为部署服务,我们可以考虑对模型进行优化,后续的分享中将介绍一种常用的优化范式,敬请期待哦。

MMDeploy 已添加对 torchscript 模型的支持,其中也采用 trace 的方式构建 jit 模型,欢迎大家访问 MMDeploy GitHub 主页体验

如果我们的分享给你带来一定的帮助,欢迎多多 Star,Fork 和 PR 呀,比心!

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-03-28,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 OpenMMLab 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
容器服务
腾讯云容器服务(Tencent Kubernetes Engine, TKE)基于原生 kubernetes 提供以容器为核心的、高度可扩展的高性能容器管理服务,覆盖 Serverless、边缘计算、分布式云等多种业务部署场景,业内首创单个集群兼容多种计算节点的容器资源管理模式。同时产品作为云原生 Finops 领先布道者,主导开源项目Crane,全面助力客户实现资源优化、成本控制。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档