前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >Github1.3万星,迅猛发展的JAX对比TensorFlow、PyTorch

Github1.3万星,迅猛发展的JAX对比TensorFlow、PyTorch

作者头像
机器之心
发布于 2023-03-29 09:30:08
发布于 2023-03-29 09:30:08
2.6K00
代码可运行
举报
文章被收录于专栏:机器之心机器之心
运行总次数:0
代码可运行

机器之心报道

机器之心编辑部

JAX 是机器学习 (ML) 领域的新生力量,它有望使 ML 编程更加直观、结构化和简洁。

在机器学习领域,大家可能对 TensorFlow 和 PyTorch 已经耳熟能详,但除了这两个框架,一些新生力量也不容小觑,它就是谷歌推出的 JAX。很对研究者对其寄予厚望,希望它可以取代 TensorFlow 等众多机器学习框架。

JAX 最初由谷歌大脑团队的 Matt Johnson、Roy Frostig、Dougal Maclaurin 和 Chris Leary 等人发起。

目前,JAX 在 GitHub 上已累积 13.7K 星。

项目地址:https://github.com/google/jax

迅速发展的 JAX

JAX 的前身是 Autograd,其借助 Autograd 的更新版本,并且结合了 XLA,可对 Python 程序与 NumPy 运算执行自动微分,支持循环、分支、递归、闭包函数求导,也可以求三阶导数;依赖于 XLA,JAX 可以在 GPU 和 TPU 上编译和运行 NumPy 程序;通过 grad,可以支持自动模式反向传播和正向传播,且二者可以任意组合成任何顺序。

开发 JAX 的出发点是什么?说到这,就不得不提 NumPy。NumPy 是 Python 中的一个基础数值运算库,被广泛使用。但是 numpy 不支持 GPU 或其他硬件加速器,也没有对反向传播的内置支持,此外,Python 本身的速度限制阻碍了 NumPy 使用,所以少有研究者在生产环境下直接用 numpy 训练或部署深度学习模型。

在此情况下,出现了众多的深度学习框架,如 PyTorch、TensorFlow 等。但是 numpy 具有灵活、调试方便、API 稳定等独特的优势。而 JAX 的主要出发点就是将 numpy 的以上优势与硬件加速结合。

目前,基于 JAX 已有很多优秀的开源项目,如谷歌的神经网络库团队开发了 Haiku,这是一个面向 Jax 的深度学习代码库,通过 Haiku,用户可以在 Jax 上进行面向对象开发;又比如 RLax,这是一个基于 Jax 的强化学习库,用户使用 RLax 就能进行 Q-learning 模型的搭建和训练;此外还包括基于 JAX 的深度学习库 JAXnet,该库一行代码就能定义计算图、可进行 GPU 加速。可以说,在过去几年中,JAX 掀起了深度学习研究的风暴,推动了科学研究迅速发展。

JAX 的安装

如何使用 JAX 呢?首先你需要在 Python 环境或 Google colab 中安装 JAX,使用 pip 进行安装:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
$ pip install --upgrade jax jaxlib

注意,上述安装方式只是支持在 CPU 上运行,如果你想在 GPU 执行程序,首先你需要有 CUDA、cuDNN ,然后运行以下命令(确保将 jaxlib 版本映射到 CUDA 版本):

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
$ pip install --upgrade jax jaxlib==0.1.61+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

现在将 JAX 与 Numpy 一起导入:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import jaximport jax.numpy as jnpimport numpy as np

JAX 的一些特性

使用 grad() 函数自动微分:这对深度学习应用非常有用,这样就可以很容易地运行反向传播,下面为一个简单的二次函数并在点 1.0 上求导的示例:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
from jax import graddef f(x):  return 3*x**2 + 2*x + 5def f_prime(x):  return 6*x +2grad(f)(1.0)# DeviceArray(8., dtype=float32)f_prime(1.0)# 8.0

jit(Just in time) :为了利用 XLA 的强大功能,必须将代码编译到 XLA 内核中。这就是 jit 发挥作用的地方。要使用 XLA 和 jit,用户可以使用 jit() 函数或 @jit 注释。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
from jax import jitx = np.random.rand(1000,1000)y = jnp.array(x)def f(x):  for _ in range(10):      x = 0.5*x + 0.1* jnp.sin(x)  return xg = jit(f)%timeit -n 5 -r 5 f(y).block_until_ready()# 5 loops, best of 5: 10.8 ms per loop%timeit -n 5 -r 5 g(y).block_until_ready()# 5 loops, best of 5: 341 µs per loop

pmap:自动将计算分配到所有当前设备,并处理它们之间的所有通信。JAX 通过 pmap 转换支持大规模的数据并行,从而将单个处理器无法处理的大数据进行处理。要检查可用设备,可以运行 jax.devices():

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
from jax import pmapdef f(x):  return jnp.sin(x) + x**2f(np.arange(4))#DeviceArray([0.       , 1.841471 , 4.9092975, 9.14112  ], dtype=float32)pmap(f)(np.arange(4))#ShardedDeviceArray([0.       , 1.841471 , 4.9092975, 9.14112  ], dtype=float32)

vmap:是一种函数转换,JAX 通过 vmap 变换提供了自动矢量化算法,大大简化了这种类型的计算,这使得研究人员在处理新算法时无需再去处理批量化的问题。示例如下:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
from jax import vmapdef f(x):  return jnp.square(x)f(jnp.arange(10))#DeviceArray([ 0,  1,  4,  9, 16, 25, 36, 49, 64, 81], dtype=int32)vmap(f)(jnp.arange(10))#DeviceArray([ 0,  1,  4,  9, 16, 25, 36, 49, 64, 81], dtype=int32)

TensorFlow vs PyTorch vs Jax

在深度学习领域有几家巨头公司,他们所提出的框架被广大研究者使用。比如谷歌的 TensorFlow、Facebook 的 PyTorch、微软的 CNTK、亚马逊 AWS 的 MXnet 等。

每种框架都有其优缺点,选择的时候需要根据自身需求进行选择。

我们以 Python 中的 3 个主要深度学习框架——TensorFlow、PyTorch 和 Jax 为例进行比较。这些框架虽然不同,但有两个共同点:

  • 它们是开源的。这意味着如果库中存在错误,使用者可以在 GitHub 中发布问题(并修复),此外你也可以在库中添加自己的功能;
  • 由于全局解释器锁,Python 在内部运行缓慢。所以这些框架使用 C/C++ 作为后端来处理所有的计算和并行过程。

那么它们的不同体现在哪些方面呢?如下表所示,为 TensorFlow、PyTorch、JAX 三个框架的比较。

TensorFlow

TensorFlow 由谷歌开发,最初版本可追溯到 2015 年开源的 TensorFlow0.1,之后发展稳定,拥有强大的用户群体,成为最受欢迎的深度学习框架。但是用户在使用时,也暴露了 TensorFlow 缺点,例如 API 稳定性不足、静态计算图编程复杂等缺陷。因此在 TensorFlow2.0 版本,谷歌将 Keras 纳入进来,成为 tf.keras。

目前 TensorFlow 主要特点包括以下:

  • 这是一个非常友好的框架,高级 API-Keras 的可用性使得模型层定义、损失函数和模型创建变得非常容易;
  • TensorFlow2.0 带有 Eager Execution(动态图机制),这使得该库更加用户友好,并且是对以前版本的重大升级;
  • Keras 这种高级接口有一定的缺点,由于 TensorFlow 抽象了许多底层机制(只是为了方便最终用户),这让研究人员在处理模型方面的自由度更小;
  • Tensorflow 提供了 TensorBoard,它实际上是 Tensorflow 可视化工具包。它允许研究者可视化损失函数、模型图、模型分析等。

PyTorch

PyTorch(Python-Torch) 是来自 Facebook 的机器学习库。用 TensorFlow 还是 PyTorch?在一年前,这个问题毫无争议,研究者大部分会选择 TensorFlow。但现在的情况大不一样了,使用 PyTorch 的研究者越来越多。PyTorch 的一些最重要的特性包括:

  • 与 TensorFlow 不同,PyTorch 使用动态类型图,这意味着执行图是在运行中创建的。它允许我们随时修改和检查图的内部结构;
  • 除了用户友好的高级 API 之外,PyTorch 还包括精心构建的低级 API,允许对机器学习模型进行越来越多的控制。我们可以在训练期间对模型的前向和后向传递进行检查和修改输出。这被证明对于梯度裁剪和神经风格迁移非常有效;
  • PyTorch 允许用户扩展代码,可以轻松添加新的损失函数和用户定义的层。PyTorch 的 Autograd 模块实现了深度学习算法中的反向传播求导数,在 Tensor 类上的所有操作, Autograd 都能自动提供微分,简化了手动计算导数的复杂过程;
  • PyTorch 对数据并行和 GPU 的使用具有广泛的支持;
  • PyTorch 比 TensorFlow 更 Python 化。PyTorch 非常适合 Python 生态系统,它允许使用 Python 类调试器工具来调试 PyTorch 代码。

JAX 

JAX 是来自 Google 的一个相对较新的机器学习库。它更像是一个 autograd 库,可以区分原生的 python 和 NumPy 代码。JAX 的一些特性主要包括:

  • 正如官方网站所描述的那样,JAX 能够执行 Python+NumPy 程序的可组合转换:向量化、JIT 到 GPU/TPU 等等;
  • 与 PyTorch 相比,JAX 最重要的方面是如何计算梯度。在 Torch 中,图是在前向传递期间创建的,梯度在后向传递期间计算, 另一方面,在 JAX 中,计算表示为函数。在函数上使用 grad() 返回一个梯度函数,该函数直接计算给定输入的函数梯度;
  • JAX 是一个 autograd 工具,不建议单独使用。有各种基于 JAX 的机器学习库,其中值得注意的是 ObJax、Flax 和 Elegy。由于它们都使用相同的核心并且接口只是 JAX 库的 wrapper,因此可以将它们放在同一个 bracket 下;
  • Flax 最初是在 PyTorch 生态系统下开发的,更注重使用的灵活性。另一方面,Elegy 受 Keras 启发。ObJAX 主要是为以研究为导向的目的而设计的,它更注重简单性和可理解性。 

参考链接:

https://www.askpython.com/python-modules/tensorflow-vs-pytorch-vs-jax

https://jax.readthedocs.io/en/latest/notebooks/quickstart.html

https://jax.readthedocs.io/en/latest/notebooks/quickstart.html

https://www.zhihu.com/question/306496943/answer/557876584

NVIDIA对话式AI开发工具NeMo的应用

开源工具包 NeMo 是一个集成自动语音识别(ASR)、自然语言处理(NLP)和语音合成(TTS)的对话式 AI 工具包,便于开发者开箱即用,仅用几行代码便可以方便快速的完成对话式 AI 场景中的相关任务。

8月12日开始,英伟达专家将带来三期直播分享,通过理论解读和实战演示,展示如何使用 NeMo 快速完成文本分类任务、快速构建智能问答系统、构建智能对话机器人

直播链接:https://jmq.h5.xeknow.com/s/how4w(点击阅读原文直达)

报名方式:进入直播间——移动端点击底部「观看直播」、PC端点击「立即学习」——填写报名表单后即可进入直播间观看。

交流答疑群:直播间详情页扫码即可加入。

© THE END 

转载请联系本公众号获得授权

投稿或寻求报道:content@jiqizhixin.com

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2021-08-09,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 机器之心 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
只知道TF和PyTorch还不够,快来看看怎么从PyTorch转向自动微分神器JAX
Jax 是谷歌开发的一个 Python 库,用于机器学习和数学计算。一经推出,Jax 便将其定义为一个 Python+NumPy 的程序包。它有着可以进行微分、向量化,在 TPU 和 GPU 上采用 JIT 语言等特性。简而言之,这就是 GPU 版本的 numpy,还可以进行自动微分。甚至一些研究者,如 Skye Wanderman-Milne,在去年的 NeurlPS 2019 大会上就介绍了 Jax。
机器之心
2020/05/19
1.5K0
只知道TF和PyTorch还不够,快来看看怎么从PyTorch转向自动微分神器JAX
​Jax 生态再添新库:DeepMind 开源 Haiku、RLax
Jax 是谷歌开源的一个科学计算库,能对 Python 程序与 NumPy 运算执行自动微分,而且能够在 GPU 和 TPU 上运行,具有很高的性能。基于 Jax 已有很多优秀的开源项目,如 Trax 等。近日,DeepMind 开源了两个基于 Jax 的新机器学习库,分别是 Haiku 和 RLax,它们都有着各自的特色,对于丰富深度学习社区框架、提升研究者和开发者的使用体验有着不小的意义。
机器之心
2020/02/25
1.1K0
PyTorch攻势凶猛,程序员正在抛弃TensorFlow?
来源 | The Gradient 译者 | 夕颜 出品 | AI科技大本营(ID:rgznai100)
AI科技大本营
2019/11/13
6200
『JAX中文文档』JAX快速入门
简单的说就是GPU加速、支持自动微分(autodiff)的numpy。众所周知,numpy是Python下的基础数值运算库,得到广泛应用。用Python搞科学计算或机器学习,没人离得开它。但是numpy不支持GPU或其他硬件加速器,也没有对backpropagation的内置支持,再加上Python本身的速度限制,所以很少有人会在生产环境下直接用numpy训练或部署深度学习模型。这也是为什么会出现Theano, TensorFlow, Caffe等深度学习框架的原因。但是numpy有其独特的优势:底层、灵活、调试方便、API稳定且为大家所熟悉(与MATLAB一脉相承),深受研究者的青睐。JAX的主要出发点就是将numpy的以上优势与硬件加速结合。现在已经开源的JAX ( https://github.com/google/jax) 就是通过GPU (CUDA)来实现硬件加速。出自:https://www.zhihu.com/question/306496943/answer/557876584
小宋是呢
2021/04/29
2.5K0
前端如何开始深度学习,那不妨试试JAX
在深度学习方面,TensorFlow 和 PyTorch是绝对的王者。但是,但除了这两个框架之外,一些新生的框架也不容小觑,比如谷歌推出的 JAX深度学习框架。
xiangzhihong
2022/07/30
1.9K0
前端如何开始深度学习,那不妨试试JAX
JAX介绍和快速入门示例
来源:DeepHub IMBA本文约3300字,建议阅读10+分钟本文中,我们了解了 JAX 是什么,并了解了它的一些基本概念。 JAX 是一个由 Google 开发的用于优化科学计算Python 库: 它可以被视为 GPU 和 TPU 上运行的NumPy , jax.numpy提供了与numpy非常相似API接口。 它与 NumPy API 非常相似,几乎任何可以用 numpy 完成的事情都可以用 jax.numpy 完成。 由于使用XLA(一种加速线性代数计算的编译器)将Python和JAX代码JI
数据派THU
2022/06/16
2K0
JAX介绍和快速入门示例
JAX: 快如 PyTorch,简单如 NumPy - 深度学习与数据科学
JAX 是 TensorFlow 和 PyTorch 的新竞争对手。 JAX 强调简单性而不牺牲速度和可扩展性。由于 JAX 需要更少的样板代码,因此程序更短、更接近数学,因此更容易理解。
数据科学工厂
2024/01/02
1.7K0
JAX: 快如 PyTorch,简单如 NumPy - 深度学习与数据科学
PyTorch称霸学界,TensorFlow固守业界,ML框架之争将走向何方?
自 2012 年深度学习再度成为焦点以来,很多机器学习框架成为研究者和业界工作者的新宠。从早期的学术框架 Caffe、Theano 到如今有业界背景的大规模框架 Pytorch 和 TensorFlow,层出不穷的新成果使得跟踪当前最流行的框架变得越发困难。
机器之心
2019/10/15
6690
PyTorch称霸学界,TensorFlow固守业界,ML框架之争将走向何方?
被PyTorch打爆!谷歌抛弃TensorFlow,押宝JAX
---- 新智元报道   编辑:拉燕 如願 好困 【新智元导读】谷歌Meta之争看来还没完!TensorFlow干不过还有JAX,二番战能否战胜PyTorch? 很喜欢有些网友的一句话: 「这孩子实在不行,咱再要一个吧。」 谷歌还真这么干了。 养了七年的TensorFlow终于还是被Meta的PyTorch干趴下了,在一定程度上。 谷歌眼见不对,赶紧又要了一个——「JAX」,一款全新的机器学习框架。 最近超级火爆的DALL·E Mini都知道吧,它的模型就是基于JAX进行编程的,从而充分地利用了谷
新智元
2022/06/16
4850
被PyTorch打爆!谷歌抛弃TensorFlow,押宝JAX
Jax:有望取代Tensorflow,谷歌出品的又一超高性能机器学习框架
在机器学习框架方面,JAX是一个新生事物——尽管Tensorflow的竞争对手从技术上讲已经在2018年后已经很完备,但直到最近JAX才开始在更广泛的机器学习研究社区中获得吸引力。
HuangWeiAI
2020/03/04
1.8K0
JAX 中文文档(十二)
我们将 JAX 发布为两个独立的 Python 轮子,即纯 Python 轮子 jax 和主要由 C++ 组成的轮子 jaxlib,后者包含库,例如:
ApacheCN_飞龙
2024/06/22
4960
JAX 中文文档(十二)
新星JAX :双挑TensorFlow和PyTorch!有望担纲Google主要科学计算库和神经网络库
JAX是机器学习框架领域的新生力量,尽管这个Tensorflow的竞争对手从2018年末开就已经出现,但直到最近,JAX才开始在更广泛的机器学习研究领域中获得关注。
新智元
2020/03/03
1.5K0
新星JAX :双挑TensorFlow和PyTorch!有望担纲Google主要科学计算库和神经网络库
TensorFlow和PyTorch的实际应用比较
TensorFlow和PyTorch是两个最受欢迎的开源深度学习框架,这两个框架都为构建和训练深度学习模型提供了广泛的功能,并已被研发社区广泛采用。但是作为用户,我们一直想知道哪种框架最适合我们自己特定项目,所以在本文与其他文章的特性的对比不同,我们将以实际应用出发,从性能、可伸缩性和其他高级特性方面比较TensorFlow和PyTorch。
deephub
2023/02/01
4.7K0
JAX 中文文档(二)
JAX 中的默认数组实现是 jax.Array。在许多方面,它与您可能熟悉的 NumPy 包中的 numpy.ndarray 类型相似,但它也有一些重要的区别。
ApacheCN_飞龙
2024/06/22
4980
2019机器学习框架之争:与Tensorflow竞争白热化,进击的PyTorch赢在哪里?
2019年,机器学习框架之争进入了新阶段:PyTorch与TensorFlow成为最后两大玩家,PyTorch占据学术界领军地位,TensorFlow在工业界力量依然强大,两个框架都在向对方借鉴,但是都不太理想。
大数据文摘
2019/10/15
7470
2019 年机器学习框架之争:PyTorch 和 TensorFlow 谁更有胜算?
对于机器学习科研工作者和工业界从业人员来说,熟练掌握一种机器学习框架是必备技能之一。随着深度学习技术发展的突飞猛进,机器学习框架市场也渐渐度过了初期野蛮生长的阶段。大浪淘沙,目前仍然活跃的机器学习框架主要是 PyTorch 和 TensorFlow。本文从学术界和工业界两个方面深度盘点了 2019 年机器学习框架的发展趋势。
AI科技评论
2019/11/26
4800
TensorFlow,危!抛弃者正是谷歌自己
萧箫 丰色 发自 凹非寺 量子位 | 公众号 QbitAI 收获接近16.6万个Star、见证深度学习崛起的TensorFlow,地位已岌岌可危。 并且这次,冲击不是来自老对手PyTorch,而是自家新秀JAX。 最新一波AI圈热议中,连fast.ai创始人Jeremy Howard都下场表示: JAX正逐渐取代TensorFlow这件事,早已广为人知了。现在它就在发生(至少在谷歌内部是这样)。 LeCun更是认为,深度学习框架之间的激烈竞争,已经进入了一个新的阶段。 LeCun表示,当初谷歌的Tens
量子位
2022/06/24
3880
TensorFlow,危!抛弃者正是谷歌自己
深度学习长文|使用 JAX 进行 AI 模型训练
在人工智能模型的开发旅程中,选择正确的机器学习开发框架是一项至关重要的决策。历史上,众多库都曾竞相争夺“人工智能开发者首选框架”这一令人垂涎的称号。(你是否还记得 Caffe 和 Theano?)在过去的几年里,TensorFlow 以其对高效率、基于图的计算的重视,似乎已经成为了领头羊(这是根据作者对学术论文提及次数和社区支持力度的观察得出的结论)。而在近十年的转折点上,PyTorch 以其对用户友好的 Python 风格接口的强调,似乎已经稳坐了霸主之位。但是,近年来,一个新兴的竞争者迅速崛起,其受欢迎程度已经到了不容忽视的地步。JAX 以其对提升人工智能模型训练和推理性能的追求,同时不牺牲用户体验,正逐步向顶尖位置发起挑战。
数据科学工厂
2024/06/18
4070
深度学习长文|使用 JAX 进行 AI 模型训练
让你捷足先登的深度学习框架
大数据文摘授权转载自数据派THU 作者:陈之炎 对于据科学的初学者来说,利用开源的深度学习框架,可以大幅度简化复杂的大规模度学习模型的实现过程。在深度学习框架下构建模型,无需花费几天或几周的时间从头开始编写代码,便可以轻松实现诸如卷积神经网络这样复杂的模型。在本文中,将介绍几种非常有用的深度学习框架、它们的优点以及应用,通过对每个框架进行比较,研发人员了解如何有选择地使用它们,高效快捷完成项目任务。 深度学习框架概述 深度学习框架是一种界面、库或工具,它使编程人员在无需深入了解底层算法的细节的情况下,能够更
大数据文摘
2023/02/23
7140
让你捷足先登的深度学习框架
JAX 中文文档(五)
当使用 JIT 模式的 JAX 时,函数将被跟踪、降级到 StableHLO,并针对每种输入类型和形状组合进行编译。在导出函数并在另一个系统上反序列化后,我们就无法再使用 Python 源代码,因此无法重新跟踪和重新降级它。形状多态性是 JAX 导出的一个特性,允许一些导出函数用于整个输入形状家族。这些函数在导出时只被跟踪和降级一次,并且Exported对象包含编译和执行该函数所需的信息,可以在许多具体输入形状上进行编译和执行。我们通过在导出时指定包含维度变量(符号形状)的形状来实现这一点,例如下面的示例:
ApacheCN_飞龙
2024/06/22
6340
JAX 中文文档(五)
推荐阅读
相关推荐
只知道TF和PyTorch还不够,快来看看怎么从PyTorch转向自动微分神器JAX
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验