在现代深度学习框架诞生之前,大概科学计算和机器学习领域一直是numpy的天下,或者至少在数学运算和自动微分方面大家都离不开它。虽然够底层够强大同时也相对够简单,不过也不是没有缺点,比如它就不支持硬件加速。虽然能训练深度学习模型,不过先天的不足和Python本身的速度劣势很容易让人敬而远之。
于是乎,后来深度学习框架来了,比如Tensorflow或者Pytorch。在ML相关领域方面,能力表现上,numpy能做的事情深度学习框架都能做,numpy不能做的事情深度学习框架也能做,比如支持GPU,在硬件加速和反向传播方面的优势恰如其分。
似乎深度学习框架已经填补了numpy的问题,要说“取代”可能也是深度学习框架,那还有JAX什么事。其实,JAX更像一个简化版的AI框架,按照官方在代码托管平台说明页面的介绍,它是Python和numpy的融合,两者的可组合转换,虽然它并不是一个神经网络框架。它更侧重于具有自动微分能力、支持GPU/TPU硬件加速的numpy库实现。在CUDA GPU加速方面,JAX表现比较惊艳。
JAX库,是Autograd和XLA的聚合,为高性能机器学习的研究。机智客看说明文档里介绍,通过更新版本的Autograd,JAX可以自动区分本机Python和NumPy函数。它可以通过循环、分支、递归和闭包进行区分,并且可以得到导数的导数。它支持通过梯度的反向模式微分(又称反向传播)和正向模式微分,两者可以任意组合成任意顺序。
因此,JAX不仅没有大型深度学习框架那么庞大繁琐,而且在自动微分和硬件加速上远超numpy,这才是其定位,取长补短,集百家所长,融会贯通。和numpy一样,支持原生Python代码,而且还支持原生numpy。
正如它支持原生Python脚本,所以非常方便我们下载学习。所以我们在安装它的时候,可以像安装其他第三方库一样,直接用pip安装就行。执行pip install jax jaxlib安装命令即可。而在引用时候则可以用import jax或import jax.numpy as jnp或from jax import random,grad, jit, vmap这样的方式。
有人说JAX可能是Pytorch的竞争对手,是Google对抗Pytorch的还击,是针对TensorFlow一些失败问题和总结的反思之作。所以弃其糟粕取其精华,集大成推出了这么一个更简便更有优势的框架。机智客看网上科技号也有人总结AI框架的时候表示,即将到来的2022年,我们深度学习的小伙伴们,有必要学习这么一款框架了。看来势头正劲,各位有兴趣的小伙伴,不妨关注一下。
领取专属 10元无门槛券
私享最新 技术干货