我在python中定义了一个联合累积密度函数,它是jax数组的函数,并返回一个值。类似于:
def cumulative(inputs: array) -> float:
...要有梯度,我知道我可以做grad(cumulative),但这只是给了我累积的一阶偏导数,相对于输入变量。相反,我要做的是计算,假设F是我的函数,并且是联合概率密度函数:

部分推导的顺序并不重要。
所以,我有几个问题:
,还是某种东西),与整个数组相比,对于输入数组中的一个变量,我如何计算一个偏导数?(我将重复这个n次,每个变量一次)
发布于 2021-11-16 14:39:34
JAX通常将梯度视为相对于单个参数而言的梯度,而不是参数中的元素。在这个上下文中,一个与您想要做的类似(但不完全相同)的内置函数是jax.hessian,它计算第二个导数的恒心矩阵;例如:
import jax
import jax.numpy as jnp
def f(x):
return jnp.prod(x ** 2)
x = jnp.arange(1.0, 4.0)
print(jax.hessian(f)(x))
# [[72. 72. 48.]
# [72. 18. 24.]
# [48. 24. 8.]]对于与数组中单个元素有关的高阶导数,我认为您必须手动嵌套梯度。您可以使用类似于以下内容的助手函数来完成此操作:
def grad_all(f):
def gradfun(x):
args = tuple(x)
f_args = lambda *args: f(jnp.array(args))
for i in range(len(args)):
f_args = jax.grad(f_args, argnums=i)
return f_args(*args)
return gradfun
print(grad_all(f)(x))
# 48.0https://stackoverflow.com/questions/69987613
复制相似问题