前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Jax:有望取代Tensorflow,谷歌出品的又一超高性能机器学习框架

Jax:有望取代Tensorflow,谷歌出品的又一超高性能机器学习框架

作者头像
HuangWeiAI
发布2020-03-04 11:31:35
1.7K0
发布2020-03-04 11:31:35
举报
文章被收录于专栏:浊酒清味
前言

在机器学习框架方面,JAX是一个新生事物——尽管Tensorflow的竞争对手从技术上讲已经在2018年后已经很完备,但直到最近JAX才开始在更广泛的机器学习研究社区中获得吸引力。

JAX到底是什么?根据JAX官方介绍:

JAX是NumPy在CPU、GPU和TPU上的版本,具有高性能机器学习研究的强大自动微分(automatic differentiation)能力。

接下来,我们会具体认识JAX。

基础介绍

就像上面说的,JAX是加速器支持的numpy以及大部分scipy功能,带有一些通用机器学习操作的便利函数。

我们举个例子

代码语言:javascript
复制
import jax
import jax.numpy as np

def gpu_backed_hidden_layer(x):
    return jax.nn.relu(np.dot(W, x) + b)

您可以得到numpy精心设计的API,它从2006年就开始使用了,具有Tensorflow和PyTorch等现代ML工具的性能特征。

JAX还包括通过jax.scipy来支持相当大一部分scipy项目:

代码语言:javascript
复制
from jax.scipy.linalg import svd

singular_vectors, singular_values = svd(x)

尽管有加速器支持的numpy + scipy版本已经非常有用,但JAX还有一些其他的妙招。首先让我们看看JAX对自动微分的广泛支持。

自动微分·Autograd

Autograd是一个用于在numpy和原生python代码上高效计算梯度的库。Autograd恰好也是JAX的前身。尽管最初的autograd存储库不再被积极开发,但是在autograd上工作的大部分核心团队已经开始全职从事JAX项目。

就像autograd, JAX允许对一个python函数的输出求导,只需调用grad:

代码语言:javascript
复制
from jax import grad

def hidden_layer(x):
    return jax.nn.relu(np.dot(W, x) + b)

grad_hidden_layer = grad(hidden_layer)

您还可以通过本机的python控制结构进行区分——而不需要使用tf.cond:

代码语言:javascript
复制
def absolute_value(x)
    if x >= 0:
        return x
    else:
        return -x

grad_absolute_value = grad(absolute_value)

JAX还支持获取高阶导数——grad函数可以任意连接:

代码语言:javascript
复制
from jax.nn import tanh

# grads all the way down
print(grad(grad(grad(tanh)))(1.0))

默认情况下,grad 为您提供了逆向模式梯度——这是计算梯度最常用的模式,它依赖于缓存激活来提高向后传递的效率。反模式差分是计算参数更新最有效的方法。但是,特别是在实现依赖于高阶派生的优化方法时,它并不总是最佳选择。JAX通过jacfwd和jacrev为逆向模式自动差分和正向模式自动差分提供了一流的支持:

代码语言:javascript
复制
from jax import jacfwd, jacrev

hessian_fn = jacfwd(jacrev(fn))

除了grad、jacfwd和jacrev之外,JAX还提供了一些实用程序,用于计算函数的线性逼近、定义自定义梯度操作,以及作为其自动微分支持的一部分。

加速线性代数·XLA

XLA (Accelerated Linear Algebra)是一个特定域的线性代数代码编译器,它是JAX将python和numpy表达式转换成加速器支持的操作的基础。

除了允许JAX将python + numpy代码转换为可以在加速器上运行的操作之外(就像我们在第一个示例中看到的那样),XLA支持还允许JAX将多个操作融合到一个内核中。它在计算图中寻找节点簇,这些节点簇可以被重写以减少计算或中间变量的存储。Tensorflow关于XLA的文档使用以下示例来解释问题可以从XLA编译中受益的实例类型。

代码语言:javascript
复制
def unoptimized_fn(x, y, z):
  return np.sum(x + y * z)

在没有XLA的情况下运行,这将作为3个独立的内核运行——一个乘法、一个加法和一个加法减法。使用XLA运行时,这变成了一个负责所有这三个方面的内核,不需要存储中间变量,从而节省了时间和内存。

向量化和并行性

虽然Autograd和XLA构成了JAX库的核心,但是还有两个JAX函数脱颖而出。你可以使用jax.vmap和jax.pmap用于向量化和基于spmd(单程序多数据)并行的pmap。

为了说明vmap的优点,我们将返回到我们的简单稠密层的示例,它操作一个由向量x表示的示例。

代码语言:javascript
复制
# convention to distinguish between 
# jax.numpy and numpy
import numpy as onp

def hidden_layer(x):
    return jax.nn.relu(np.dot(W, x + b)
   
print(hidden_layer(np.random.randn(128)).shape)
# (128,)

我们已经编写了隐含层来获取单个向量输入,但实际上我们几乎总是批量处理输入以利用向量化计算。使用JAX,您可以使用任何接受单个输入的函数,并允许它使用JAX .vmap接受一批输入:

代码语言:javascript
复制
batch_hidden_layer = vmap(hidden_layer)
print(batch_hidden_layer(onp.random.randn(32, 128)).shape)
# (32, 128)

它的美妙之处在于,它意味着你或多或少地忽略了模型函数中的批处理维数,并且在你构造模型的时候,在你的头脑中少了一个张量维数。

如果您有几个输入都应该向量化,或者您想沿着轴向量化而不是沿着轴0,您可以使用in_axes参数来指定。

代码语言:javascript
复制
batch_hidden_layer = vmap(hidden_layer, in_axes=(0,))

JAX用于SPMD paralellism的实用程序,遵循非常类似的API。如果你有一台4-gpu机器和4个例子,你可以使用pmap在每个设备上运行一个例子。

代码语言:javascript
复制
# first dimension must align with number of XLA-enabled devices
spmd_hidden_layer = pmap(hidden_layer)

和往常一样,你可以随心所欲地编写函数:

代码语言:javascript
复制
# hypothetical setup for high-throughput inference
outputs = pmap(vmap(hidden_layer))(onp.random.randn(4, 32, 128))
print(outputs.shape)
# (4, 32, 128)

为什么是JAX?

JAX不是因为它都比现有的机器学习框架更加干净,或者因为它是比Tensorflow PyTorch更好地设计的东西,而是因为它能让我们更容易尝试更多的想法以及探索更广泛的空间。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-02-22,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Python与机器学习之路 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档