在机器学习框架方面,JAX是一个新生事物——尽管Tensorflow的竞争对手从技术上讲已经在2018年后已经很完备,但直到最近JAX才开始在更广泛的机器学习研究社区中获得吸引力。
JAX到底是什么?根据JAX官方介绍:
JAX是NumPy在CPU、GPU和TPU上的版本,具有高性能机器学习研究的强大自动微分(automatic differentiation)能力。
接下来,我们会具体认识JAX。
基础介绍
就像上面说的,JAX是加速器支持的numpy以及大部分scipy功能,带有一些通用机器学习操作的便利函数。
我们举个例子
import jax
import jax.numpy as np
def gpu_backed_hidden_layer(x):
return jax.nn.relu(np.dot(W, x) + b)
您可以得到numpy精心设计的API,它从2006年就开始使用了,具有Tensorflow和PyTorch等现代ML工具的性能特征。
JAX还包括通过jax.scipy来支持相当大一部分scipy项目:
from jax.scipy.linalg import svd
singular_vectors, singular_values = svd(x)
尽管有加速器支持的numpy + scipy版本已经非常有用,但JAX还有一些其他的妙招。首先让我们看看JAX对自动微分的广泛支持。
自动微分·Autograd
Autograd是一个用于在numpy和原生python代码上高效计算梯度的库。Autograd恰好也是JAX的前身。尽管最初的autograd存储库不再被积极开发,但是在autograd上工作的大部分核心团队已经开始全职从事JAX项目。
就像autograd, JAX允许对一个python函数的输出求导,只需调用grad:
from jax import grad
def hidden_layer(x):
return jax.nn.relu(np.dot(W, x) + b)
grad_hidden_layer = grad(hidden_layer)
您还可以通过本机的python控制结构进行区分——而不需要使用tf.cond:
def absolute_value(x)
if x >= 0:
return x
else:
return -x
grad_absolute_value = grad(absolute_value)
JAX还支持获取高阶导数——grad函数可以任意连接:
from jax.nn import tanh
# grads all the way down
print(grad(grad(grad(tanh)))(1.0))
默认情况下,grad 为您提供了逆向模式梯度——这是计算梯度最常用的模式,它依赖于缓存激活来提高向后传递的效率。反模式差分是计算参数更新最有效的方法。但是,特别是在实现依赖于高阶派生的优化方法时,它并不总是最佳选择。JAX通过jacfwd和jacrev为逆向模式自动差分和正向模式自动差分提供了一流的支持:
from jax import jacfwd, jacrev
hessian_fn = jacfwd(jacrev(fn))
除了grad、jacfwd和jacrev之外,JAX还提供了一些实用程序,用于计算函数的线性逼近、定义自定义梯度操作,以及作为其自动微分支持的一部分。
加速线性代数·XLA
XLA (Accelerated Linear Algebra)是一个特定域的线性代数代码编译器,它是JAX将python和numpy表达式转换成加速器支持的操作的基础。
除了允许JAX将python + numpy代码转换为可以在加速器上运行的操作之外(就像我们在第一个示例中看到的那样),XLA支持还允许JAX将多个操作融合到一个内核中。它在计算图中寻找节点簇,这些节点簇可以被重写以减少计算或中间变量的存储。Tensorflow关于XLA的文档使用以下示例来解释问题可以从XLA编译中受益的实例类型。
def unoptimized_fn(x, y, z):
return np.sum(x + y * z)
在没有XLA的情况下运行,这将作为3个独立的内核运行——一个乘法、一个加法和一个加法减法。使用XLA运行时,这变成了一个负责所有这三个方面的内核,不需要存储中间变量,从而节省了时间和内存。
向量化和并行性
虽然Autograd和XLA构成了JAX库的核心,但是还有两个JAX函数脱颖而出。你可以使用jax.vmap和jax.pmap用于向量化和基于spmd(单程序多数据)并行的pmap。
为了说明vmap的优点,我们将返回到我们的简单稠密层的示例,它操作一个由向量x表示的示例。
# convention to distinguish between
# jax.numpy and numpy
import numpy as onp
def hidden_layer(x):
return jax.nn.relu(np.dot(W, x + b)
print(hidden_layer(np.random.randn(128)).shape)
# (128,)
我们已经编写了隐含层来获取单个向量输入,但实际上我们几乎总是批量处理输入以利用向量化计算。使用JAX,您可以使用任何接受单个输入的函数,并允许它使用JAX .vmap接受一批输入:
batch_hidden_layer = vmap(hidden_layer)
print(batch_hidden_layer(onp.random.randn(32, 128)).shape)
# (32, 128)
它的美妙之处在于,它意味着你或多或少地忽略了模型函数中的批处理维数,并且在你构造模型的时候,在你的头脑中少了一个张量维数。
如果您有几个输入都应该向量化,或者您想沿着轴向量化而不是沿着轴0,您可以使用in_axes参数来指定。
batch_hidden_layer = vmap(hidden_layer, in_axes=(0,))
JAX用于SPMD paralellism的实用程序,遵循非常类似的API。如果你有一台4-gpu机器和4个例子,你可以使用pmap在每个设备上运行一个例子。
# first dimension must align with number of XLA-enabled devices
spmd_hidden_layer = pmap(hidden_layer)
和往常一样,你可以随心所欲地编写函数:
# hypothetical setup for high-throughput inference
outputs = pmap(vmap(hidden_layer))(onp.random.randn(4, 32, 128))
print(outputs.shape)
# (4, 32, 128)
为什么是JAX?
JAX不是因为它都比现有的机器学习框架更加干净,或者因为它是比Tensorflow PyTorch更好地设计的东西,而是因为它能让我们更容易尝试更多的想法以及探索更广泛的空间。
本文分享自 Python与机器学习之路 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!