首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >OpenAI Triton现状调研

OpenAI Triton现状调研

原创
作者头像
aaronwjzhao
修改2024-11-15 15:55:45
修改2024-11-15 15:55:45
8660
举报
文章被收录于专栏:AI工程落地AI工程落地

接入pytorch方式:静态图编译模式,添加torch.compile装饰器

PyTorch 2.3.1引入了torch.compile功能,允许用户将包含triton内核的PyTorch代码进行本地执行。

这一功能的引入,使得用户能够轻松地将eager PyTorch代码迁移到torch.compile,而无需担心性能回归或图形中断。

torch.compile通过优化代码的执行路径和减少不必要的计算开销,极大地提升了PyTorch代码的执行效率。

代码语言:txt
复制
@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
def foo(x):
    return torch.sin(x) + torch.cos(x)
torch.compile工作流程
torch.compile工作流程

第一步: TorchDynamo来捕获计算图

第二步:TorchInductor进行图编译优化,产生新的高效计算代码

torch.compile介入第三方工具示意图
torch.compile介入第三方工具示意图

注意:常见的pytorch前向、反向、优化器算子还都是cuda算子,除非用户用上述torch.compile自己替换推理和训练代码,才会使用到Triton。

Triton源码学习

Ops算子

triton支持的都是小算子(如log、exp、cat、reshape等,相当于模拟器算子),算子列表见:triton.language — Triton documentation

Triton小算子列表(不完整)
Triton小算子列表(不完整)

手动实现Kernel(不一定要做,torch.compile可以捕获计算图、编译成Ops小算子)

如果要实现layernorm、embedding、gemm、conv等复杂计算,需要用上面给出例子的方法写新算子。如LayerNorm的实现:Layer Normalization — Triton documentation

LayerNorm实现实例
LayerNorm实现实例

试用情况

优点:无侵入式修改用户代码,添加装饰器即可。@torch.compile

缺点:对于多卡、复杂网络可能支持度欠佳

block

fail/suc

原因分析

Llama3.2 11B整个模型

fail

不支持多个device并行

Llama3.2 11B视觉模型

suc

Llama3.2 11B语言模型

fail

TorchDynamo捕获计算图失败

Llama3.1 8B整个模型

fail

能直接把整网替换成triton,生成几个词之后,推理会报错。感觉是attention中某些计算不支持。

Llama3.1 8B Attention层

fail

同上

Llama3.1 8B RMSNorm

suc

Llama3.1 8B MLP

suc

Resnet

suc

国内外现状分析

企业&研究机构

主要贡献

代码仓库

微软

方便新后端集成,基于triton做的中间层。 整体软件架构: [Triton IR] -> [Middle Layer] -> [HW specific IR] triton-shared做的是Middle Layer

https://github.com/microsoft/triton-shared

寒武纪

和微软的triton-shared类似,开源了基于Linalg编译技术和Triton编程语言的AI编译器前端,可快速集成新的硬件后端。

https://github.com/Cambricon/triton-linalg

智源研究院

基于Triton实现的高性能算子库,通过在PyTorch的ATen后端注册,FlagGems实现了无缝过渡,允许用户切换到Triton函数库,而无需修改其模型代码。算子支持挺全的,基本上涵盖了大模型训练和推理所需要的算子。 全局替换 FlagGems 算子 import flag_gems flag_gems.enable() 局部替换 FlagGems 算子 import torch import flag_gems M, N, K = 1024, 1024, 1024 A = torch.randn((M, K), dtype=torch.float16, device="cuda") B = torch.randn((K, N), dtype=torch.float16, device="cuda") with flag_gems.use_gems(): C = torch.mm(A, B)

https://github.com/FlagOpen/FlagGems

英伟达

当前Triton主要支持的还是英伟达的GPU,可能比cuda要方便用户编程,所以英伟达也在支持Triton

OpenAI

Triton的主要贡献者

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 接入pytorch方式:静态图编译模式,添加torch.compile装饰器
  • Triton源码学习
    • Ops算子
    • 手动实现Kernel(不一定要做,torch.compile可以捕获计算图、编译成Ops小算子)
  • 试用情况
  • 国内外现状分析
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档