jax和jaxlib是一起的,所以我们可以通过jax或者jaxlib去判断GPU是否用。
jax判断:
import jax print(jax.devices())
jaxlib判断:
from jax.lib import xla_bridge print(xla_bridge.get_backend().platform)
本文分享自 作者个人站点/博客 前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!