原文:
jax.readthedocs.io/en/latest/jep/9419-jax-versioning.html
jax
和 jaxlib
是独立的包?我们将 JAX 发布为两个独立的 Python 轮子,即纯 Python 轮子 jax
和主要由 C++ 组成的轮子 jaxlib
,后者包含库,例如:
我们发布 jax
作为两个独立的 Python 轮子,即纯 Python 轮子 jax
和主要由 C++ 组成的轮子 jaxlib
,后者包含如下库:
此外,构建 jaxlib
不是廉价的,但我们希望能够在没有大量 CPU 的环境中迭代并运行 JAX 测试,例如在 Github Actions 或笔记本电脑上。我们的许多 CI 构建都简单地使用预构建的 jaxlib
,而不是在每个 PR 上重新构建 JAX 的 C++ 组件。
如我们将看到的,将 jax
和 jaxlib
分开发布也有一定成本,因为需要确保 jaxlib
的变更保持向后兼容的 API。然而,我们认为总体上,使得 Python 的变更变得简单是可取的,即使这会稍微增加 C++ 变更的难度。
jax
和 jaxlib
的版本如何确定?概要:jax
和 jaxlib
在 JAX 源代码树中共享相同的版本号,但作为单独的 Python 包发布。安装时,jax
包版本必须大于或等于 jaxlib
的版本,并且 jaxlib
的版本必须大于或等于 jax
指定的最小 jaxlib
版本。
jax
和 jaxlib
发布版本号均为 x.y.z
,其中 x
是主版本号,y
是次版本号,z
是可选的补丁版本号。版本号必须遵循PEP 440。版本号比较是对整数元组的词典排序比较。
每个 jax
发布版本都有一个关联的最小 jaxlib
版本 mx.my.mz
。对于 jax
版本 x.y.z
,其最小 jaxlib
版本必须不大于 x.y.z
。
对于 jax
版本 x.y.z
和 jaxlib
版本 lx.ly.lz
兼容性要求如下:
jaxlib
版本(lx.ly.lz
)必须大于或等于最小的 jaxlib
版本(mx.my.mz
)。
jax
版本(x.y.z
)必须大于或等于 jaxlib
版本(lx.ly.lz
)。
这些约束意味着发布需遵循以下规则:
jax
而不更新 jaxlib
。
jaxlib
,必须同时发布一个 jax
版本。
当前 jax
在导入时检查这些版本约束,而不是作为 Python 包版本约束来表达。 jax
在运行时检查 jaxlib
版本,而不是使用 pip
包版本约束,因为我们为各种硬件和软件版本(如 GPU、TPU 等)提供单独的 jaxlib
轮子。由于我们不知道哪种选择对任何给定用户来说是正确的,我们不希望 pip
自动为我们安装 jaxlib
包。
将来,我们希望将 jaxlib
的硬件特定部分分离成单独的插件,届时最低版本可以表达为 Python 包依赖性。目前,我们确实提供特定平台的额外要求,以安装兼容的 jaxlib
版本,例如 jax[cuda]
。
jaxlib
的 API 进行更改?jax
可能随时停止与旧版本的 jaxlib
兼容,只要将最低 jaxlib
版本升级到兼容版本即可。但请注意,即使是对于尚未发布的 jax
版本,最低 jaxlib
版本也必须是一个已发布的版本!这允许我们在持续集成构建中使用已发布的 jaxlib
轮子,并允许 Python 开发者在不需要构建 jaxlib
的情况下在 HEAD 上工作。
例如,要在 jax
Python 代码中移除旧的向后兼容路径,只需提高最低 jaxlib
版本然后删除兼容路径即可。
jaxlib
可能会停止与低于其自身发布版本号的旧 jax
发行版的兼容性。 jax
强制执行的版本约束将禁止使用不兼容的 jaxlib
。
例如,要使 jaxlib
放弃一个旧的 jax
版本使用的 Python 绑定 API,必须增加 jaxlib
的次要或主要版本号。
jaxlib
进行更改。
通常,jaxlib
可以自由更改其 API,只要遵循 jax
必须与至少两个 jaxlib
版本兼容的规则。这意味着 jax
必须始终与至少两个 jaxlib
版本兼容,即最后一个发布版本和最新版本(实际上是下一个发布版本)。如果保持兼容性,这将更容易实现,尽管可以通过 jax
的版本测试进行不兼容的更改;请参见下文。
例如,通常可以安全地向 jaxlib
添加新功能,但是如果当前的 jax
仍在使用它,删除现有功能或更改其签名则是不安全的。对 jax
的更改必须在所有大于最低版本的 jaxlib
发行版上运行或逐渐退化。
请注意,此处的兼容性规则仅适用于发布的jax
和jaxlib
版本。它们不适用于未发布的版本;也就是说,如果从未发布或没有发布的jax
版本使用该 API,则可以引入并删除jaxlib
中的 API。
jaxlib
的源代码布局是怎样的?jaxlib
被分为两个主要的存储库,即jaxlib/
主 JAX 存储库的子目录和XLA 源代码树,位于 XLA 存储库内部。XLA 内部的 JAX 特定部分主要位于xla/python
子目录。
JAX 的 C++ 组件,如 Python 绑定和运行时组件,位于 XLA 树内部的原因部分是历史原因,部分是技术原因。
历史原因是最初xla/python
绑定被构想为通用 Python 绑定,可能与其他框架共享。实际上,这种情况越来越少见,xla/python
包含了许多特定于 JAX 的部分,并且可能会包含更多。因此,最好将xla/python
简单地视为 JAX 的一部分。
技术原因在于 XLA C++ API 不稳定。通过将 XLA:Python 绑定保留在 XLA 树中,可以将它们的 C++ 实现与 XLA 的 C++ API 进行原子更新。在 Python API 层面上,维护 Python API 的向后和向前兼容性要比维护 C++ API 更容易,因此xla/python
公开了 Python API 并负责在 Python 层面上维护向后兼容性。
jaxlib
使用 Bazel 从jax
存储库构建。来自 XLA 存储库的jaxlib
部分被合并到构建中 作为 Bazel 子模块。要在构建过程中更新使用的 XLA 版本,必须在 Bazel 的WORKSPACE
中更新固定的版本。这是根据需要手动完成的,但可以根据构建的需求进行覆盖。
jax
和jaxlib
发布之间如何跨界修改?jaxlib
版本是一个粗糙的工具:它只能让我们推断发布版本。
然而,由于jax
和jaxlib
代码分布在无法在单个更改中原子更新的存储库中,我们需要在比我们的发布周期更精细的粒度上管理兼容性。为了管理细粒度兼容性,我们有额外的版本控制,这与jaxlib
发布版本号独立。
我们在XLA 存储库中的xla_client.py
中维护了一个额外的版本号(_version
)。其理念是,这个版本号在xla/python
中与 JAX 的 C++部分一起定义,也可以作为jax._src.lib.xla_extension_version
被 JAX Python 访问,并且在每次对 XLA/Python 代码进行更改且这些更改对jax
的向后兼容性有影响时,都必须递增。JAX Python 代码可以利用这个版本号来维护向后兼容性,例如:
from jax._src.lib import xla_extension_version
# 123 is the new version number for _version in xla_client.py
if xla_extension_version >= 123:
# Use new code path
...
else:
# Use old code path.
请注意,这个版本号是为了帮助管理开发中未发布代码的兼容性而存在的,与已发布版本号的约束额外。发布版本也必须遵循上述兼容性规则。
原文:
jax.readthedocs.io/en/latest/jep/10657-sequencing-effects.html
sharadmv@ May 9 2022
当我们编写 JAX 代码时,通常可以假装我们在编写单线程、即时执行的 Python 代码,尽管在底层,JAX 及其运行时可能在后台异步执行。只要我们编写纯净(无副作用)的代码,这些性能优化通常对我们是不可见的,不会干扰我们的单线程心理模型。异步执行非常棒 — 我们可以获得高效、并行的代码,而无需考虑任何问题!
然而,在存在副作用的情况下,这种幻象开始破裂,我们心理模型的裂缝开始显现。具体来说,当我们考虑副作用发生的顺序时,这些差异就会显现出来。
在这篇设计说明中,我们探讨了 JAX 执行模型与副作用顺序之间的交互。我们还提供了一种强制执行“单线程”副作用顺序的方法。
当我们编写以下 Python 代码时
def f():
print("hello")
return 2
def g():
print("world")
return 3
f()
g()
我们期望 "hello"
在 "world"
之前被打印出来。这似乎是显而易见的,但考虑以下 JAX 代码:
@partial(jax.jit, device=<device 0>)
def f():
return 2
@partial(jax.jit, device=<device 1>)
def g():
return 3
f()
g()
在许多情况下,JAX 将并行执行 f
和 g
,将计算分发到不同的线程 —— g
可能会在 f
之前执行。并行执行是一种很好的性能优化,特别是在设备间的复制成本昂贵时(详见异步调度说明了解更多详情)。然而,在实践中,我们通常不需要考虑异步调度,因为我们编写的是纯函数,只关心函数的输入和输出 —— 我们自然会在未来的值上阻塞。
但是,现在想象一下,我们有一个 jax.print
函数,可以在 JIT 编译的 JAX 函数内部工作(例如 host_callback.id_print
就是一个例子)。让我们回到之前的例子,但现在加入了打印输出。
@partial(jax.jit, device=<device 0>)
def f():
jax.print("hello")
return 2
@partial(jax.jit, device=<device 1>)
def g():
jax.print("world")
return 3
f()
g()
由于异步调度的存在,我们实际上可以看到 "world"
在 "hello"
之前被打印出来。打印输出副作用的重新排序破坏了单线程执行模型的幻象。
另一个副作用可以“揭示”无序执行的示例是当我们编译 JAX 程序时。考虑以下 JAX 代码:
@jax.jit
def f(x):
jax.print("hello")
jax.print("world")
return x
尽管在 Python 中,我们先写了 "hello"
的打印,然后是 "world"
的打印,但是像 XLA 这样的编译器可以自由地重新排序它们,因为这两个打印之间没有显式的数据依赖关系。
我们希望支持“有序”效果。所谓有序,意味着效果发生的顺序与我们在执行单线程 Python 程序时的顺序相同。这是我们的主要愿望。在存在显式并行性(如pmap
或用户线程)的情况下,我们不需要保持这种行为,但至少如果用户没有显式请求并行性,我们希望保持单线程顺序。
在深入讨论之前,让我们先退后一步,问问自己,如果我们为了性能而重新排序效果,这样做是否可以接受?反之,我们是否需要完全强制效果的顺序?在某些情况下,我们不需要排序。也许某些副作用不应该影响 JAX 程序的性能。然而,对于其他副作用,我们可能希望强制单线程程序顺序,以防止用户得到反直觉的行为。考虑一个日志效果。
@jax.jit
def f(x, y):
log_value(x)
log_value(y)
f(1, 2)
如果log
正在改变全局列表,我们可能期望在添加y
之前添加x
。为了更严格的效果,我们可能希望能够对效果进行排序。
我们用来强制计算顺序的主要工具是数据依赖性。简单来说,如果函数g
的输入是函数f
的输出,那么必须先执行f
,再执行g
。
然而,我们可能会有像打印这样的副作用,这些副作用根本没有任何输入,因此我们无法简单地对它们进行排序。因此,我们使用令牌作为向计算中注入人为数据依赖性的手段。
什么是令牌?令牌只是可以在计算中穿插的虚拟值。通过在多个计算中穿插相同的令牌,我们强制它们按照特定顺序进行。让我们看看前面的打印示例,加入令牌后会是什么样子:
@jax.jit
def f(token, x):
token = jax.print(token, "hello")
token = jax.print(token, "world")
return token, x
如果我们重写jax.print
以接受并返回一个令牌,我们现在已经按顺序序列化了两个打印,因为第二个打印的输入依赖于第一个打印的输出。实际上,token
的实际值可以是任何东西,但我们会看到,这些令牌对用户来说是不可见的。
现在我们将开始讨论实现细节。实际上,我们需要两种不同类型的令牌来序列化效果:一种用于上述重新排序的每种源,我们需要运行时令牌来序列化异步调度的有副作用的计算,我们还需要编译器令牌来序列化计算内部的效果。
实际上,我们的计算将重写为以下形式:
@jax.jit
def f(runtime_token, x):
compiler_token = new_compiler_token()
compiler_token = jax.print(compiler_token, "hello")
compiler_token = jax.print(compiler_token, "world")
return runtime_token, x
注意运行时令牌仅在 JIT 边界使用,而编译器令牌仅在编译后的代码中使用。编译器令牌是在“降级”过程中创建的(我们将 Python 代码转换为类似 HLO 或 StableHLO 的低级表示),但运行时令牌需要在 Python 中进行管理,因为它们在 JIT 化的函数中穿插输入和输出。
此外,请注意运行时令牌与编译器令牌之间是“断开”的,这意味着它们之间没有数据依赖关系。这可能是危险的,因为我们会失去两个调度函数调用体之间的数据依赖性。然而,如果我们假设“严格执行”——即一个调度函数只有在其所有输入准备就绪且所有输出同时准备就绪时才会开始执行——我们可以安全地创建一个新的编译器令牌,并返回一个不依赖于输出的运行时令牌。
为了代表用户管理运行时令牌,我们需要插入到 JAX 的调度机制中。每当我们调用 JIT 编译的函数时,我们最终会得到一个看起来像这样的函数:
def _execute(compiled_computation, *args):
outputs = compiled_computation.execute(*args)
return outputs
此时我们需要"注入"运行时令牌到计算中,并从计算的输出中"提取"它们:
def _execute(compiled_computation, *args):
runtime_token = get_runtime_token() # Grab global token
runtime_token, *outputs = compiled_computation.execute(runtime_token, *args)
update_runtime_token(runtime_token) # Update global token
return outputs
runtime_token
究竟是什么?嗯,我们需要能够将其传递给compiled_computation
,这意味着它需要是某种数组(目前来说,由于在编译的 JAX 代码内外没有共享的令牌表示,我们可以使用一个(0,)
形状的数组来最小化开销)。
我们还需要考虑多设备使用情况,例如第一个示例中,我们首先在设备 0 上调用 JIT 编译的函数,然后在设备 1 上调用另一个函数。在这种情况下,我们还需要将第一个计算返回的运行时令牌(位于设备 0 上)复制到设备 1,以便将其传递给第二个计算。如果两个后续计算共享相同的设备,则此复制是不必要的。
当我们将 Python 代码降级为 HLO 或 StableHLO 时,我们需要在计算开始时创建一个令牌,并确保在需要对顺序进行排序的副作用计算时可用。副作用计算将该令牌作为输入,并将其作为输出返回。
实现此令牌线程涉及升级 JAX 降级机制以自动进行此类记账。主要挑战涉及处理像调用原语和控制流原语这样的高阶原语。在本设计说明中,我们不会详细讨论如何处理这些挑战。
为运行时和编译器令牌增加支持以进行副作用计算序列化是很重要的,但令牌还有另一个微妙的用例,即在副作用计算上阻塞。即使我们不希望副作用计算是有序的,我们可能仍然希望等待其完成。目前我们有jax.block_until_ready
,它会等待直到未来的值准备就绪。然而,对于副作用计算,我们可能有一些没有返回值但仍在执行副作用的函数。以这里的简单示例为例:
@jax.jit
def f():
jax.print("hello world")
return
f() # Executed asynchronously
这个编译后的计算不接受任何显式输入,也没有显式输出。如果它是一个有序的打印效果,我们可以阻塞返回的运行时令牌,但是当这是一个无序计算时,我们不执行任何令牌线程。当我们没有输出值来调用block_until_ready
时,我们如何等待f()
执行结束呢?嗯,我们可以应用相同的令牌策略,除了我们只返回运行时令牌而不将它们作为输入。这将给我们一个可以阻塞的值,该值仅在f()
执行完成后才会准备好。我们将这些令牌称为输出令牌。我们最终得到了如下所示的函数:
@jax.jit
def f():
jax.print("hello world")
return new_runtime_token()
f() # Executed asynchronously
在幕后,我们将以与管理运行时令牌相同的方式来管理输出令牌,但提供一种方法让用户在当前一组输出令牌上阻塞。与运行时令牌不同,输出令牌需要是特定于设备的。考虑单设备使用情况:
@jax.jit
def f():
jax.print("hello")
@jax.jit
def g():
jax.print("world")
f()
g()
由于f()
和g()
在同一设备上执行,阻塞g()
的输出令牌有效地阻塞了f()
,因为(目前为止!),JAX 运行时不会交错执行在同一设备上执行的计算。当然,如果情况改变,我们将不得不重新审视整个设计。
然而,考虑两个设备使用情况:
@partial(jax.jit, device=<device 0>)
def f():
jax.print("hello")
@partial(jax.jit, device=<device 1>)
def g():
jax.print("world")
f()
g()
这里我们不想显式地序列f()
和g()
,但是希望等待它们都完成。我们需要一个f()
的输出令牌和一个g()
的输出令牌,并且我们将阻塞在这两个令牌上:
@partial(jax.jit, device=<device 0>)
def f():
jax.print("hello")
return new_runtime_token()
@partial(jax.jit, device=<device 1>)
def g():
jax.print("world")
return new_runtime_token()
t0 = f()
t1 = g()
block_until_ready((t0, t1))
因此,我们需要每个设备的输出令牌,这样我们就可以避免在不同设备上对计算进行排序,同时可以阻塞具有副作用的计算。我们最终得到了以下(大致)对 JAX 调度机制的更改:
def _execute(compiled_computation, *args):
output_token, *outputs = compiled_computation.execute(runtime_token, *args)
update_output_token(output_token, compiled_computation.device)
return outputs
我们还需要暴露一个函数来阻塞输出令牌:
def effects_barrier():
output_token.block_until_ready()
注意,阻塞输出令牌可能不太常见,因为大多数 JAX 计算将返回一个值来阻塞。然而,输出令牌对于测试和分析非常有用,并且支持它们是很好的,这样我们就有了一个一致且有条理的效果系统。
原文:
jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html
自 #11830 开始,我们正在启用新的 jax.checkpoint()
实现,也称为 jax.remat()
(两个名称是互为别名)。对于大多数代码,不会有任何更改。 但在边缘情况下可能会有一些可观察的差异;参见升级后可能出现的问题有哪些?
如果您对此更改有问题,截至 jax==0.3.16
版本,可以通过将 jax_new_checkpoint
配置选项设置为 False
关闭新实现,以下是任何一种方法:
JAX_NEW_CHECKPOINT=0
;
jax.config.update('jax_new_checkpoint', False)
;
absl
解析标志,请传递 --jax_new_checkpoint=False
选项。
如果您需要恢复到旧版本,请在 GitHub 问题上联系我们,以便我们为您使新版本正常工作。
从 jax==0.3.17
版本开始,不再提供 jax_new_checkpoint
配置选项。如果您遇到问题,请在问题跟踪器上联系我们以帮助解决!
截至撰写时,JAX 有两个并行实现的 jax.checkpoint
。新版本已经在几个月内(例如 Pax 和 Flaxformer/T5X)按选择使用。但默认情况下尚未启用。
我们希望将新实现设置为默认启用,并删除旧实现。使用新实现并删除旧实现将为用户带来多种好处。
新实现的主要优势是与 policy
参数对应的新功能。其目的是在自动微分的前向传递过程中,精确控制哪些中间结果保存(而不是重新计算)。通过控制内存使用与重新计算之间的权衡,用户可以获得显著的性能优势,尤其是在大型模型和我们的 LLM MLPerf 提交中!
此功能的完整文档尚未发布,但以下是一个快速示例:
from functools import partial
import jax
def apply_layer(W, x):
return jnp.sin(jnp.dot(W, x))
@partial(jax.checkpoint, policy=jax.checkpoint_policies.checkpoint_dots)
def predict(params, x):
for W in params[:-1]:
x = apply_layer(W, x)
return jnp.dot(params[-1], x)
通过在这里应用jax.checkpoint
与policy=jax.checkpoint_policies.checkpoint_dots
,我们确保只有矩阵乘法的结果在正向传播期间被保存。从cos
应用中得到的雅可比系数值以及计算它们所需的sin
应用的值在正向传播期间不会被保存,而是在反向传播期间重新计算。(像这样的策略在 TPU 上非常有效,其中逐元素计算实际上是免费的,但来自矩阵单元的结果值是值得保存的。)
旧的jax.checkpoint
实现实际上不能在没有对装饰函数参数的数据依赖时重新生成计算。考虑这个玩具示例:
@jax.checkpoint
def f(x):
a = some_function(jnp.arange(10_000_000)) # `a` does not depend on `x`
return a * x
旧的jax.checkpoint
实现被迫保存a
的值,这可能需要大量内存。新的jax.checkpoint
实现可以重新生成而不是保存a
的值。
在某些情况下,新的jax.checkpoint
在 Python 开销方面显著减少。简单的开销基准测试变快了 10 倍。这些开销只在急切的逐步执行中出现,因此在通常情况下,在jax.jit
或类似方法下使用jax.checkpoint
时,这些加速并不相关。但仍然很不错!
这个改变也为未来用户带来了很大的好处,比如自定义批处理规则(vmap
的类比custom_vjp
)以及custom_vjp
的向前可微升级。它还显著减少了 JAX 代码库中某些部分的复杂性,这对于一般的可维护性和错误修复都是有好处的。
因为新的实现可以重新生成更多的计算,包括那些可能很大的常数,所以一些代码可能会看到小的数值变化。任何数值变化的幅度应该在我们预期的编译器优化变化范围内,例如浮点操作的重新排序。但某些过于严格的测试容差可能需要略微放宽。
concrete=True
被移除了。旧的jax.checkpoint
实现有一个布尔选项concrete
,允许跟踪具体的 Python 值(而不是延迟所有计算,并仅在抽象值上进行跟踪)。该选项很少被使用,而在使用它的情况下,有更简单的替代方案。因此,在新的jax.checkpoint
中我们移除了这个选项。
例如,在 Google 代码中,使用concrete=True
的压倒性常见用法是支持传递像is_training
这样的参数:
@partial(jax.checkpoint, concrete=True) # OLD jax.checkpoint API
def foo(x, is_training):
if is_training:
return g(x)
else:
return h(x)
使用新的jax.checkpoint
实现,我们可以使用static\_argnums
选项完成相同的功能:
@partial(jax.checkpoint, static_argnums=(1,)) # NEW jax.checkpoint API
def foo(x, is_training):
if is_training:
...
如果需要在静态参数上执行jax.numpy
操作,并且它们的数值结果在 Python 追踪期间计算而不是延迟计算,我们可以使用jax.ensure_compile_time_eval()
与static_argnums
。但似乎你不太可能需要这样做!
原文:
jax.readthedocs.io/en/latest/jep/12049-type-annotations.html
Python 3.0 引入了可选的函数注释(PEP 3107),这些注释后来在 Python 3.5 发布时被规范为静态类型检查的一部分(PEP 484)。在很大程度上,类型注释和静态类型检查已经成为许多 Python 开发工作流程的一个重要组成部分,为此我们在 JAX API 的许多地方添加了注释。目前在 JAX 中的类型注释有些零散,增加更多注释的努力受到了更基本的设计问题的阻碍。本文试图总结这些问题,并为 JAX 中类型注释的目标和非目标制定路线图。
为什么我们需要这样的路线图?更好/更全面的类型注释是用户(包括内部和外部用户)经常提出的请求。此外,我们经常收到来自外部用户的拉取请求(例如,PR#9917,PR#10322),试图改进 JAX 的类型注释:对于 JAX 团队成员来说,审查此类贡献是否有益并不总是清楚,特别是当它们引入复杂的协议来解决 JAX 对 Python 的完全注释所固有的挑战时。本文详细介绍了 JAX 对包中类型注释的目标和建议。
有许多原因使得 Python 项目希望对其代码库进行注释;我们将在本文档中总结为 Level 1、Level 2 和 Level 3。
最初在 PEP 3107 中引入时,类型注释部分是由于可以将其用作函数参数类型和返回类型的简洁内联文档。JAX 长期以来一直以这种方式使用注释;一个例子是常见的创建类型名称并将其别名为 Any
的模式。可以在 lax/slicing.py
中找到一个例子[source]:
Array = Any
Shape = core.Shape
def slice(operand: Array, start_indices: Sequence[int],
limit_indices: Sequence[int],
strides: Optional[Sequence[int]] = None) -> Array:
...
出于静态类型检查的目的,这种使用 Array = Any
用于数组类型注释对参数值没有任何限制(Any
等同于没有注释),但它确实作为开发人员在代码中有用的形式化文档。
为了生成文档,别名的名称会丢失(jax.lax.slice
的HTML 文档将操作数报告为类型Any
),因此文档的好处并未超出源代码(尽管我们可以启用一些sphinx-autodoc
选项来改进此功能:参见autodoc_type_aliases)。
这种类型注解的一个好处是,用Any
注释一个值永远不会错,因此它将以文档的形式为开发者和用户提供实际的好处,而无需满足任何特定静态类型检查器更严格的需求的复杂性。
许多现代 IDE 利用类型注解作为智能代码补全系统的输入。其中一个例子是 VSCode 的Pylance扩展,它使用微软的pyright静态类型检查器作为 VSCode IntelliSense完成的信息源。
这种类型检查的使用需要比上述简单的别名更深入的了解;例如,知道slice
函数返回一个名为Array
的Any
别名并不会为代码完成引擎增添任何有用的信息。然而,如果我们用DeviceArray
标注函数的返回类型,自动完成将了解如何填充结果的命名空间,因此在开发过程中能够提供更相关的自动完成建议。
JAX 已经在几个地方开始添加这种类型注解的级别;一个例子是jax.random
包中的jnp.ndarray
返回类型 [来源]:
def shuffle(key: KeyArray, x: Array, axis: int = 0) -> jnp.ndarray:
...
在这种情况下,jnp.ndarray
是一个抽象基类,用于预先声明 JAX 数组的属性和方法(见源代码),因此 VSCode 中的 Pylance 可以为该函数的结果提供完整的自动完成集合。这里是显示结果的屏幕截图:
在自动完成字段中列出了抽象ndarray
类声明的所有方法和属性。我们将在下面进一步讨论为什么需要创建这个抽象类,而不是直接用DeviceArray
进行注释。
当今,静态类型检查通常是人们在考虑 Python 代码中类型注解目的时首先考虑的事情。虽然 Python 不会对类型进行任何运行时检查,但存在几种成熟的静态类型检查工具,可以作为 CI 测试套件的一部分进行此类检查。对于 JAX 来说,最重要的工具如下:
完全的静态类型检查是所有类型注解应用中最严格的,因为它会在您的类型注解不精确时立即出现错误。一方面,这很好,因为您的静态类型分析可能会捕获到错误的类型注解(例如,DeviceArray
方法在 jnp.ndarray
抽象类中缺失的情况)。
另一方面,这种严格性可能会使得依赖鸭子类型而不是严格类型安全 API 的软件包在类型检查过程中变得非常脆弱。你会在 JAX 代码库中当前发现大量像 #type: ignore
(对于 mypy)或 #pytype: disable
(对于 pytype)这样的代码注释。这些通常代表了出现类型问题的情况;它们可能是 JAX 类型注解中的不准确之处,或者是静态类型检查器在正确跟踪代码控制流时的不准确之处。偶尔,它们可能是由于 pytype 或 mypy 行为中真正而微妙的错误造成的。在罕见的情况下,它们可能是由于 JAX 使用了在 Python 的静态类型注解语法中难以甚至不可能表达的 Python 模式。
JAX 目前的类型注解是不同风格的混合,并针对上述所有三个类型注解层级。部分原因是因为 JAX 的源代码对 Python 的类型注解系统提出了许多独特的挑战。我们将在这里概述它们。
JAX 目前面临的一个挑战是,软件包开发必须满足两种不同静态类型检查系统的约束,即 pytype
(用于内部 CI 和 Google 内部项目)和 mypy
(用于外部 CI 和外部依赖)。尽管这两种类型检查器在行为上有广泛的重叠,但每种都展示了其独特的特例情况,这可以从 JAX 代码库中遍布的众多 #type: ignore
和 #pytype: disable
语句中看出。
这给开发带来了摩擦:内部贡献者可能会迭代直到测试通过,然后发现在导出时他们通过 pytype 验证的代码在 mypy 中不符合要求。对于外部贡献者来说,情况通常相反:一个最近的例子是#9596,在未能通过 Google 内部的 pytype 检查后不得不回滚。每次我们将类型注释从第 1 级(到处都是Any
)移动到第 2 或第 3 级(更严格的注释),都会增加这种令人沮丧的开发体验的可能性。
注释 JAX 代码的一个特殊挑战是其广泛使用的鸭子类型。一般情况下标记为Array
的函数的输入可能是许多不同类型之一:JAX 的DeviceArray
、NumPy 的np.ndarray
、NumPy 标量、Python 标量、Python 序列、带有__array__
属性的对象、带有__jax_array__
属性的对象或任何jax.Tracer
的变体。因此,简单的注释如def func(x: DeviceArray)
将不足以满足要求,并且会导致许多有效用法的误报。这意味着对于 JAX 函数的类型注释不会简短或琐碎,但我们必须有效地开发一组类似于numpy.typing
包中的 JAX 特定类型扩展。
JAX 的 Python API 严重依赖于函数转换(jit()
、vmap()
、grad()
等),这种类型的 API 对静态类型分析提出了特殊挑战。装饰器的灵活注释一直是 mypy 包的长期问题,最近才通过引入ParamSpec
(详见PEP 612,并在 Python 3.10 中添加)解决。因为 JAX 遵循NEP 29,在 2024 年中期之后才能依赖 Python 3.10 的功能。与此同时,Protocols 可作为部分解决方案使用(JAX 在#9950中为 jit 和其他方法添加了此功能),而 ParamSpec 可以通过typing_extensions
包使用(原型在#9999中),尽管这目前揭示了 mypy 中的基本错误(见python/mypy#12593)。总之:目前尚不清楚 JAX 函数转换的 API 是否能在当前 Python 类型注释工具的限制下得到适当注释。
另一个挑战是 Python 所有面向数组的 API 共同面临的问题,多年来一直是 JAX 讨论的一部分(见#943)。类型注解涉及对象的 Python 类或类型,而在基于数组的语言中,类的属性通常更为重要。在 NumPy、JAX 及类似包中,我们经常希望注释特定的数组形状和数据类型。
例如,jnp.linspace
函数的参数必须是标量值,但在 JAX 中,标量由零维数组表示。因此,为了使注释不引发误报,我们必须允许这些参数是任意数组。另一个例子是jax.random.choice
的第二个参数,在shape=()
时必须具有dtype=int
。Python 计划通过可变类型泛型(参见PEP 646,计划用于 Python 3.11)来实现类型注解的这种粒度,但像ParamSpec
一样,支持这一功能还需要一段时间来稳定。
在此期间,有一些第三方项目可能会有所帮助,特别是google/jaxtyping,但这些使用非标准注解,可能不适用于对核心 JAX 库本身进行注释。总的来说,数组类型粒度挑战的问题不如其他挑战那么严重,因为主要影响是数组类似的注释将不如其本应该的那样具体。
JAX 用户界面 API 的大部分内容都继承自jax.numpy
子模块中的 NumPy。NumPy 的 API 在 Python 语言静态类型检查成为一部分之前就已经开发多年,遵循 Python 的历史建议使用一种鸭子类型/EAFP编码风格,其中不鼓励在运行时进行严格的类型检查。作为具体例子,考虑numpy.tile()
函数,它的定义如下:
def tile(A, reps):
try:
tup = tuple(reps)
except TypeError:
tup = (reps,)
d = len(tup)
...
这里的意图是reps
应该包含一个int
或者一个int
值的序列,但实现允许tup
是任何可迭代的对象。在对这种鸭子类型的代码添加注释时,我们可以采取两种路线之一:
reps: Union[int, Sequence[int]]
的内容。
reps: Union[ConvertibleToInt, Iterable[ConvertibleToInt]]
,其中 ConvertibleToInt
是一个特殊的协议,涵盖了我们的函数将输入转换为整数的确切机制(即通过 __int__
、通过 __index__
、通过 __array__
等)。此外,请注意,从严格意义上讲,Iterable
在这里是不足够的,因为在 Python 中有些对象虽然通过 __getitem__
是可迭代的,但不能满足静态类型检查的 Iterable
(比如,用于 __iter__
而不是 __getitem__
的对象)。
#1 的优势,在于注释意图,是注释在传达 API 合约时对用户更有用;而对于开发者来说,灵活性则为在必要时重构留下了余地。缺点(特别是对于像 JAX 这样的渐进式类型 API 来说)是,现有用户代码很可能是运行正确的,但在类型检查器中会被标记为不正确。现有鸭子类型 API 的渐进类型化意味着当前的注释隐式是 Any
,因此将其更改为更严格的类型可能会对用户产生破坏性的改变。
总体而言,在 IDE 注释中,更好地服务于 Level 1 类型检查的是注释意图,而更好地服务于 Level 3 的是注释实现,而 Level 2 则是一种混合体(在 IDE 注释中,注释意图和实现都很重要)。
在这种(Level 1/2/3)和 JAX 特定挑战的框架下,我们可以开始制定我们在 JAX 项目中实施一致类型注释的路线图。
对于 JAX 类型注释,我们将遵循以下原则:
尽可能地,我们希望支持完整的Level 1、2 和 3类型注释。特别是这意味着我们应该对公共 API 函数的输入和输出都进行严格的类型注释。
JAX 类型注释通常应该指示 API 的意图,而不是实现,以便注释在传达 API 合约时变得有用。这意味着有时在运行时有效的输入,在静态类型检查器中可能不被识别为有效(一个例子可能是将任意迭代器传递到标注为 Shape = Sequence[int]
的形状位置)。
JAX 函数和方法的输入应尽可能宽松地进行类型标注:例如,虽然形状通常是元组,但接受形状的函数应接受任意序列。类似地,接受数据类型的函数不必要求是 np.dtype
的实例,而是任何可转换为数据类型的对象。这可能包括字符串、内置标量类型或标量对象构造函数,如 np.float64
和 jnp.float64
。为了使整个包尽可能统一,我们将添加一个 jax.typing
模块,其中包含常见的类型规范,从广义类别开始,例如:
ArrayLike
将是可以隐式转换为数组的任何内容的联合:例如,jax 数组、numpy 数组、JAX 追踪器以及 Python 或 numpy 标量。
DTypeLike
将是可以隐式转换为数据类型的任何内容的联合:例如,numpy 数据类型、numpy 数据类型对象、jax 数据类型对象、字符串和内置类型。
ShapeLike
将是可以转换为形状的任何内容的联合:例如,整数或类整数对象的序列。
注意,这些通常比 numpy.typing
中使用的等效协议要简单。例如,在 DTypeLike
的情况下,JAX 不支持结构化数据类型,因此 JAX 可以使用更简单的实现。同样地,在 ArrayLike
中,JAX 通常不支持列表或元组输入来代替数组,因此类型定义将比 NumPy 的类似物简单。
相反,函数和方法的输出应尽可能严格地进行类型标注:例如,对于返回数组的 JAX 函数,输出应该用类似 jnp.ndarray
的方式进行注释,而不是 ArrayLike
。返回数据类型的函数应始终注释为 np.dtype
,返回形状的函数应始终为 Tuple[int]
或严格类型的 NamedShape 等效物。为此,我们将在 jax.typing
中实现几个严格类型化的类似于上述宽松类型的模拟,即:
Array
或 NDArray
(见下文)实际上等效于 Union[Tracer, jnp.ndarray]
,应用于数组输出的标注。
DType
是 np.dtype
的别名,可能还具有表示 JAX 中使用的关键类型和其他泛化类型的能力。
Shape
本质上是 Tuple[int, ...]
,可能具有一些额外的灵活性以适应动态形状的情况。
NamedShape
是 Shape
的扩展,允许在 JAX 内部使用的命名形状。
我们还将探讨是否可以放弃当前的 jax.numpy.ndarray
实现,以支持将 ndarray
作为 Array
或类似物的别名。
除了在jax.typing
中收集的常见类型协议之外,我们应该偏向简单。在传递给 API 函数的参数的类型规范无法简洁指定的情况下,我们应避免构建过于复杂的联合,而是使用简单的联合,如Union[simple_type, Any]
。这是一个妥协,旨在实现 Level 1 和 Level 2 的注解目标,同时避免不必要的复杂性,暂时放弃 Level 3。
为了不给开发带来不必要的摩擦(由于内部/外部 CI 差异),我们希望在使用类型注解构造时保守一些:特别是在涉及最近引入的机制如ParamSpec
(PEP 612)和可变类型泛型(PEP 646)时,我们希望在 mypy 和其他工具支持成熟且稳定之前等待。
其中一个影响是,目前在函数被 JAX 转换(如jit
、vmap
、grad
等)装饰时,JAX 将有效地剥离所有注解。尽管这很不幸,但在撰写本文时,mypy 与ParamSpec
提供的潜在解决方案存在一长串的不兼容性(见ParamSpec
mypy bug tracker),因此我们认为目前尚不适合在 JAX 中全面采用。在未来,一旦对此类特性的支持稳定下来,我们将重新审视这个问题。
同样地,目前我们将避免添加由jaxtyping项目提供的更复杂和更精细的数组类型注解。这是我们可以在未来重新审视的决定。
Array
类型设计考虑因素如上所述,对于 JAX 中数组的类型注解,由于 JAX 广泛使用鸭子类型,即在 jax 转换中传递和返回Tracer
对象而不是实际的数组,这带来了独特的挑战。这变得越来越令人困惑,因为用于类型注解的对象通常与用于运行时实例检查的对象重叠,可能与所讨论对象的实际类型层次结构相对应也可能不相对应。对于 JAX,我们需要为两个上下文提供鸭子类型对象的解决方案:静态类型注解和运行时实例检查。
以下讨论将假设jax.Array
是运行时设备上数组的类型,尽管目前尚未实现,但一旦在#12016中完成工作,将会实现。
我们需要提供一个可以用于鸭子类型注解的对象。假设我们暂时称此对象为ArrayAnnotation
,我们需要一个解决方案,能够满足像下面这样的案例在mypy
和pytype
中的要求:
@jit
def f(x: ArrayAnnotation) -> ArrayAnnotation:
assert isinstance(x, core.Tracer)
return x
这可以通过多种方法实现,例如:
ArrayAnnotation = Union[Array, Tracer]
Tracer
和Array
应被视为ArrayAnnotation
的子类。
Array
和Tracer
,使ArrayAnnotation
成为两者的真实基类。
同时,我们必须提供一个可用于鸭子类型运行时isinstance
检查的对象。假设我们暂时称之为ArrayInstance
,我们需要一个能通过以下运行时检查的解决方案:
def f(x):
return isinstance(x, ArrayInstance)
x = jnp.array([1, 2, 3])
assert f(x) # x will be an array
assert jit(f)(x) # x will be a tracer
再次强调,可以使用几种机制来实现这一点:
type(ArrayInstance).__instancecheck__
以使得对Array
和Tracer
对象返回True
;这就是jnp.ndarray
当前实现的方式(来源)。
ArrayInstance
定义为一个抽象基类,并动态注册到Array
和Tracer
。
Array
和Tracer
,使ArrayInstance
成为两者的真实基类。
我们需要做出的决定是ArrayAnnotation
和ArrayInstance
应该是相同的还是不同的对象。这里有一些先例;例如,在核心 Python 语言规范中,typing.Dict
和 typing.List
存在于注解的缘故,而内置的 dict
和 list
用于实例检查的缘故。然而,在较新的 Python 版本中,Dict
和 List
已被弃用,推荐使用dict
和list
用于注解和实例检查。
在 NumPy 的情况下,np.typing.NDArray
用于类型注解,而 np.ndarray
用于实例检查(以及数组类型识别)。基于此,遵循 NumPy 的先例并实现以下操作可能是合理的:
jax.Array
是在设备上数组的实际类型。
jax.typing.NDArray
是用于鸭子类型数组注解的对象。
jax.numpy.ndarray
是用于鸭子类型数组实例检查的对象。
对于 NumPy 的高级用户来说,这可能会感觉有些自然,然而这种三分法可能会导致混淆:在选择用于实例检查和注解的对象时并不明显。
另一种方法是通过上述覆盖机制统一类型检查和注解。
部分统一可能如下所示:
jax.Array
是在设备上数组的实际类型。
jax.typing.Array
是用于鸭子类型数组注解的对象(通过.pyi
接口在Array
和Tracer
上)。
jax.typing.Array
同样用于鸭子类型实例检查(通过其元类中的__isinstance__
覆盖)
在这种方法中,jax.numpy.ndarray
将成为向后兼容的简单别名jax.typing.Array
。
或者,我们可以通过覆盖选择完全统一:
jax.Array
是设备上数组的实际类型。
jax.Array
也是用于鸭子类型数组注释的对象(通过 Tracer
上的 .pyi
接口)。
jax.Array
也是用于鸭子类型实例检查的对象(通过其元类中的 __isinstance__
覆盖)。
在这里,jax.numpy.ndarray
将成为向后兼容的简单别名 jax.Array
。
最终,我们可以通过重组类层次结构并将鸭子类型替换为面向对象的对象层次结构来选择完全统一:
jax.Array
是设备上数组的实际类型。
jax.Array
也是用于数组类型注释的对象,通过确保 Tracer
继承自 jax.Array
来实现。
jax.Array
也是通过相同机制进行实例检查的对象。
在这里,jnp.ndarray
可以是 jax.Array
的一个别名。从面向对象设计的角度来看,这种最终方法在某些方面可能是最纯粹的,但从面向对象设计的角度来看,它有些强行(Tracer
是一个 Array
?)。
我们可以通过使 Tracer
和设备上数组的类继承自一个共同的基类来使类层次结构更合理。因此,例如:
jax.Array
同时也是 Tracer
的基类以及设备上数组的实际类型,可能是 jax._src.ArrayImpl
或类似的。
jax.Array
也是用于数组类型注释的对象。
jax.Array
也是用于实例检查的对象。
在这里,jnp.ndarray
将是 Array
的一个别名。从面向对象编程的角度来看,这可能更加纯粹,但与选项 2 和 3 相比,它取消了 type(x) is jax.Array
为 True 的概念。
综合考虑每种潜在方法的优势和劣势:
jax.Array
是你需要知道的全部。
Tracer
成为数组的子类。这打破了继承模型,因为它要求 Tracer
对象承载 Array
对象的所有负担(数据缓冲区、分片、设备等)。
jax._src.ArrayImpl
)。但绝大多数用户永远不需要直接触及这个私有实现。
这里有不同的权衡,但经过讨论,我们决定采用 Option 4 作为我们的前进方式。
为了推进类型注释,我们将执行以下操作:
jax._src.typing
(目前不提供任何公共 API),并将上述简单类型的第一级放入其中:
Array = Any
作为别名,因为这需要更多的思考。
ArrayLike
:作为输入传递给常规jax.numpy
函数的类型的联合。
DType
/ DTypeLike
(注意:numpy 使用驼峰式DType
;我们应该遵循这个惯例以便使用)。
Shape
/ NamedShape
/ ShapeLike
这些工作的开端在#12300已经完成。
jax.Array
基类上进行工作,该类遵循前一节中的第 4 个选项。最初,这将在 Python 中定义,并使用目前在jnp.ndarray
实现中找到的动态注册机制,以确保isinstance
检查的正确行为。为每个 tracer 和类似数组的类创建一个pyi
覆盖,以确保类型注释的正确行为。然后,jnp.ndarray
可以成为jax.Array
的别名。
jax.lax
中的函数。
jax.Array
基类,以便ArrayImpl
和Tracer
可以继承它。使用pyi
定义确保静态类型检查器识别类的适当属性。
jax.Array
和jax._src.ArrayImpl
完全完成,就删除这些临时的 Python 实现。
jax.typing
模块,使上述类型对用户可用,并提供使用 JAX 的代码注释最佳实践的文档。
我们将在#12049中跟踪这项工作,从中获取本 JEP 的编号。
shmap
(shard_map)用于简单的每个设备代码sholto@, sharadmv@, jekbradbury@, zhangqiaorjc@, mattjj@
2023 年 1 月
JAX 支持两种多设备编程的思路:
我们需要既出色的 API,而不是互斥的替代方案,它们需要相互组合。
通过pjit
(现在是jit
),我们拥有了下一代 API来支持第一种思路。但是我们还没有完全升级第二种思路。pmap
遵循第二种思路,但随着时间的推移,我们发现它存在致命缺陷。xmap
解决了这些问题,但它并没有完全给我们提供每个设备的形状,并且还包含了其他几个重大的想法。同时,对于像在高效扩展 Transformer 推理中的每个设备显式集合编程的新需求也在不断涌现。
我们可以通过shmap
升级第二种思路。shmap
是:
xmap
的特化,具有简化的功能和一些调整;
shard_map
、shpecialized_xmap
、sholto_map
或sharad_map
。
对于pjit
用户,shmap
是一个补充工具。它可以在pjit
计算中使用,暂时切换到“手动集合”模式,就像是从编译器的自动分区中逃脱一样。这样,用户可以在大部分代码中享受到pjit
的便利和熟悉的 NumPy 编程模型,同时在需要时使用shmap
来手动优化集合通信。这是两全其美的解决方案!
对于pmap
用户,shmap
是一个严格的升级。它更加表达力强,性能更好,并且与其他 JAX API 可以良好组合,而不会使基本的批量数据并行化变得更加困难。
对于更多的实际使用情况,你可以跳转到何时使用shmap
和何时使用pjit
?如果你想知道我们为什么需要一个全新的东西,或者pmap
存在什么问题,可以跳到为什么pmap
或xmap
不能解决这个问题?或者继续阅读下一节,查看一些shmap
示例和 API 规范。
shmap
!Sho shick:
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('i', 'j'))
a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 32.).reshape(16, 32)
@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P('j', None)),
out_specs=P('i', None))
def matmul_basic(a_block, b_block):
# a_block: f32[2, 8]
# b_block: f32[8, 32]
z_partialsum = jnp.dot(a_block, b_block)
z_block = jax.lax.psum(z_partialsum, 'j')
return z_block
c = matmul_basic(a, b) # c: f32[8, 32]
注意:
axis_index_groups
)来处理多个轴的并行性,不像pmap
;
pmap
和 hard-xmap
,逻辑形状对应于每个设备的物理形状,不像(非硬)xmap
;
mesh
实现精确的设备放置控制,不像 pmap
;
xmap
;
pjit
的 jax.Array
,不像 pmap
;
pjit
/jit
内部有效地工作,不像 pmap
;
pdb
并打印值,不像 xmap
的当前实现(尽管设计上 xmap
没有顺序安排也可以急切地工作)。
这里是另一种具有完全分片结果的矩阵乘法变体:
@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P('j', None)),
out_specs=P('i', 'j'))
def matmul_reduce_scatter(a_block, b_block):
# c_partialsum: f32[8/X, 32]
c_partialsum = jnp.matmul(a_block, b_block)
# c_block: f32[8/X, 32/Y]
c_block = jax.lax.psum_scatter(c_partialsum, 'j', scatter_dimension=1, tiled=True)
return c_block
c = matmul_reduce_scatter(a, b)
我们可以将 pmap
(和 vmap
和 xmap
)视为沿轴解堆叠每个数组输入(例如,将 2D 矩阵解包成其 1D 行),对每个片段应用其体函数,并将结果堆叠在一起,至少在不涉及集合时是这样的:
pmap(f, in_axes=[0], out_axes=0)(xs) == jnp.stack([f(x) for x in xs])
例如,如果 xs
的形状为 f32[8,5]
,那么每个 x
的形状为 f32[5]
,如果每个 f(x)
的形状为 f32[3,7]
,那么最终堆叠的结果 pmap(f)(xs)
的形状为 f32[8,3,7]
。也就是说,每次对体函数 f
的应用都比 pmap(f)
对应的参数少一个轴。我们可以说这些是降秩映射,输入/输出的解堆叠/堆叠。
f
的逻辑应用次数由被映射的输入轴的大小确定:例如,如果我们在大小为 8 的输入轴上进行映射,从语义上讲,我们得到函数的 8 次逻辑应用,这对于 pmap
总是对应于 8 个物理设备计算。
相反,shmap
没有这种降秩行为。相反,我们可以将其视为沿输入轴切片(或“非连接”)为块,应用体函数,并将结果再次连接在一起(在没有涉及集合时):
devices = np.array(jax.devices()[:4])
m = Mesh(devices, ('i',)) # mesh.shape['i'] = 4
shard_map(f, m, in_specs=P('i'), out_specs=P('i'))(y)
==
jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, 4)])
请注意,jnp.split
将其输入切片成相同大小的块,因此如果在上述示例中 y
的形状为 f32[8,5]
,那么每个 y_blk
的形状为 f32[2,5]
,如果每个 f(y_blk)
的形状为 f32[3,7]
,那么最终连接的结果 shard_map(f, ...)(y)
的形状为 f32[12,7]
。因此 shmap
(shard_map
)映射其输入的分片或块。我们可以说它是一个保持秩映射,其输入/输出的解连接/连接。
f
的逻辑应用次数由网格大小确定,而不是任何输入轴大小:例如,如果我们有总大小为 4 的网格(即超过 4 个设备),那么从语义上讲,我们得到函数的 4 次逻辑应用,对应于 4 个物理设备计算它们。
in_specs
控制每个输入的切分(解连接)和平铺;每个in_specs
通过PartitionSpec
标识了一些相应输入数组的轴,通过网格轴名称表示如何将该输入拆分(或取消连接)成应用主体函数的块。该标识确定了分片大小;当一个输入轴标识为一个网格轴时,输入沿该逻辑轴分割(取消连接)为与相应网格轴大小相等的多个部分。(如果相应网格轴大小不能整除输入数组轴大小,则会产生错误。)如果输入的pspec
未提及网格轴名称,则在该网格轴上不会进行分割。例如:
devices = np.array(jax.devices())
m = Mesh(devices.reshape(4, 2), ('i', 'j'))
@partial(shard_map, mesh=m, in_specs=P('i', None), out_specs=P('i', 'j'))
def f1(x_block):
print(x_block.shape)
return x_block
x1 = np.arange(12 * 12).reshape(12, 12)
y = f1(x1) # prints (3,12)
因为输入的pspec
未提及网格轴名'j'
,所以没有任何输入数组轴在该网格轴上进行分割;同样地,因为输入数组的第二轴未与任何网格轴标识(因此未在其上进行分割),f1
的应用将完整查看该轴上的输入。
当输入的pspec
中未提及网格轴时,我们总是可以重写为一个效率较低的程序,其中所有网格轴都被提及,但调用者执行jnp.tile
,例如:
@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P('i', 'j'))
def f2(x_block):
print(x_block.shape)
return x_block
x = np.arange(12 * 12).reshape(12, 12)
x_ = jnp.tile(x, (1, mesh.axis_size['j'])) # x_ has shape (12, 24)
y = f2(x_) # prints (3,12), and f1(x) == f2(x_)
换句话说,因为每个输入的pspec
可以零次或一次提及每个网格轴名,而不必确切一次提及每个名字,所以我们可以说,除了其输入中内置的jnp.split
,shard_map
还具有一个内置的jnp.tile
,至少在逻辑上是如此(尽管根据参数的物理分片布局,不一定需要在物理上执行平铺)。要使用的平铺方法不是唯一的;我们也可以沿着第一个轴平铺,并使用P(('j', 'i'), None)
的pspec
。
输入上的物理数据移动是可能的,因为每个设备都需要具有适当数据的副本。
out_specs
控制每个输出通过连接、块转置和使用untiling
组装。类似于输入端,每个out_specs
通过名称将一些相应输出数组的轴标识为网格轴,表示如何将输出块(每个主体函数应用的一个或等效地每个物理设备的一个)组装回来以形成最终输出值。例如,在上述f1
和f2
示例中,out_specs
表明我们应通过沿两个轴连接块结果来形成最终输出,从而在两种情况下得到形状为(12,24)
的数组y
。(如果主体函数的输出形状,即输出块形状,对应的输出pspec
所描述的连接过程具有过小的秩,则会产生错误。)
当输出 pspec 中未提到网格轴名称时,它表示一个未平铺:当用户编写一个输出 pspec,其中未提到网格轴名称之一时,他们保证输出块在该网格轴上是相等的,因此在输出中仅使用该轴上的一个块(而不是沿该网格轴将所有块连接在一起)。例如,使用与上述相同的网格:
x = jnp.array([[3.]])
z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P('i', 'j'))()
print(z) # prints the same as jnp.tile(x, (4, 2))
z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P('i', None))()
print(z) # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))
z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P(None, None))()
print(z) # prints the same as jnp.tile(x, (1, 1)), or just x
注意,闭包在数组值上的主体函数等同于将其作为具有相应输入 pspec P(None, None)
的增广传递。作为另一个例子,更接近前面例子的另一个例子:
@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P('i', None))
def f3(x_block):
return jax.lax.psum(x_block, 'j')
x = np.arange(12 * 12).reshape(12, 12)
y3 = f3(x)
print(y3.shape) # (12,6)
注意,结果的第二个轴大小为 6,是输入第二个轴大小的一半。在这种情况下,通过在输出 pspec 中不提到网格轴名称 'j'
来表达的未平铺是安全的,因为集体 psum
确保每个输出块在相应的网格轴上是相等的。以下是另外两个例子,其中我们变化了在输出 pspec 中提到的网格轴:
@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f4(x_block):
return jax.lax.psum(x_block, 'i')
x = np.arange(12 * 12).reshape(12, 12)
y4 = f4(x)
print(y4.shape) # (3,12)
@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P(None, None))
def f5(x_block):
return jax.lax.psum(x_block, ('i', 'j'))
y5 = f5(x)
print(y5.shape) # (3,6)
在物理方面,未在输出 pspec 中提到网格轴名称会从输出设备缓冲区组装一个 Array
,在该网格轴上具有复制的布局。
没有运行时检查输出块实际上是否沿网格轴相等以进行未平铺,或者等效地说,相应的物理缓冲区是否具有相等的值,因此可以解释为单个逻辑数组的复制布局。但我们可以提供一个静态检查机制,在所有潜在不正确的程序上引发错误。
因为 out_specs
可以提到网格轴名称零次或一次,并且它们可以以任意顺序提到,所以我们可以说,除了其输出中内置的 jnp.concatenate
外,shard_map
还包含一个未平铺和一个块转置。
在输出上不可能进行物理数据移动,无论输出 pspec 如何。相反,out_specs
只是编码如何将块输出组装成 Array
,或者物理上如何将缓冲区解释为单个逻辑 Array
的物理布局。
from jax.sharding import Mesh
Specs = PyTree[PartitionSpec]
def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs
) -> Callable:
...
其中:
mesh
编码设备按照数组排列,并且具有相关联的轴名称,就像对 xmap
和 sharding.NamedSharding
也是如此;
in_specs
和 out_specs
是 PartitionSpec
,它们可以仿射地提到 mesh
中的轴名称(不像 xmap
中的分开的逻辑名称)来表示输入和输出的切片/非拼接和拼接,分别(不像 pmap
和 xmap
那样的解包和堆叠),未提到的名称对应于复制和未平铺(断言已复制,因此给我一个副本);
f
的参数的形状与传递给shard_map
-of-f
的参数的形状相同(不像pmap
和xmap
,其中形状被降低),而且参数传递给f
的形状是从对应于shard_map
-of-f
的形状shape
和相应的PartitionSpec
spec 计算得到的,大致为tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))
;
f
的主体可以使用来自mesh
的名称应用收集操作。
shmap
默认是急切的,这意味着我们逐个原语地调度计算,使用户能够在完全复制的值上使用 Python 控制流和交互式pdb
调试以打印任何值。要将shmap
函数进行阶段输出并进行端到端编译,只需在其周围放置一个jit
。一个结果是,shmap
没有像当前的xmap
和pmap
那样有其自己的调度和编译路径;它只是jit
路径的一部分。
当它被例如封闭的jit
阶段输出时,将shmap
降低到 StableHLO 是微不足道的:它仅涉及切换到输入的“手动 SPMD 模式”,并在输出上切换回来。(我们目前不计划支持部分手动部分自动模式。)
与效果的交互与pmap
的交互相同。
与自动微分的交互也与pmap
类似(而不是尝试xmap
所做的新语义,对应于具有未映射中间变量的grad
的reduce_axes
以及使psum
转置为pbroadcast
而不是psum
)。但是它因此继承了来自pmap
的一个未解决的问题:在某些情况下,将后向传播的psum
移动到后向传播的其他位置,利用线性特性,而不是将psum
转置为psum
,从而执行与前向传播psum
对应的后向传播psum
,这可能是有益的。许多高级的pmap
用户通过使用custom_vjp
来实现psum_idrev
和id_psumrev
函数来解决这一挑战,但由于很容易意外地使其失衡,这种技术是有风险的。我们对如何以更安全的方式提供此功能有一些想法。
shmap
,何时应该使用pjit
?一种哲学是:在jit==pjit
中编写程序通常更简单 —— 但是如果程序的某个部分的优化程度不如编译器可能的话,就使用shmap
!
实际上,我们可以使用 30 行 Python 实现简单版本的“集体矩阵乘法”算法,该算法最近在 XLA 中引入,以重叠通信和计算使用shmap
。算法的基本思想可以通过一个简单的例子掌握。
假设我们想要计算C = A @ B
,其中A
由第 0 维的 1D 网格分片,而B
和C
是复制的。
M, K, N = 4096, 2048, 1024
A = jnp.arange(np.prod((M, K))).reshape((M, K))
B = jnp.arange(np.prod((K, N))).reshape((K, N))
mesh = Mesh(np.array(jax.devices()), axis_names=('i'))
A_x = jax.device_put(A, NamedSharding(mesh, P('i', None)))
@jax.jit
def f(lhs, rhs):
return lhs @ rhs
C = f(A_x, B)
配置文件显示了在矩阵乘法开始之前,所有设备上的阻塞全收集。这是次优的,因为A
在非收缩维上被分片,每个A
的分片可以独立地与B
进行矩阵乘法,并且这种分块计算可以与从另一设备获取下一个A
分片重叠。
这种重叠可以通过shmap
和显式集体来实现。
def collective_matmul_allgather_lhs_non_contracting(lhs, rhs):
# lhs is the looped operand; rhs is the local operand
axis_size = jax.lax.psum(1, axis_name='i')
axis_index = jax.lax.axis_index(axis_name='i')
chunk_size = lhs.shape[0]
def f(i, carrys):
accum, lhs = carrys
# matmul for a chunk
update = lhs @ rhs
# circular shift to the left
lhs = jax.lax.ppermute(
lhs,
axis_name='i',
perm=[(j, (j - 1) % axis_size) for j in range(axis_size)]
)
# device 0 computes chunks 0, 1, ...
# device 1 computes chunks 1, 2, ...
update_index = (((axis_index + i) % axis_size) * chunk_size, 0)
accum = jax.lax.dynamic_update_slice(accum, update, update_index)
return accum, lhs
accum = jnp.zeros((lhs.shape[0] * axis_size, rhs.shape[1]), dtype=lhs.dtype)
# fori_loop cause a crash: hlo_sharding.cc:817 Check failed: !IsManual()
# accum, lhs = jax.lax.fori_loop(0, axis_size - 1, f, (accum, lhs))
for i in range(0, axis_size - 1):
accum, lhs = f(i, (accum, lhs))
# compute the last chunk, without the ppermute
update = lhs @ rhs
i = axis_size - 1
update_index = (((axis_index + i) % axis_size) * chunk_size, 0)
accum = jax.lax.dynamic_update_slice(accum, update, update_index)
return accum
jit_sharded_f = jax.jit(shard_map(
collective_matmul_allgather_lhs_non_contracting, mesh,
in_specs=(P('i', None), P()), out_specs=P()))
C = jit_sharded_f(A_x, B)
一个配置文件显示,全收集消失了,并且用异步集体置换的重叠矩阵乘法替换。此配置文件与集体矩阵乘法论文结果非常接近。
这种集体矩阵乘法技术可以用于加速变压器层中的前馈块。这通常包括两个矩阵乘法,后跟一个ReduceScatter
(用于解决并行矩阵乘法的部分和)和前导的AllGather
(用于沿某些轴收集分片维度并允许部分和计算)。在一起,一层的ReduceScatter
和下一层的AllGather
相当于一个AllReduce
。
在典型配置文件中,两个矩阵乘法后将跟随一个AllReduce
,它们不会重叠。集体矩阵乘法可以用来实现重叠,但很难触发,具有最小切片大小,并且尚未涵盖所有拓扑结构、张量形状和集体矩阵乘法的变体(即延迟和吞吐量优化的变体)。在最近的一篇论文中,我们发现,在许多情况下,通过手动实现集体矩阵乘法变体,可以获得约 40%的增益,类似于shmap
风格。
但这并不总是更复杂!我们预计这将是一种更自然的管道计算方式,计划很快进行一些演示!
这里展示了shmap
在变换器层传递中的样子,采用了 2D 权重收集模式(论文,第 3.2.3 节,第 5 页):
def matmul_2D_wg_manual(xnorm, q_wi, layer):
'''Calls a custom manual implementation of matmul_reducescatter'''
# [batch, maxlen, embed.X] @ [heads.YZ, embed.X, q_wi_per_head]
# -> (matmul)
# -> [batch, maxlen, heads.YZ, q_wi_per_head]{x unreduced}
# -> (reducescatter over x into X heads, B batches)
# -> [batch, maxlen, heads.YZX, q_wi_per_head]
with jax.named_scope('q_wi'):
xnorm = intermediate_dtype(xnorm)
q_wi = matmul_reducescatter(
'bte,hed->bthd',
xnorm,
params.q_wi,
scatter_dimension=(0, 2),
axis_name='i',
layer=layer)
return q_wi
import partitioning.logical_to_physical as l2phys
def pjit_transformer_layer(
hparams: HParams, layer: int, params: weights.Layer, sin: jnp.ndarray,
cos: jnp.ndarray, kv_caches: Sequence[attention.KVCache],
x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Forward pass through a single layer, returning output, K, V."""
def my_layer(t, axis=0):
"""Gets the parameters corresponding to a given layer."""
return lax.dynamic_index_in_dim(t, layer, axis=axis, keepdims=False)
# 2D: [batch.Z, time, embed.XY]
x = _with_sharding_constraint(
x, ('residual_batch', 'residual_time', 'residual_embed'))
xnorm = _layernorm(x)
# 2D: [batch, time, embed.X]
xnorm = _with_sharding_constraint(
xnorm, ('post_norm_batch', 'time', 'post_norm_embed'))
# jump into manual mode where you want to optimise
if manual:
q_wi = shard_map(matmul_2D_wg_manual, mesh
in_specs=(l2phys('post_norm_batch', 'time', 'post_norm_embed'),
l2phys('layers', 'heads', 'embed', 'q_wi_per_head')),
out_specs=l2phys('post_norm_batch', 'time', 'heads', 'q_wi_per_head'))(xnorm, q_wi, layer)
else:
q_wi = jnp.einsum('bte,hed->bthd', xnorm, my_layer(params.q_wi))
# 2D: [batch, time, heads.YZX, None]
q_wi = _with_sharding_constraint(q_wi,
('post_norm_batch', 'time', 'heads', 'qkv'))
q = q_wi[:, :, :, :hparams.qkv]
q = _rope(sin, cos, q)
# unlike in https://arxiv.org/pdf/2002.05202.pdf, PaLM implements
# swiGLU with full d_ff dimension, rather than 2/3 scaled
wi0 = q_wi[:, :, :, hparams.qkv:hparams.qkv + (hparams.ff // hparams.heads)]
wi1 = q_wi[:, :, :, hparams.qkv + (hparams.ff // hparams.heads):]
kv = jnp.einsum('bte,ezd->btzd', xnorm, my_layer(params.kv))
k = kv[:, :, 0, :hparams.qkv]
v = kv[:, :, 0, hparams.qkv:]
k = _rope(sin, cos, k)
y_att = jnp.bfloat16(attention.attend(q, k, v, kv_caches, layer))
y_mlp = special2.swish2(wi0) * wi1
# 2D: [batch, time, heads.YZX, None]
y_mlp = _with_sharding_constraint(y_mlp,
('post_norm_batch', 'time', 'heads', None))
y_fused = jnp.concatenate([y_att, y_mlp], axis=-1)
# do the second half of the mlp and the self-attn projection in parallel
y_out = jnp.einsum('bthd,hde->bte', y_fused, my_layer(params.o_wo))
# 2D: [batch.Z, time, embed.XY]
y_out = _with_sharding_constraint(
y_out, ('residual_batch', 'residual_time', 'residual_embed'))
z = y_out + x
z = _with_sharding_constraint(
z, ('residual_batch', 'residual_time', 'residual_embed'))
return z, k, v
在下面的配置文件中,第一和第二个矩阵乘法都被手动降低版本替换,计算(融合)完全与通信(ppermute)重叠!一个有趣的提示是,我们使用的是延迟优化变体,因此 ppmerute 像素是抖动的 — 因为同时使用两个重叠的 ppermute,使用相反的 ICI 轴!
全对全的重叠要困难得多,因此被搁置了。
pmap
或xmap
还没有解决这个问题?pmap
是我们的第一个多设备并行性 API。它遵循每设备代码和显式集体的学派。但它存在重大缺陷,使其不适用于今天的程序:
pmap
。 不仅嵌套 pmap
写起来麻烦,而且很难控制(甚至预测)数据和计算的设备放置,也很难保留数据分片(参见接下来的两个子弹)。如今的程序需要多个轴的并行处理。
pmap
不提供如何在硬件上放置映射程序实例的控制;用户只能使用自动设备顺序,无法控制它。(Gopher 使用 axis_index_groups
和单个未嵌套的 pmap
基本上是一种通过将多个并行轴压缩为一个轴来绕过此问题的方法。)
jit
/pjit
可组合性。 jit
-of-pmap
是一个性能陷阱,像是嵌套 pmap
、例如 scan
-of-pmap
一样,因为从内部 pmap
返回时未能保留分片。要保留分片,我们需要在 jaxprs 上进行模式匹配,以确保我们正在处理完全嵌套的 pmaps,或者在 jit
内部只有一个 pmap。此外,pjit
无助于此处,因为 pmap
面向 XLA 副本,而 pjit
则面向 XLA SPMD Partitioner,这两者的组合很困难。
jax.Array
兼容性(因此 pjit
兼容性)。 由于 pmap
输出的分片不能表示为 Shardings
/ OpShardings
,因为 pmap
使用的是堆叠而不是连接语义,所以目前无法将 pmap
计算的输出直接传递给 pjit
计算,而需要经过主机反弹(或调度重塑计算)。
pjit
兼容性)。 多控制器 pmap
在控制器间连接值,这很有效,但与单控制器 pmap
的堆叠语义不同。更实际地说,它排除了与多控制器 pjit
一起使用的非完全可寻址 jax.Array
输入和输出的可能性。
pmap
设计为急切模式,尽管最终(四年多后!)通过 disable_jit()
添加了急切操作,但事实上 pmap
中融入了 jit
意味着它有自己的编译和调度路径(实际上有两个调度路径:Python 处理 Tracer
,以及 C++ 处理原始 Array
输入!),这是一个沉重的实现负担。
pmap
的典型用例可能看起来是从大小为 128 的批处理轴开始,将其重塑为大小为 (8, 16) 的两个轴,然后在第一个轴上进行 pmap
。这些重塑是笨拙的,编译器通常将它们解释为复制而不是视图,增加了内存和时间的使用。
这些缺点在仅进行批量数据并行时并不算太糟糕。但是当涉及更多并行处理时,pmap
就显得力不从心!
xmap
作为pmap
的下一代演进铺平了道路并解决了(几乎)所有这些问题。shmap
则沿着xmap
的步伐前行,并以基本相同的方式解决了这些问题;实际上,shmap
就像是xmap
的一个专门子集(有些人称之为“硬xmap
”子集),只是稍作调整。
对于初始原型,我们选择将shmap
实现为与xmap
分离的单独原语,因为限制它支持的功能集使得更容易专注于核心功能。例如,shmap
不允许未映射的中间值,这样就更容易不用担心命名轴与自动微分之间的交互。此外,不需要考虑所有功能对之间的交互使得可以更容易地添加超出当前xmap
实现的功能,比如支持急切模式。
shmap
和xmap
都共享降低代码的重要部分。未来我们可以考虑合并两者,或者甚至完全专注于shmap
,这取决于使用方式的演变。
@froystig, @sharadmv, @jakevdp, @yashk2810
2023 年 5 月
import jax.extend as jex
多个项目依赖于 JAX 的代码库内部,通常用于使用其核心机制(例如编写其 IR 上的转换)或扩展它(例如定义新的原语)。这些依赖的两个挑战是(a)我们的内部结构并不都是为外部使用而设计的,以及(b)绕过 JAX 的公共 API 是不受支持的。换句话说,我们的内部经常被用作库,但既不像库那样结构化也不像库那样更新。
此提案考虑引入一个jax.extend
模块,定义 JAX 一些内部组件的库视图。我们将其视为第二层 API,仍然基本不保证兼容性政策,但希望在发生更改时更容易发现。
jax.extend
的受众包括与 JAX 相关的 Python 库,如Oryx,jax-triton等,以及进行函数转换、自动微分系统、数值编程编译器前端等实验的项目。
本说明概述了jax.extend
现在和将来可能的样子。它没有详细列出所有细节,而是建议我们开始逐步开发这个模块。
注意,jax.extend
与jax.experimental
不同,后者是新功能和正在进行的想法的一个暂存场所。通常,jax.experimental
中的工作最终会进入另一个 JAX 模块或被完全移除。
为了保持开发的开销低,jax.extend
不会遵循公共API 兼容性政策。它将承诺没有弃用窗口,也没有版本间的向后兼容性。每个发布都可能会破坏现有的调用者,没有简单的回退措施(例如没有重新引入先前行为的标志)。我们将依赖变更日志来指出这些更改。
调用jax.extend
的调用者可能会发现在 JAX 发布时与其代码一起定期升级对他们有用。这是当今依赖 JAX 内部的项目的一个常见习惯。不同之处在于现在它将以更好的意图和更好的库设计和命名帮助中,伴随着变更日志公告的形式出现。
没有兼容性政策使得在实施上更容易上手:第一天,我们可以从内部包(如jax._src
)中移植少量符号到今天的jax.core
和jax.interpreters
。然后我们可以迭代改进。
我们可以设想,最终jax.extend
可能包括以下模块:
core
– 原语,Jaxpr IR 等。
interpreters
– 核心转换(例如自动微分、批处理)和降低。
random
– 随机位生成、关键分割和折叠、关键数组。
sharding
– 关于分布式数组的额外功能。
最初在模块中可能还有其他符号,例如jex.api_util
,随着我们的工作,我们将移除或替换它们。其他的时间会决定。例如,jex.lib
可能在短期内提供访问 jexlib 的入口点,但是目前还不清楚我们是否想长期保留它。
对每个这些内容可能包含的一些初步想法的追踪。
jax.extend.core
这应该至少使调用者能够定义新的 JAX 原语并处理 Jaxpr IR(jax.make_jaxpr(...)
的输出)。支持这一点可能涉及提供:
jax._src.lax.add_p
。
jax._src.core.ShapedArray
。
jax.make_jaxpr
分阶段地阶段 Python 函数(或不阶段化!)。
在初始化时,这个模块将包含比定义原语和规则所需更多的符号,包括在设置“最终风格转换”时使用的各种名称,例如当前的jax._src.core.Trace
和Tracer
类。我们可以重新审视jex.core
是否应该支持初始风格方法以及是否可以通过比完全暴露Trace
和Tracer
更狭窄的 API 来支持最终风格扩展。Oryx可能会帮助指导这些决策。
我们还可以考虑将make_jaxpr
本身迁移到jax.core
中。
jax.extend.interpreters
此模块将提供注册各种原语转换规则的手段 —— 定义它们在自动微分、批处理、降低等方面的行为。
最初将反映jax._src.interpreters
,提供模块ad
、batching
、partial_eval
(用于将 Python 编程转换为 Jaxpr,并用于自动微分中的线性化)、mlir
、pxla
和xla
。前三者可能可以通过jax.core
中的单一原语扩展 API 替换。用于降低的后三者可以简化为一个模块,也许。
今天,为了编写转换规则,例如用于自动微分和批处理的规则,调用者可能需要与跟踪器相关的符号,例如JVPTracer
和BatchTracer
。以后可能可以避免这种情况,并允许我们从jax
中移除跟踪器类型。
这个模块加上jex.core
应该足以复制今天的自定义原语教程(例如我们的教程和dfm 的教程)。例如,定义一个原语及其在jax.jit
下的行为可能如下(在短期内):
from jax.extend import core # Previously: from jax import core
from jax.extend.interpreters import mlir # ... and similarly
mul_add_p = core.Primitive('mul_add')
mul_add_p.def_impl(lambda x, y, z: x * y + z)
@mul_add_p.def_abstract_eval
def mul_add_abstract(x_sa, y_sa, z_sa):
return core.ShapedArray(x_sa.shape, x_sa.dtype)
def mul_add_mlir(ctx, xc, yc, zc):
add = mlir.hlo.AddOp
mul = mlir.hlo.MulOp
return add(mul(xc, yc), zc).results
mlir.register_lowering(mul_add_p, mul_add_mlir)
import jax
print(mul_add_p.bind(2, 3, 4)) # -> 10
print(jax.jit(mul_add_p.bind)(2, 3, 4)) # -> Array(10, dtype=int32)
jax.extend.random
这个模块可以暴露出我们定义新的随机数生成器实现的机制,并提供用于处理 PRNG 密钥内部的函数(参见问题#9263),例如当前的jax._src.prng.random_wrap
和random_unwrap
。
它还可以暴露出构成内置随机数生成器实现基础的键控哈希函数,例如jax._src.prng.threefry_2x32
。
jax.extend.sharding
这个模块可以暴露出用于分片分布式数组的低级实用工具。
目前我们只考虑了一项。XLA 编译器的数组分片格式比JAX 提供的那些更具表现力。我们可以将其作为jex.sharding.XlaOpShardingProto
提供,对应于今天内部的jax._src.lib.xla_client.OpSharding
。
mattjj@,dougalm@
2023 年 8 月
我们在自动转置包含某些收集的shmap
中遇到了效率问题。问题出现在psum
和all_gather
,特别是当收集的输出作为未映射的输出返回给调用者时。这并不是一个边缘情况:例如,在应用grad
到基于shmap
的批量数据并行神经网络损失函数时,使用psum
来计算总损失。
我们已经知道这个问题有一段时间了。与pmap
类似的问题存在,尽管通过在pmap
内部而不是外部保留grad
来解决了这个问题。不完全的带有名称的avals-with-names
工作的一个主要目标是解决这个转置效率问题的一个版本。这篇文档借鉴了这些想法,同时对其进行了扩展和修订,以处理更多情况,并且更易于落地。事实上,这里提出的解决方案只影响shmap
的实现。其余系统不需要更改(暂时)。
这篇文档的主要目的是定义这个转置效率问题,并提出一个易于落地的解决方案。
这篇文档不涉及:
shmap
和 OG pmap
中的轴名称一样);
psum
或all_gather
的有效转置取决于共享设备上的余切是否不变考虑这个半真实的例子,旨在类似于一个复制参数批量数据并行损失函数:
devices = jax.devices() # 8 devices
@partial(shmap, mesh=Mesh(devices, ('batch',)),
in_specs=(P(None, None), P('batch', None)),
out_specs=P())
def loss(params, batch):
inputs, targets = batch
predictions = predict(params, inputs)
local_loss = jnp.mean(jnp.sum(predictions - targets, -1))
global_loss = lax.pmean(local_loss, 'batch'))
return global_loss
注意out_specs=P()
,它指示未映射的输出。如果您不熟悉未映射输出的概念,请参阅本文档底部的附录。
在loss
示例中的大多数细节并不重要。对于我们的目的来说,唯一重要的是我们在最后应用了psum
(或者更确切地说是pmean = lambda x, name: psum(x, name) / psum(1, name)
)。因此,一个精简版本看起来像这样:
# Example 1: shmap involving psum and unmapped output with inefficient transpose
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
甚至通过抑制mesh
参数简化了符号。在接下来的例子中,可以从上下文中推断出来。
什么样的转置看起来像?写t
来表示函数转置,我们可以通过应用下面的函数¿f1_transpose?
有效地评估任意ybar
对应的t(f1)(ybar)
:
# An efficient "transpose" of Example 1 (but don't transpose this again!)
¿f1_transpose? = shmap(t(g), in_specs=P(), out_specs=P('i'))
但这并不是我们当前获得的转置t(f1)
。
相反,当前的转置配方大致是我们交换in_specs
和out_specs
,对未映射输出进行一些分区重缩放,并转置主体。因为psum
本身是其自身的转置(作为全归约和的总和),我们最终会产生这个转置:
# The transpose we currently get for Example 1 (which is fine to transpose again)
t(f1) = shmap(lambda ybar: t(g)(psum(ybar / 8, 'i')),
in_specs=P(), out_specs=P('i'))
这个转置虽然得到了正确的数字,但是很浪费。我们从转置的 in_specs=P()
静态地知道 ybar
对于每个函数实例都具有相同的值,即其值对于沿着被命名为 i
的网格轴的设备是不变的,然而我们还是对它应用了 psum
!这使用了昂贵的通信来将每个设备上的值乘以 8。(这里的 8 指的是轴 i
的大小。除以 8 来自于原始函数的 out_specs=P()
;它和微不足道的 psum
基本上互相抵消了。)
我们做错了什么?我们没有利用 cotangents
ybar
对应于 f1
的未映射输出是设备不变的这一事实;相反,我们像防御性地 psum
它们一样处理它们,就像 psum
的转置不能确定它们一样。有时 psum
是必要的,比如对于关于其第一个参数的 f2
的转置:
# Example 2: shmap involving psum and *mapped* output with efficient transpose
f2 = shmap(lambda x, y: psum(g(x), 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# The transpose we currently get for Example 2 is efficient
t(f2, 0) = shmap(lambda y, zbar: t(g)(psum(zbar * y, 'i')),
in_specs=(P('i'), P('i')), out_specs=P('i'))
直观地说,如果我们的转置机制能区分示例 1 和示例 2,我们可以通过尽可能避免在可能的情况下避免 psum
和除法来做得更好。
低效的示例甚至可以更小。考虑转置这个被诅咒的恒等函数:
# Example 3: cursed identity
cursed_identity = shmap(lambda x: x, P(), P())
# Currently we get these inefficient transposes
t(cursed_identity) = shmap(lambda x: psum(x / 8, 'i'), P(), P())
t(t(cursed_identity)) = shmap(lambda x: psum(psum(x / 8 / 8, 'i'), 'i')), P(), P())
...
随着我们的转置越来越多,它变得越来越大。真丢人!
而 psum
并不是唯一的问题。类似的情况也适用于 all_gather
:
# Example 4: all_gather to an unmapped output
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())
# Currently we get this inefficient transpose
t(f4) = shmap(lambda ybar: psum_scatter(ybar / 8, 'i'), P(), P('i'))
这个程序有点人为。为什么要做一个 all_gather
并将结果馈送到未映射的输出,而不是跳过主体中的 all_gather
并仅使用 out_specs=P('i')
收集结果?但即使是虚构的,这个例子仍然展示了一个不必要执行通信的转置(我们本可以执行一个非通信的切片),类似于示例 1 中的 psum
。
类似于 psum
示例,防御性的 psum_scatter
在某些情况下是必要的:
# Example 5: all_gather to a mapped output
f5 = shmap(lambda x, y: all_gather(x, 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Currently we get this efficient transpose
t(f5, 0) = shmap(lambda y, zbar: psum_scatter(zbar * y, 'i'),
in_specs=(P('i'), P('i')), out_specs=P('i'))
那么我们如何避免这些低效的转置呢?
这里有两个解决方案的想法。它们并不是互斥的。但是(剧透),第二个更好,并且它是我们所需的全部。
psum
表达到 out_specs
中的能力这个解决方案有点像一个草人,因为它只会提供一个笨拙的编程方式。而且它甚至不能解决所有问题!但是,考虑到激励更完整的解决方案,这也值得一试。
上面的示例 4 是人为的,因为我们本可以在主体中使用 out_specs
而不是一个 all_gather
:
# Example 4 again
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())
# Why didn't we just write it like this?
f4_better = shmap(lambda x: x, P('i'), P('i'))
f4_better
版本没有任何转置问题,因为转置问题源于主体中的集体操作。
类似地,我们可以通过扩展 out_specs
来修复示例 1,以便它们可以表达求和:
# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
# What if we could write an output sum like this?
f1_better = shmap(g, in_specs=P('i'), out_specs=P(sum='i')) # sum='i' means sum over that axis
# Then it could transpose like this:
t(f1_better) = shmap(t(g), in_specs=P(), out_specs=P('i'))
t(t(f1_better)) = shmap(t(t(g)), in_specs=P('i'), P(sum='i'))
因此,提供内置到 out_specs
的 psum
解决了示例 1 中的转置问题。但它并没有完全解决示例 3 中的被诅咒的恒等转置:
# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# How it would transpose with the P-sum partial solution:
t(cursed_identity) = shmap(lambda x: x / 8, P(), P(sum='i'))
t(t(cursed_identity)) = shmap(lambda x: x / 8, P(), P(sum='i'))
尽管程序不会随着我们继续转置而继续增大,这是一个改进,但我们仍在进行浪费的通信。
这个解决方案有两个组成部分:
psum
分解为两步过程,引入一个新的pbroadcast
基元,并引入all_gather
及其转置的新基元。
从道义上讲,追踪设备不变与设备变化信息是一种类型级别的考虑。但为了第一次实现的方便起见,我们不需要在抽象值或者 jaxpr 类型中真正添加这些信息。在实施之前,我们会先使用类型引入这个想法。
同样将讨论如何使用户 API 既方便又向后兼容。但首先介绍这个想法时,我们会忽略方便性,而是尽可能地编写显式的代码。
有时候仅仅通过静态信息,我们就可以断定在shmap
的主体中一些中间变量的值在整个网格轴上是不变的,这意味着沿着网格轴的函数实例(及其对应的设备)必须都在使用相同的值进行计算。我们将这样的值称为设备不变的。对于那些不是设备不变的值,我们将它们称为设备变化的,尽管从类型系统的角度来看,我们其实是指它们可能在设备层面上是变化的。
要在类型中编码设备变化,我们将扩展数组类型的语法。我们会写类似x:f32[3,4]{i}
来表示x
在网格轴i
上(可能)是设备变化的(在shmap
的其他网格轴上是设备不变的)。更一般地说,我们会说数组类型语法的语法是这样的
shaped_array ::= <dtype>[<int_literal>, ...]<device_variance_type>
device_variance_type ::= {<axis_name>, ...}
我们还将更新类型规则来处理设备变化类型
mul x:f32[s1]{r1} y:f32[s2][r2]
要求除了s1 == s2
外还要求r1 == r2
cond
的分支,我们会取设备变化类型中轴名称集合的并集)
shmap
的“静态分析”检查,以确定任何未映射的 out_specs
是否与其兼容。
这里是一个总结集体原语设备差异类型的表格:
名称 | 设备差异类型 | 示例 | 降低到 HLO | 转置 |
---|---|---|---|---|
psum2 | 可变 -> 不变 | y:f32[3]{j} = psum(x:f32[3]{i,j}, axis='i') | AllReduceSum (通讯) | pbroadcast |
pbroadcast | 不变 -> 可变 | y:f32[3]{i} = pbroadcast(x:f32[3], 'i') | no-op(无通讯) | psum |
all_to_all | 可变 -> 可变 | y:f32[16]{i} = all_to_all(x:f32[16]{i}, 'i', 0, 0) AllToAll (通讯) | all_to_all | |
axis_index | () -> 可变 | idx:i32[]{i} = axis_index('i') | ReplicaId 和一些算术运算(无通讯) | n/a |
psum_scatter | 可变 -> 可变 | y:f32[2]{i} = psum_scatter(x:f32[16]{i}, 'i') | ReduceScatterSum (通讯) | all_gather |
all_gather | 可变 -> 可变 | y:f32[16]{i} = all_gather(x:f32[2]{i}, 'i') | AllGather (通讯) | psum_scatter |
pscatter | 不变 -> 可变 | y:f32[2]{i} = pscatter(x:f32[16], 'i') | lambda x: x[axis_index('i'), None] (无通讯) | all_gather_invariant |
all_gather_invariant | 可变 -> 不变 | y:f32[16] = all_gather_invariant(x:f32[2]{i}, 'i') | AllGather (通讯) | pscatter |
这里有一些令人惊讶的事情!
pbroadcast
,有趣的是降低为 no-op
all_gather_invariant
,它降低到与 all_gather
相同的内容,但具有不同的设备差异类型(实质上 all_gather
中融合了 pbroadcast
,而 all_gather_invariant
没有)
pscatter
,它是 all_gather_invariant
的对偶(转置)
all_gather
有一个设备可变的结果
直觉上,引入 pbroadcast
的原因(除了使类型规则生效之外)是为了使 psum
能转置为物理上的 no-op。我们需要 all_gather
有一个设备可变的结果,这样我们就可以将其转置为 psum_scatter
;如果我们将其留在设备不变的结果上,可能需要下游的 pbroadcast
,这种组合将转置为低效的 psum
,然后是切片 / pscatter
。因此,我们将 pbroadcast
“融合到” all_gather
中,从而实现有效的转置为 psum_scatter
。我们提供 all_gather_invariant
及其转置 pscatter
主要是为了完整性;用户不太可能需要它(它对应于示例 4 中的情况,可以使用 out_specs
进行不同写作)。
有趣的是,psum
和 pbroadcast
的转置对应于用户在训练 LLMs 时引入的 pmap
中的 psum_idrev
和 id_psumrev
。
再次考虑简化的激励示例:
# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
# Example 1 with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1(x: f32[3,4]{i}):
w:f32[]{i} = g(x)
y:f32[]{} = psum(w, 'i')
return y
使用这些新规则,转置为:
# Example 1 transpose using device variance types (go ahead and transpose this again!)
t(f1) = shmap(lambda ybar: t(g)(pbroadcast(ybar, 'i')),
in_specs=P(), out_specs=P('i'))
# Example 1 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1_transpose(ybar: f32[]):
wbar:f32[]{i} = pbroadcast(ybar, 'i')
xbar:f32[3,4]{i} = transpose(g)(wbar)
return xbar
在评估 pbroadcast
应用程序时完全不涉及通信或 FLOP;这是一个无操作。请注意,如果我们保持转置,主体的大小不会增长;确实 t(t(f1)) == f1
。实现了效率!
只要我们在需要时插入 pbroadcast
以进行类型检查,我们就不会搞砸其他示例:
# Example 2 rewritten with explicit pbroadcast
f2 = shmap(lambda x, y: pbroadcast(psum(g(x), 'i'), 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Example 2 transpose using device variance types
t(f2, 0) = shmap(lambda y, zbar: t(g)(pbroadcast(psum(zbar * y, 'i'), 'i')),
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# Notice here the body is `f32[...] -> f32[...]`, i.e. no device varying type.
# Example 3 transpose using device variance types
t(cursed_identity) = shmap(lambda x: x, P(), P())
t(t(cursed_identity)) = shmap(lambda x: x, P(), P())
直观地,在示例 1 中,我们现在只有“原始 psum 的一半”,而在示例 2 中,我们得到了“两半”。对于示例 3,我们根本不需要主体中的任何操作。
对于 all_gather
示例,示例 4 将需要使用 all_reduce_invariant
来实现有效的转置(虽然最好是在主体中使用 out_specs
而不是集体操作):
# Example 4 rewritten with explicit all_reduce_invariant
f4 = shmap(lambda x: all_gather_invariant(x, 'i'), P('i'), P())
# Example 4 with intermediate device variance types annotated
@partial(shmap, P('i'), P())
def f4(x:f32[1]{i}):
y:f32[8]{} = all_gather_invariant(x, 'i')
return y
# Example 4 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P(), out_specs=P('i'))
def f4_transpose(ybar:f32[8]):
xbar:f32[1]{i} = pscatter(ybar, 'i')
return xbar
对于示例 5,使用设备变化的 all_gather
的效果与我们期望的一样:
# Example 5 with intermediate device variance types annotated
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5(x:f32[1]{i}, y:f32[8]{i}):
z:f32[8]{i} = all_gather(x, 'i')
w:f32[8]{i} = z * y
return w
# Transpose with respect to first argument
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5_transpose(y:f32[8]{i}, wbar:f32[8]{i}):
zbar:f32[8]{i} = wbar * y
xbar:f32[1]{i} = psum_scatter(zbar, 'i')
return xbar
但是,有哪位用户愿意编写pbroadcast
?有哪位开发人员愿意破坏许多现有用户代码,其中包括未输入到未映射输出的 psum
?不包括我!
相反,我们可以自动插入pbroadcast
。这有点类似于我们在 jax.numpy
层执行自动等级提升时的方式,插入广播以避免二元运算符中的等级不匹配错误。但它要简单得多,因为我们不需要处理形状元组。典型的规则是:每当我们看到一个多元操作,其中操作数在设备方差类型上存在差异时,我们将操作数的设备方差类型的轴名称集合的并集,并插入pbroadcast
以将每个操作数提升到结果设备方差类型。
在需要之前自动插入 pbroadcast
可能意味着我们对相同的操作数多次应用相同的 pbroadcast
,从而创建共同子表达式。当我们转置时,这些可能会变成 psum
的和而不是 psum
的总和。我们将依赖编译器根据需要进行清理。如果这是个问题,我们可以向 pbroadcast
插入通行证添加一些简单的记忆化处理。
all_gather
的用户 API 将默认为 all_gather_p
(而不是 all_gather_invariant_p
),涵盖常见情况,意味着不需要插入 pbroadcast
。
我们可以在 shmap
上提供一个选项来禁用这种自动插入pbroadcast
,在这种情况下,用户需要确保类型正确。这种显式选项可能对一些人很有吸引力,他们希望明确指定向后传递中 psum
出现的位置。
使实现轻量级的关键是我们不会将这些类型添加到 avals 或 jaxprs 中。至少起初不会。这可能很昂贵,因为它需要更新 JAX 的其余部分,例如 avals 和 jaxprs 的所有消费者可能需要处理新类型。我们不会再次上当!
相反,我们将保留这些扩展类型作为shmap
的内部元数据,就像当前的“out_specs
复制检查”机制一样。实际上,这个解决方案相当于对现有机制的相对小的扩展:它已经在跟踪相同的信息;现在我们只是添加了pbroadcast
。
我们至少有两种选择来执行pbroadcast
插入的位置:
shmap
主体中,无论是急切执行还是分阶段输出,都要像当前的“out_specs
复制检查”机制一样。前者可能更容易,因为我们只需要处理 jaxpr 案例,并且只有线性原语。但我们将首先尝试后者,以便此处的实现是对现有复制检查逻辑的严格修订/扩展。
对于具体性,我们将主要关注shmap
,尽管这些想法同样适用于例如pmap
和可能的xmap
。
当对应的in_specs
条目未提及该网格轴的名称时,参数/输入沿着网格轴是未映射的。逻辑上意味着每个沿着该网格轴的函数实例对于参数得到相同的值。对于调用者来说,每个操作数根据其映射的网格轴进行切片,而对于未映射的网格轴,则没有切片。
当对应的out_specs
条目未提及该网格轴的名称时,输出沿着网格轴是未映射的。逻辑上意味着每个沿着该网格轴的函数实例必须返回相同的值。对于调用者来说,shmap
的每个结果由沿着输出映射的每个函数实例的返回值串联而成,而对于未映射的网格轴,则只使用该值的一个副本。
参见《shmap
JEP》,其中展示了未映射输入和输出的示例。作为比较,在vmap
中,未映射的输入/输出通过使用in_axes
/ out_axes
为None
(而不是int
)来指示。
这里是我们喜欢shmap
的未映射输入和输出的原因:
pjit
相同的表达能力。 任何pjit
能做的事情,shmap
逃逸通道也应该能做到。否则我们就会缺少逃逸通道!如果shmap
中没有未映射的输出,那么我们无法表达与pjit
相同的批并行损失函数计算。
因此,未映射的输出既是规范的又是有用的!
原文:
jax.readthedocs.io/en/latest/jep/18137-numpy-scipy-scope.html
Jake VanderPlas
2023 年 10 月
到目前为止,jax.numpy
和 jax.scipy
的预期范围相对模糊。本文提出了这些包的明确定义范围,以更好地指导和评估未来的贡献,并促使移除一些超出范围的代码。
从一开始,JAX 的目标是为在 XLA 中执行代码提供类似于 NumPy 的 API,项目的发展的一大部分是建立 jax.numpy
和 jax.scipy
命名空间,作为基于 JAX 的 NumPy 和 SciPy API 实现。一直有一个隐含的认识,即numpy
和 scipy
的某些部分超出了 JAX 的范围,但这一范围并没有明确定义。这可能会导致贡献者困惑和沮丧,因为对于潜在的 jax.numpy
和 jax.scipy
贡献是否会被接受,没有明确的答案。
为了避免遗漏,我们应该明确一点:像 JAX 这样的项目中包含的任何代码都会为开发者带来一定的持续维护负担,虽然小但非零。项目长期成功直接与维护者能否继续为项目的所有部分承担维护工作有关:包括文档功能的记录、回答问题、修复错误等。对于任何软件工具的长期成功和可持续性,维护者必须仔细权衡每一项贡献是否对项目的目标和资源是净正面影响。
本文提出了一个六轴评估标准,用来评判任何特定numpy
或scipy
API 的适用范围,以确定是否适合纳入 JAX。在所有轴上表现强劲的 API 是纳入 JAX 包的极佳候选;在六个轴中的任何一个上表现极差都是不适合纳入 JAX 的充分理由。
我们考虑的第一个方向是建议 API 与本地 XLA 操作的对齐程度。例如,jax.numpy.exp()
函数几乎直接镜像了 jax.lax.exp
。numpy
、scipy.special
、numpy.linalg
、scipy.linalg
等中的大多数函数符合此标准:这类函数在考虑其是否应包含在 JAX 中时通过了 XLA 对齐检查。
另一方面,有些函数如numpy.unique()
,它们不直接对应任何 XLA 操作,在某些情况下甚至与 JAX 的当前计算模型根本不兼容,后者要求静态形状的数组(例如 unique
返回依赖于值的动态数组形状)。这类函数在考虑其是否应包含在 JAX 中时未能通过 XLA 对齐检查。
我们还考虑纯函数语义的必要性。例如,numpy.random
基于一个隐式更新的基于状态的随机数生成器,这与基于 XLA 的 JAX 计算模型根本不兼容。
我们考虑的第二个方向集中在Python 数组 API 标准上:在某些意义上,这是一个社区驱动的大纲,用于定义在各种用户社区中重要的面向数组编程的数组操作。如果numpy
或 scipy
中的 API 列在数组 API 标准中,这表明 JAX 应该包含它。以上述示例为例,数组 API 标准包含了 numpy.unique()
的多个变体(unique_all
、unique_counts
、unique_inverse
、unique_values
),这表明,尽管该函数与 XLA 的精确对齐并不完全,但它对于 Python 用户社区非常重要,因此 JAX 或许应该实现它。
对于不符合 Axis 1 或 2 的功能,是否存在良好支持的下游包供应该功能是纳入 JAX 的一个重要考虑因素。一个很好的例子是 scipy.optimize
:虽然 JAX 包含了对 scipy.optimize
功能的最小包装集,但更完整的实现存在于由 JAX 协作者积极维护的 JAXopt 包中。在这种情况下,我们应倾向于指向用户和贡献者这些专业化的包,而不是在 JAX 自身重新实现这些 API。
对于不符合 XLA 的功能,一个考虑因素是提议实现的复杂程度。这在某种程度上与 Axis 1 一致,但仍然是需要强调的。有许多函数已经贡献给 JAX,它们具有相对复杂的实现,难以验证并引入了过多的维护负担;一个例子是 jax.scipy.special.bessel_jn()
:在撰写本 JEP 时,其当前实现是一个非直观的迭代逼近,存在 某些领域的收敛问题,而 提出的修复方案 则引入了更多的复杂性。如果在接受贡献时更加仔细地权衡了实现的复杂性和健壮性,我们可能选择不接受这个包的贡献。
JAX 最适合使用功能型 API 而不是面向对象的 API。面向对象的 API 通常会隐藏不纯的语义,使其往往难以实现良好。NumPy 和 SciPy 通常坚持使用功能型 API,但有时提供面向对象的便利包装器。
例如 numpy.polynomial.Polynomial
,它包装了像 numpy.polyadd()
,numpy.polydiv()
等低级操作。一般情况下,当既有功能型 API 又有面向对象 API 时,JAX 应避免为面向对象 API 提供包装器,而应为功能型 API 提供包装器。
在只存在面向对象的 API 的情况下,JAX 应避免提供包装器,除非在其他轴上有很强的案例支持。
决定在 JAX 中包含 NumPy/SciPy API 还应考虑到该算法对一般用户社区的重要性。诚然,很难量化谁是“利益相关者”以及如何衡量这种重要性;但我们包括这一点是为了明确说明,在 JAX 的 NumPy 和 SciPy 包装中包含什么的任何决定都将涉及某种不容易量化的自由裁量权。
对于现有 API,通过在 github 中搜索使用情况可能有助于确定其重要性或缺失;例如,我们可以回到上面讨论过的 jax.scipy.special.bessel_jn()
:搜索显示,这个函数在 github 上仅有少数用例,这可能部分原因与先前提到的精度问题有关。
在本节中,我们将尝试根据上述标准评估 NumPy 和 SciPy 的 API,包括当前 JAX API 中的一些示例。这不会是所有现有函数和类的详尽列表,而是一个更一般的子模块和主题讨论,附带相关示例。
numpy
命名空间我们认为主要 numpy
命名空间中的函数基本上都适用于 JAX,因为它与 XLA(轴 1)和 Python 数组 API(轴 2)的一般对齐性以及对 JAX 用户社区的一般重要性(轴 6)保持一致。一些函数可能处于边界地带(例如 numpy.intersect1d()
,np.setdiff1d()
,np.union1d()
可能在某些标准下不完全符合),但为简单起见,我们声明所有主要 numpy 命名空间中的数组函数都适用于 JAX。
numpy.linalg
和 numpy.fft
numpy.linalg
和 numpy.fft
子模块包含许多与 XLA 提供的功能广泛对齐的函数。其他函数具有复杂的特定设备的低级实现,但代表一种情况,其中对利益相关者的重要性(轴 6)超过了复杂性。因此,我们认为这两个子模块都适用于 JAX。
numpy.random
numpy.random
对于 JAX 而言超出范围,因为基于状态的随机数生成器与 JAX 的计算模型基本不兼容。相反,我们将重点放在 jax.random
上,它使用基于计数器的伪随机数生成器提供类似的功能。
numpy.ma
和 numpy.polynomial
numpy.ma
和 numpy.polynomial
子模块主要关注通过其他函数手段表达的计算的面向对象接口(轴 5)。因此,我们认为它们不适用于 JAX。
numpy.testing
NumPy 的测试功能只对主机端计算有意义,因此我们在 JAX 中不包含任何包装器。尽管如此,JAX 数组与 numpy.testing
兼容,并且在整个 JAX 测试套件中频繁使用它。
SciPy 没有顶层命名空间中的函数,但包含多个子模块。我们逐一考虑每个子模块,略过已弃用的模块。
scipy.cluster
scipy.cluster
模块包含用于层次聚类、K 均值和相关算法的工具。这些在多个方面表现不佳,更适合由下游包处理。JAX 中已经存在一个函数(jax.scipy.cluster.vq.vq()
),但在 github 上没有明显的使用示例,这表明聚类对于 JAX 用户并不广泛重要。
建议:弃用并移除 jax.scipy.cluster.vq()
。
scipy.constants
scipy.constants
模块包含数学和物理常数。这些常数可以直接在 JAX 中使用,因此没有必要在 JAX 中重新实现。
scipy.datasets
scipy.datasets
模块包含获取和加载数据集的工具。这些获取的数据集可以直接在 JAX 中使用,因此没有必要在 JAX 中重新实现。
scipy.fft
scipy.fft
模块包含与 XLA 提供的功能大致对齐的函数,并且在其他方面表现良好。因此,我们认为它们适用于 JAX 的范围内。
scipy.integrate
scipy.integrate
模块包含用于数值积分的函数。其中更复杂的函数(quad
、dblquad
、ode
)基于动态评估的循环算法,根据轴 1 和 4 应视为 JAX 范围之外。jax.experimental.ode.odeint()
相关,但相当有限,未处于任何活跃开发状态。
JAX 当前确实包括 jax.scipy.integrate.trapezoid()
,但这仅因为numpy.trapz()
最近已弃用,推荐使用此功能。对于任何特定输入,其实现可以用一行 jax.numpy
表达式替换,因此它并不是提供的特别有用的 API。
基于轴 1、2、4 和 6,scipy.integrate
应被视为 JAX 范围之外。
建议:移除 jax.scipy.integrate.trapezoid()
,此功能已在 JAX 0.4.14 中添加。
scipy.interpolate
scipy.interpolate
模块提供了在一维或多维中进行插值的低级和面向对象的例程。从多个角度评估,这些 API 表现不佳:它们基于类而非低级,除了最简单的方法外,无法有效地用 XLA 操作表达。
JAX 当前具有 scipy.interpolate.RegularGridInterpolator
的包装器。如果今天考虑此贡献,我们可能会根据以上标准拒绝它。但此代码相当稳定,因此继续维护没有太大的风险。
未来,我们应考虑将 scipy.interpolate
的其他成员视为 JAX 范围之外。
scipy.io
scipy.io
子模块涉及文件输入/输出。在 JAX 中重新实现这一功能没有必要。
scipy.linalg
scipy.linalg
子模块包含与 XLA 提供的功能大致对应的函数,快速线性代数对 JAX 用户社区至关重要。因此,我们认为它适用于 JAX 的范围之内。
scipy.ndimage
scipy.ndimage
子模块包含一组用于处理图像数据的工具。其中许多与 scipy.signal
中的工具重叠(例如卷积和滤波)。JAX 目前在 jax.scipy.ndimage.map_coordinates()
中提供了一个 scipy.ndimage
API。此外,JAX 在 jax.image
模块中提供了一些与图像相关的工具。DeepMind 生态系统包括 dm-pix,一个更全面的用于在 JAX 中进行图像处理的工具集。考虑到所有这些因素,我建议 scipy.ndimage
应被视为 JAX 核心之外的范畴;我们可以将感兴趣的用户和贡献者指向 dm-pix。我们可以考虑将 map_coordinates
移至 dm-pix
或其他适当的包中。
scipy.odr
scipy.odr
模块提供了一个面向对象的 ODRPACK
包装器,用于执行正交距离回归。目前尚不清楚是否可以使用现有的 JAX 原语清晰地表达这一功能,因此我们认为它超出了 JAX 本身的范畴。
scipy.optimize
scipy.optimize
模块提供了用于优化的高级和低级接口。这样的功能对许多 JAX 用户非常重要,在 JAX 创建 jax.scipy.optimize
包装器时非常早就开始。然而,这些程序的开发人员很快意识到 scipy.optimize
API 过于约束,并且不同的团队开始开发 JAXopt 包和 Optimistix 包,每个包都包含了在 JAX 中更全面和经过更好测试的优化程序集。
由于这些受到良好支持的外部包,我们现在认为 scipy.optimize
超出了 JAX 的范围。
建议:弃用 jax.scipy.optimize
或使其成为一个轻量级的包装器,周围包装一个可选的 JAXopt 或 Optimistix 依赖。
scipy.signal
scipy.signal
模块则有所不同:一些函数完全适用于 JAX(例如correlate
和convolve
,这些函数是lax.conv_general_dilated
的更友好的包装),而其他许多函数则完全不适用于 JAX(专门领域的工具没有合适的降低路径到 XLA)。对于jax.scipy.signal
的潜在贡献将需要具体问题具体分析。
scipy.sparse
scipy.sparse
子模块主要包含了多种格式的稀疏矩阵和数组的存储和操作数据结构。此外,scipy.sparse.linalg
还包含了许多无矩阵的求解器,适用于稀疏矩阵、稠密矩阵和线性算子。
scipy.sparse
的数组和矩阵数据结构也超出了 JAX 的范围,因为它们与 JAX 的计算模型不符(例如,许多操作依赖于动态大小的缓冲区)。JAX 已经开发了jax.experimental.sparse
模块作为一组更符合 JAX 计算约束的替代数据结构。因此,我们认为scipy.sparse
中的数据结构超出了 JAX 的范围。
另一方面,scipy.sparse.linalg
已经被证明是一个有趣的领域,jax.scipy.sparse.linalg
包括了bicgstab
、cg
和gmres
求解器。这些对于 JAX 用户社区(轴 6)非常有用,但在其他轴上并不适用。它们非常适合移入一个下游库;一个潜在的选择可能是Lineax,它包括了多个基于 JAX 构建的线性求解器。
建议:考虑将稀疏求解器移入 Lineax,并且将scipy.sparse
视为 JAX 范围外的内容。
scipy.spatial
scipy.spatial
模块主要包含面向对象的空间/距离计算和最近邻搜索接口。这在很大程度上超出了 JAX 的范围。
scipy.spatial.transform
子模块提供了用于操作三维空间旋转的工具。这是一个相对复杂的面向对象接口,也许最好由下游项目更好地服务。JAX 目前在jax.scipy.spatial.transform
中部分实现了Rotation
和Slerp
;这些是对基本函数的面向对象包装器,引入了非常庞大的 API 表面,且使用者非常少。我们认为它们超出了 JAX 本身的范围,用户最好由一个假设的下游项目更好地服务。
scipy.spatial.distance
子模块包含一组有用的距离度量标准,可能会诱人地为这些提供 JAX 包装器。尽管如此,通过jit
和vmap
,用户可以很容易地根据需要从头开始定义大多数这些的高效版本,因此将它们添加到 JAX 中并不特别有益。
建议:考虑废弃和移除Rotation
和Slerp
API,并考虑将scipy.spatial
整体视为不适合未来贡献。
scipy.special
scipy.special
模块包括一些更专业函数的实现。在许多情况下,这些函数完全在范围内:例如,像gammaln
、betainc
、digamma
和许多其他函数直接对应于可用的 XLA 基元,并且明显在轴 1 和其他轴上在范围内。
其他函数需要更复杂的实现;一个上面提到的例子是bessel_jn
。尽管在轴 1 和 2 上不对齐,但这些函数往往在轴 6 上非常强大:scipy.special
提供了在多个领域中进行计算所需的基本函数,因此即使是具有复杂实现的函数,只要实现良好且健壮,也应倾向于在范围内。
有一些现有的函数包装器值得我们更仔细地看一看;例如:
jax.scipy.special.lpmn()
: 这个函数通过一个复杂的fori_loop
生成 Legendre 多项式,其方式与 scipy 的 API 不匹配(例如,对于scipy
,z
必须是标量,而对于 JAX,则z
必须是 1D 数组)。该函数有少数发现的用途,使其成为 Axes 1、2、4 和 6 上的一个薄弱候选者。
jax.scipy.special.lpmn_values()
: 这与上述的lmpn
有类似的弱点。
jax.scipy.special.sph_harm()
:此函数基于 lpmn 构建,其 API 与对应的scipy
函数不同。
jax.scipy.special.bessel_jn()
:如上述第 4 轴中讨论的那样,这在实现的健壮性方面存在弱点,使用较少。我们可能会考虑用一个新的、更健壮的实现替换它(例如 #17038)。
建议:重构并提高bessel_jn
的健壮性和测试覆盖率。如果无法修改以更接近scipy
的 API,则考虑废弃lpmn
、lpmn_values
和sph_harm
。
scipy.stats
scipy.stats
模块包含广泛的统计函数,包括离散和连续分布、汇总统计以及假设检验。JAX 目前在jax.scipy.stats
中包装了其中一些,主要包括大约 20 种统计分布以及一些其他函数(如mode
、rankdata
、gaussian_kde
)。总体来说,这些与 JAX 很好地对齐:分布通常可以用高效的 XLA 操作表达,API 清晰且功能齐全。
目前我们没有任何假设检验函数的包装器,这可能是因为这些对于 JAX 的主要用户群体不太有用。
关于分布,在某些情况下,tensorflow_probability
提供类似的功能,未来我们可能会考虑是否应该废弃 scipy.stats 中的分布以支持这种实现。
建议:未来,我们应将统计分布和汇总统计视为范围内的内容,并考虑假设检验及其相关功能通常不在范围内。
原文:
jax.readthedocs.io/en/latest/investigating_a_regression.html
所以,您更新了 JAX,并且遇到了速度回归?您有一点时间并且准备好调查吗?让我们首先提一个 JAX 问题。但如果您能够确定触发回归的提交,这将确实帮助我们。
本文说明了我们如何确定导致15% 性能回归的提交。
如果复现器足够快,这可以很容易地完成。这是一种蛮力方法而非二分法,但如果复现器足够快,它会很有效。这确保了您始终测试兼容的 XLA 和 JAX 提交。它还限制了 XLA 的重新编译。
这里是建议的调查策略:
这可以通过使用JAX-Toolbox 每夜容器来完成。
这里是用于该问题的真实示例脚本:github.com/google/jax/issues/17686
for m in 7 8 9; do
for d in `seq -w 1 30`; do
docker run -v $PWD:/dir --gpus=all ghcr.io/nvidia/jax:nightly-2023-0${m}-${d} /bin/bash /dir/test.sh &> OUT-0${m}-${d}
done
Done
pip install jmp pyvista numpy matplotlib Rtree trimesh jmp termcolor orbax
git clone https://github.com/Autodesk/XLB
cd XLB
export PYTHONPATH=.
export CUDA_VISIBLE_DEVICES=0 # only 1 GPU is needed
python3 examples/performance/MLUPS3d.py 256 200
然后,您可以对每个输出执行 grep 命令以查看回归发生的时间:grep MLUPS OUT*
。这是我们得到的结果:
OUT-07-06:MLUPS: 587.9240990200157
OUT-07-07:MLUPS: 587.8907972116419
OUT-07-08:MLUPS: 587.3186499464459
OUT-07-09:MLUPS: 587.3130127722537
OUT-07-10:MLUPS: 587.8526619429658
OUT-07-17:MLUPS: 570.1631097290182
OUT-07-18:MLUPS: 570.2819775617064
OUT-07-19:MLUPS: 570.1672213357352
OUT-07-20:MLUPS: 587.437153685251
OUT-07-21:MLUPS: 587.6702557143142
OUT-07-25:MLUPS: 577.3063618431178
OUT-07-26:MLUPS: 577.2362978080912
OUT-07-27:MLUPS: 577.2101850145785
OUT-07-28:MLUPS: 577.0716349809895
OUT-07-29:MLUPS: 577.4223280707176
OUT-07-30:MLUPS: 577.2255967221336
OUT-08-01:MLUPS: 577.277685388252
OUT-08-02:MLUPS: 577.0137874289354
OUT-08-03:MLUPS: 577.1333281553946
OUT-08-04:MLUPS: 577.305012020407
OUT-08-05:MLUPS: 577.2143988866626
OUT-08-06:MLUPS: 577.2409145495443
OUT-08-07:MLUPS: 577.2602819927345
OUT-08-08:MLUPS: 577.2823738293221
OUT-08-09:MLUPS: 577.3453199728248
OUT-08-11:MLUPS: 577.3161423260563
OUT-08-12:MLUPS: 577.1697775786824
OUT-08-13:MLUPS: 577.3049883393633
OUT-08-14:MLUPS: 576.9051978525331
OUT-08-15:MLUPS: 577.5331743016213
OUT-08-16:MLUPS: 577.5117505070573
OUT-08-18:MLUPS: 577.5930698237612
OUT-08-19:MLUPS: 577.3539885757353
OUT-08-20:MLUPS: 577.4190113959127
OUT-08-21:MLUPS: 577.300394253605
OUT-08-22:MLUPS: 577.4263792037783
OUT-08-23:MLUPS: 577.4087536357031
OUT-08-24:MLUPS: 577.1094728438082
OUT-08-25: File "/XLB/examples/performance/MLUPS3d.py", line 5, in <module>
OUT-08-26:MLUPS: 537.0164618489928
OUT-08-27:MLUPS: 536.9545448661609
OUT-08-28:MLUPS: 536.2887650464874
OUT-08-29:MLUPS: 536.7178471720636
OUT-08-30:MLUPS: 536.6978912984252
OUT-09-01:MLUPS: 536.7030899164106
OUT-09-04:MLUPS: 536.5339818238837
OUT-09-05:MLUPS: 536.6507808565617
OUT-09-06:MLUPS: 536.7144494518315
OUT-09-08:MLUPS: 536.7376612408998
OUT-09-09:MLUPS: 536.7798324141778
OUT-09-10:MLUPS: 536.726157440174
OUT-09-11:MLUPS: 536.7446210750584
OUT-09-12:MLUPS: 536.6707332269023
OUT-09-13:MLUPS: 536.6777936517823
OUT-09-14:MLUPS: 536.7581523280307
OUT-09-15:MLUPS: 536.6156273667873
OUT-09-16:MLUPS: 536.7320935035265
OUT-09-17:MLUPS: 536.7104991444398
OUT-09-18:MLUPS: 536.7492269469092
OUT-09-19:MLUPS: 536.6760131792959
OUT-09-20:MLUPS: 536.7361260076634
这发现 8-24 是好的,但 8-26 是坏的。在 8-25 上有另一个问题,阻止了获取结果。因此,我们需要在 8-24 和 8-26 之间的每小时进行调查。较早的减速可以忽略,仅需在这些日期之间再进行一次小时调查即可。
这在两个日期之间的每个小时检出 JAX 和 XLA,重建所有内容并运行测试。这些脚本结构不同。我们启动工作容器并保持它。然后在容器内,我们只触发增量 XLA 构建,第一次构建除外。因此,在第一次迭代后速度要快得多。
# Execute this script inside the container:
# docker run -v $PWD:/dir --gpus=all ghcr.io/nvidia/jax:nightly-2023-08-24 /bin/bash
cd /opt/xla-source
git remote update
cd /opt/jax-source
git remote update
pip install jmp pyvista numpy matplotlib Rtree trimesh jmp termcolor orbax
cd /tmp
git clone https://github.com/Autodesk/XLB
cd XLB
for d in `seq -w 24 26`; do
for h in `seq -w 0 24`; do
echo $m $d $h
/bin/bash /dir/test2.sh Aug $d 2023 $h:00:00 &> OUT-08-${d}-$h
done
done
echo "param: $@"
cd /opt/xla-source
git checkout `git rev-list -1 --before="$*" origin/main`
git show -q
cd /opt/jax-source
git checkout `git rev-list -1 --before="$*" origin/main`
git show -q
rm /opt/jax-source/dist/jax*.whl
build-jax.sh # The script is in the nightly container
export PYTHONPATH=.
export CUDA_VISIBLE_DEVICES=0 # only 1 GPU is needed
python3 examples/performance/MLUPS3d.py 256 200
现在,您可以在新的输出文件上执行 grep 命令,查看问题出现的小时。
通过这样,您需要检查这些小时之间的 JAX 和 XLA 历史记录。也许有几个提交需要测试。如果您想要花哨一点,可以使用 git bisect。
是的!如果这是一个崩溃回归,能够进行二分法测试将非常有用。但这会更加复杂。如果有人想贡献这样的说明,请提交 PR 😉
对于速度回归,二分法可以隐藏一些信息。我们不会那么容易地看到这里有两个回归。 0899164106 OUT-09-04:MLUPS: 536.5339818238837 OUT-09-05:MLUPS: 536.6507808565617 OUT-09-06:MLUPS: 536.7144494518315 OUT-09-08:MLUPS: 536.7376612408998 OUT-09-09:MLUPS: 536.7798324141778 OUT-09-10:MLUPS: 536.726157440174 OUT-09-11:MLUPS: 536.7446210750584 OUT-09-12:MLUPS: 536.6707332269023 OUT-09-13:MLUPS: 536.6777936517823 OUT-09-14:MLUPS: 536.7581523280307 OUT-09-15:MLUPS: 536.6156273667873 OUT-09-16:MLUPS: 536.7320935035265 OUT-09-17:MLUPS: 536.7104991444398 OUT-09-18:MLUPS: 536.7492269469092 OUT-09-19:MLUPS: 536.6760131792959 OUT-09-20:MLUPS: 536.7361260076634
这发现 8-24 是好的,但 8-26 是坏的。在 8-25 上有另一个问题,阻止了获取结果。因此,我们需要在 8-24 和 8-26 之间的每小时进行调查。较早的减速可以忽略,仅需在这些日期之间再进行一次小时调查即可。
## 每小时调查。
这在两个日期之间的每个小时检出 JAX 和 XLA,重建所有内容并运行测试。这些脚本结构不同。我们启动工作容器并保持它。然后在容器内,我们只触发增量 XLA 构建,第一次构建除外。因此,在第一次迭代后速度要快得多。
+ test_runner2.sh:
```py
# Execute this script inside the container:
# docker run -v $PWD:/dir --gpus=all ghcr.io/nvidia/jax:nightly-2023-08-24 /bin/bash
cd /opt/xla-source
git remote update
cd /opt/jax-source
git remote update
pip install jmp pyvista numpy matplotlib Rtree trimesh jmp termcolor orbax
cd /tmp
git clone https://github.com/Autodesk/XLB
cd XLB
for d in `seq -w 24 26`; do
for h in `seq -w 0 24`; do
echo $m $d $h
/bin/bash /dir/test2.sh Aug $d 2023 $h:00:00 &> OUT-08-${d}-$h
done
done
echo "param: $@"
cd /opt/xla-source
git checkout `git rev-list -1 --before="$*" origin/main`
git show -q
cd /opt/jax-source
git checkout `git rev-list -1 --before="$*" origin/main`
git show -q
rm /opt/jax-source/dist/jax*.whl
build-jax.sh # The script is in the nightly container
export PYTHONPATH=.
export CUDA_VISIBLE_DEVICES=0 # only 1 GPU is needed
python3 examples/performance/MLUPS3d.py 256 200
现在,您可以在新的输出文件上执行 grep 命令,查看问题出现的小时。
通过这样,您需要检查这些小时之间的 JAX 和 XLA 历史记录。也许有几个提交需要测试。如果您想要花哨一点,可以使用 git bisect。
是的!如果这是一个崩溃回归,能够进行二分法测试将非常有用。但这会更加复杂。如果有人想贡献这样的说明,请提交 PR 😉
对于速度回归,二分法可以隐藏一些信息。我们不会那么容易地看到这里有两个回归。