前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Python代码中的偏函数

Python代码中的偏函数

作者头像
DechinPhy
发布2023-12-22 14:22:05
2010
发布2023-12-22 14:22:05
举报
文章被收录于专栏:Dechin的专栏

技术背景

在数学中我们都学过偏导数

\frac{\partial f(x,y)}{\partial x}

,而这里我们提到的偏函数,指的是

f(y)(x)

。也就是说,在代码实现的过程中,虽然我们实现的一个函数可能带有很多个变量,但是可以用偏函数的形式把其中一些不需要拆分和变化的变量转变为固有变量。比较典型的两个例子是计算偏导数和多进程优化。虽然大部分支持自动微分的框架都有相应的支持偏导数的接口,多进程操作中也可以指定额外的args,但是这些自带的方法在形式上都是比较tricky的,感觉并不如使用偏函数优雅和简洁。这里我们主要介绍python中可能会用到的偏函数功能--partial

Partial简单案例

我们先来一个最简单的乘法函数

f(x,y)=xy

。假如说我们想得到该函数关于y的偏导数,注意,这里y是第二个输入的变量,不是第一个位置,一般自动微分框架都默认都第一个位置的变量计算偏导数。相关代码实现如下所示:

代码语言:javascript
复制
from functools import partial

def mul(x, y):
    print (locals())
    return x * y

x = 2
y = 3
res_0 = mul(x, y)
partial_mul = partial(mul, x=x)
res_1 = partial_mul(y=3)
print ('The result is: {}'.format(res_0))
print ('The result is: {}'.format(res_1))

这段代码的运行结果为:

代码语言:javascript
复制
{'x': 2, 'y': 3}
{'x': 2, 'y': 3}
The result is: 6
The result is: 6

我们现在来分析一下上面这个案例中所体现的信息:

  1. 在使用partial函数时使用的是关键字参数,即时原本的变量不是一个关键字参数,而是一个位置参数。
  2. 虽然得到的偏函数partial_mul运行的方式跟函数一致,但其实它是一个partial的对象类型。
  3. 在生成partial_mul对象时已经执行过一遍函数,因此函数中的打印语句被打印了两次。
  4. 偏函数的计算结果肯定是跟原函数保持一致的,但是在一些特殊场景下,我们可能会用到这种单变量的偏函数。

Concurrent多核并行场景

现在我们稍微修改一下上面的案例,我们要用concurrent这个并行工具去分别执行上述乘法任务,同时输入的x也变成了一个多维的数组。然后为了验证并行算法,这里每计算一次元素乘法,我们都用time.sleep方法让进程休眠2秒钟时间。由于此时的参数y还是一个标量,但是每次乘法计算我们都需要输入这个标量,因此我们直接将其封装到一个partial偏函数中,使得函数变成:

f(x,y)=f(y)(x)=P(x)

,然后对x这个入参进行并行化操作:

代码语言:javascript
复制
import numpy as np
import concurrent.futures
from functools import partial
import time
# 定义休眠函数
def mul(x, y):
    time.sleep(2)
    return x * y
# 定义入参
x = np.array([1, 2, 3], np.float32)
y = 3.
# 有阻塞计算
time_0 = time.time()
res_0 = []
for _x in x:
    res_0.append(mul(_x, y))
res_0 = np.array(res_0, np.float32)
time_1 = time.time()
# 并行计算
partial_mul = partial(mul, y=y)
time_2 = time.time()
with concurrent.futures.ProcessPoolExecutor(max_workers=x.shape[0]) as executor:
    res = executor.map(partial_mul, x)
res_1 = np.array(list(res), np.float32)
time_3 = time.time()
print ('The result is: {}, and for loop time cost is: {}s'.format(res_0, time_1 - time_0))
print ('The result is: {}, and concurrent time cost is : {}s'.format(res_1, time_3 - time_2))

如果有感兴趣的童鞋也可以去尝试一下,在这种场景下的并行运算,如果参量y不是一个可迭代式的变量,是无法用zip压缩传到map函数中去的。上述代码的运行结果如下:

代码语言:javascript
复制
The result is: [3. 6. 9.], and for loop time cost is: 6.005392789840698s
The result is: [3. 6. 9.], and concurrent time cost is : 2.0451698303222656s

这个计算时长其实就约等于休眠时长,因为这里我们开启了3个进程来进行休眠,因此并行时长是2s。

Jax自动微分场景

这里我们用Jax的自动微分框架做一个示例,没有安装Jax和Jaxlib的想运行需要自行安装相关软件。虽然在Jax的grad函数中,支持argnums这样的参数配置,但从代码层面角度来说,总是显得可读性并不好。正常情况下我们算偏导数

\frac{\partial f(x,y)}{\partial x}

其实更合理的表述应该是

\frac{\partial P(x)}{\partial x}

。而如果按照Jax这种写法,更像是从

[\frac{\partial f(x, y)}{\partial x}, \frac{\partial f(x, y)}{\partial x}]

两个元素中取了第一个元素。当然,这只是表述上的问题,也是我个人的理解,其实并不影响程序的正确性。这里使用partial偏函数的相关案例如下所示:

代码语言:javascript
复制
from functools import partial
from jax import grad
from jax import numpy as jnp
# Jax要求grad函数输出结果为标量,所以要加一项求和
def mul(x, y):
    f = x * y
    return f.sum()
# 定义输入变量
x = jnp.array([1, 2, 3], jnp.float32)
y = 3.
# 定义偏函数和对应偏导数
partial_mul = partial(mul, y=y)
grad_mul = grad(partial_mul)
print (grad_mul(x))

执行结果如下:

代码语言:javascript
复制
[3. 3. 3.]

总结概要

本文介绍了在Python中使用偏函数partial的方法,并且介绍了两个使用partial函数的案例,分别是concurrent并行场景和基于jax的自动微分场景。在这些相关的场景下,我们用partial函数更多时候可以使得代码的可读性更好,在性能上其实并没有什么提升。如果不想使用partial函数,类似的功能也可以使用参考链接中所介绍的方法,实现一个装饰器,也可以做到一样的功能。

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2023-12-21,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 技术背景
  • Partial简单案例
  • Concurrent多核并行场景
  • Jax自动微分场景
  • 总结概要
相关产品与服务
GPU 云服务器
GPU 云服务器(Cloud GPU Service,GPU)是提供 GPU 算力的弹性计算服务,具有超强的并行计算能力,作为 IaaS 层的尖兵利器,服务于生成式AI,自动驾驶,深度学习训练、科学计算、图形图像处理、视频编解码等场景。腾讯云随时提供触手可得的算力,有效缓解您的计算压力,提升业务效率与竞争力。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档