在数学中我们都学过偏导数
,而这里我们提到的偏函数,指的是
。也就是说,在代码实现的过程中,虽然我们实现的一个函数可能带有很多个变量,但是可以用偏函数的形式把其中一些不需要拆分和变化的变量转变为固有变量。比较典型的两个例子是计算偏导数和多进程优化。虽然大部分支持自动微分的框架都有相应的支持偏导数的接口,多进程操作中也可以指定额外的args,但是这些自带的方法在形式上都是比较tricky的,感觉并不如使用偏函数优雅和简洁。这里我们主要介绍python中可能会用到的偏函数功能--partial
。
我们先来一个最简单的乘法函数
。假如说我们想得到该函数关于y的偏导数,注意,这里y是第二个输入的变量,不是第一个位置,一般自动微分框架都默认都第一个位置的变量计算偏导数。相关代码实现如下所示:
from functools import partial
def mul(x, y):
print (locals())
return x * y
x = 2
y = 3
res_0 = mul(x, y)
partial_mul = partial(mul, x=x)
res_1 = partial_mul(y=3)
print ('The result is: {}'.format(res_0))
print ('The result is: {}'.format(res_1))
这段代码的运行结果为:
{'x': 2, 'y': 3}
{'x': 2, 'y': 3}
The result is: 6
The result is: 6
我们现在来分析一下上面这个案例中所体现的信息:
现在我们稍微修改一下上面的案例,我们要用concurrent这个并行工具去分别执行上述乘法任务,同时输入的x也变成了一个多维的数组。然后为了验证并行算法,这里每计算一次元素乘法,我们都用time.sleep
方法让进程休眠2秒钟时间。由于此时的参数y还是一个标量,但是每次乘法计算我们都需要输入这个标量,因此我们直接将其封装到一个partial偏函数中,使得函数变成:
,然后对x这个入参进行并行化操作:
import numpy as np
import concurrent.futures
from functools import partial
import time
# 定义休眠函数
def mul(x, y):
time.sleep(2)
return x * y
# 定义入参
x = np.array([1, 2, 3], np.float32)
y = 3.
# 有阻塞计算
time_0 = time.time()
res_0 = []
for _x in x:
res_0.append(mul(_x, y))
res_0 = np.array(res_0, np.float32)
time_1 = time.time()
# 并行计算
partial_mul = partial(mul, y=y)
time_2 = time.time()
with concurrent.futures.ProcessPoolExecutor(max_workers=x.shape[0]) as executor:
res = executor.map(partial_mul, x)
res_1 = np.array(list(res), np.float32)
time_3 = time.time()
print ('The result is: {}, and for loop time cost is: {}s'.format(res_0, time_1 - time_0))
print ('The result is: {}, and concurrent time cost is : {}s'.format(res_1, time_3 - time_2))
如果有感兴趣的童鞋也可以去尝试一下,在这种场景下的并行运算,如果参量y不是一个可迭代式的变量,是无法用zip压缩传到map函数中去的。上述代码的运行结果如下:
The result is: [3. 6. 9.], and for loop time cost is: 6.005392789840698s
The result is: [3. 6. 9.], and concurrent time cost is : 2.0451698303222656s
这个计算时长其实就约等于休眠时长,因为这里我们开启了3个进程来进行休眠,因此并行时长是2s。
这里我们用Jax的自动微分框架做一个示例,没有安装Jax和Jaxlib的想运行需要自行安装相关软件。虽然在Jax的grad函数中,支持argnums这样的参数配置,但从代码层面角度来说,总是显得可读性并不好。正常情况下我们算偏导数
其实更合理的表述应该是
。而如果按照Jax这种写法,更像是从
两个元素中取了第一个元素。当然,这只是表述上的问题,也是我个人的理解,其实并不影响程序的正确性。这里使用partial偏函数的相关案例如下所示:
from functools import partial
from jax import grad
from jax import numpy as jnp
# Jax要求grad函数输出结果为标量,所以要加一项求和
def mul(x, y):
f = x * y
return f.sum()
# 定义输入变量
x = jnp.array([1, 2, 3], jnp.float32)
y = 3.
# 定义偏函数和对应偏导数
partial_mul = partial(mul, y=y)
grad_mul = grad(partial_mul)
print (grad_mul(x))
执行结果如下:
[3. 3. 3.]
本文介绍了在Python中使用偏函数partial的方法,并且介绍了两个使用partial函数的案例,分别是concurrent并行场景和基于jax的自动微分场景。在这些相关的场景下,我们用partial函数更多时候可以使得代码的可读性更好,在性能上其实并没有什么提升。如果不想使用partial函数,类似的功能也可以使用参考链接中所介绍的方法,实现一个装饰器,也可以做到一样的功能。