torch.get_num_threads()
返回的是当前 PyTorch 使用的线程数。如果你在代码中设置了 NUM_THREADS = 12
,但是调用 torch.get_num_threads()
仍然返回 1,可能是以下几个原因:
OMP_NUM_THREADS
或 MKL_NUM_THREADS
(取决于使用的后端),并可能使用这些环境变量设置的线程数而不是代码中的设置。为了解决这个问题,你可以尝试以下步骤:
import os
import torch
# 设置环境变量
os.environ['OMP_NUM_THREADS'] = '12'
os.environ['MKL_NUM_THREADS'] = '12'
# 设置 PyTorch 线程数
torch.set_num_threads(12)
# 验证设置是否生效
print(torch.get_num_threads())
确保在导入 PyTorch 之前设置环境变量和调用 torch.set_num_threads()
。
参考链接:
如果你遵循了上述步骤,但问题仍然存在,可能需要进一步检查你的系统配置或 PyTorch 的安装情况。
领取专属 10元无门槛券
手把手带您无忧上云