首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何从Jax中的联合累积密度函数计算联合概率密度函数?

如何从Jax中的联合累积密度函数计算联合概率密度函数?
EN

Stack Overflow用户
提问于 2021-11-16 10:38:36
回答 1查看 163关注 0票数 1

我在python中定义了一个联合累积密度函数,它是jax数组的函数,并返回一个值。类似于:

代码语言:javascript
复制
def cumulative(inputs: array) -> float:
    ...

要有梯度,我知道我可以做grad(cumulative),但这只是给了我累积的一阶偏导数,相对于输入变量。相反,我要做的是计算,假设F是我的函数,并且是联合概率密度函数:

部分推导的顺序并不重要。

所以,我有几个问题:

  • 如何在Jax中有效地计算这一点?我假设一旦计算出结果函数,我就不能调用梯度n乘以else)?
  • alternatively,
  • ,结果函数会比原始函数具有更高的调用复杂度吗(它是增加了O(n),还是常量,或者某种东西是

,还是某种东西),与整个数组相比,对于输入数组中的一个变量,我如何计算一个偏导数?(我将重复这个n次,每个变量一次)

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-11-16 14:39:34

JAX通常将梯度视为相对于单个参数而言的梯度,而不是参数中的元素。在这个上下文中,一个与您想要做的类似(但不完全相同)的内置函数是jax.hessian,它计算第二个导数的恒心矩阵;例如:

代码语言:javascript
复制
import jax
import jax.numpy as jnp

def f(x):
  return jnp.prod(x ** 2)

x = jnp.arange(1.0, 4.0)
print(jax.hessian(f)(x))
# [[72. 72. 48.]
#  [72. 18. 24.]
#  [48. 24.  8.]]

对于与数组中单个元素有关的高阶导数,我认为您必须手动嵌套梯度。您可以使用类似于以下内容的助手函数来完成此操作:

代码语言:javascript
复制
def grad_all(f):
  def gradfun(x):
    args = tuple(x)
    f_args = lambda *args: f(jnp.array(args))
    for i in range(len(args)):
      f_args = jax.grad(f_args, argnums=i)
    return f_args(*args)
  return gradfun

print(grad_all(f)(x))
# 48.0
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69987613

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档