PyTorch 2.3.1引入了torch.compile功能,允许用户将包含triton内核的PyTorch代码进行本地执行。
这一功能的引入,使得用户能够轻松地将eager PyTorch代码迁移到torch.compile,而无需担心性能回归或图形中断。
torch.compile通过优化代码的执行路径和减少不必要的计算开销,极大地提升了PyTorch代码的执行效率。
@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
def foo(x):
return torch.sin(x) + torch.cos(x)
第一步: TorchDynamo来捕获计算图
第二步:TorchInductor进行图编译优化,产生新的高效计算代码

注意:常见的pytorch前向、反向、优化器算子还都是cuda算子,除非用户用上述torch.compile自己替换推理和训练代码,才会使用到Triton。
triton支持的都是小算子(如log、exp、cat、reshape等,相当于模拟器算子),算子列表见:triton.language — Triton documentation

如果要实现layernorm、embedding、gemm、conv等复杂计算,需要用上面给出例子的方法写新算子。如LayerNorm的实现:Layer Normalization — Triton documentation

优点:无侵入式修改用户代码,添加装饰器即可。@torch.compile
缺点:对于多卡、复杂网络可能支持度欠佳
block | fail/suc | 原因分析 |
|---|---|---|
Llama3.2 11B整个模型 | fail | 不支持多个device并行 |
Llama3.2 11B视觉模型 | suc | |
Llama3.2 11B语言模型 | fail | TorchDynamo捕获计算图失败 |
Llama3.1 8B整个模型 | fail | 能直接把整网替换成triton,生成几个词之后,推理会报错。感觉是attention中某些计算不支持。 |
Llama3.1 8B Attention层 | fail | 同上 |
Llama3.1 8B RMSNorm | suc | |
Llama3.1 8B MLP | suc | |
Resnet | suc |
企业&研究机构 | 主要贡献 | 代码仓库 |
|---|---|---|
微软 | 方便新后端集成,基于triton做的中间层。 整体软件架构: [Triton IR] -> [Middle Layer] -> [HW specific IR] triton-shared做的是Middle Layer | https://github.com/microsoft/triton-shared |
寒武纪 | 和微软的triton-shared类似,开源了基于Linalg编译技术和Triton编程语言的AI编译器前端,可快速集成新的硬件后端。 | https://github.com/Cambricon/triton-linalg |
智源研究院 | 基于Triton实现的高性能算子库,通过在PyTorch的ATen后端注册,FlagGems实现了无缝过渡,允许用户切换到Triton函数库,而无需修改其模型代码。算子支持挺全的,基本上涵盖了大模型训练和推理所需要的算子。
全局替换 FlagGems 算子
| https://github.com/FlagOpen/FlagGems |
英伟达 | 当前Triton主要支持的还是英伟达的GPU,可能比cuda要方便用户编程,所以英伟达也在支持Triton | |
OpenAI | Triton的主要贡献者 |
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。