前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Pytorch中requires_grad_(), detach(), torch.no_grad()的区别

Pytorch中requires_grad_(), detach(), torch.no_grad()的区别

作者头像
Tyan
发布2020-06-02 22:46:22
5.7K0
发布2020-06-02 22:46:22
举报
文章被收录于专栏:SnailTyan

文章作者:Tyan 博客:noahsnail.com | CSDN | 简书

0. 测试环境

Python 3.6.9, Pytorch 1.5.0

1. 基本概念

Tensor是一个多维矩阵,其中包含所有的元素为同一数据类型。默认数据类型为torch.float32

  • 示例一
代码语言:javascript
复制
>>> a = torch.tensor([1.0])
>>> a.data
tensor([1.])
>>> a.grad
>>> a.requires_grad
False
>>> a.dtype
torch.float32
>>> a.item()
1.0
>>> type(a.item())
<class 'float'>

Tensor中只有一个数字时,使用torch.Tensor.item()可以得到一个Python数字。requires_gradTrue时,表示需要计算Tensor的梯度。requires_grad=False可以用来冻结部分网络,只更新另一部分网络的参数。

  • 示例二
代码语言:javascript
复制
>>> a = torch.tensor([1.0, 2.0])
>>> b = a.data
>>> id(b)
139808984381768
>>> id(a)
139811772112328
>>> b.grad
>>> a.grad
>>> b[0] = 5.0
>>> b
tensor([5., 2.])
>>> a
tensor([5., 2.])

a.data返回的是一个新的Tensor对象ba, bid不同,说明二者不是同一个Tensor,但ba共享数据的存储空间,即二者的数据部分指向同一块内存,因此修改b的元素时,a的元素也对应修改。

2. requires_grad_()与detach()

代码语言:javascript
复制
>>> a = torch.tensor([1.0, 2.0])
>>> a.data
tensor([1., 2.])
>>> a.grad
>>> a.requires_grad
False
>>> a.requires_grad_()
tensor([1., 2.], requires_grad=True)
>>> c = a.pow(2).sum()
>>> c.backward()
>>> a.grad
tensor([2., 4.])
>>> b = a.detach()
>>> b.grad
>>> b.requires_grad
False
>>> b
tensor([1., 2.])
>>> b[0] = 6
>>> b
tensor([6., 2.])
>>> a
tensor([6., 2.], requires_grad=True)
  • requires_grad_()

requires_grad_()函数会改变Tensorrequires_grad属性并返回Tensor,修改requires_grad的操作是原位操作(in place)。其默认参数为requires_grad=Truerequires_grad=True时,自动求导会记录对Tensor的操作,requires_grad_()的主要用途是告诉自动求导开始记录对Tensor的操作。

  • detach()

detach()函数会返回一个新的Tensor对象b,并且新Tensor是与当前的计算图分离的,其requires_grad属性为False,反向传播时不会计算其梯度。ba共享数据的存储空间,二者指向同一块内存。

:共享内存空间只是共享的数据部分,a.gradb.grad是不同的。

3. torch.no_grad()

torch.no_grad()是一个上下文管理器,用来禁止梯度的计算,通常用来网络推断中,它可以减少计算内存的使用量。

代码语言:javascript
复制
>>> a = torch.tensor([1.0, 2.0], requires_grad=True)
>>> with torch.no_grad():
...     b = n.pow(2).sum()
...
>>> b
tensor(5.)
>>> b.requires_grad
False
>>> c = a.pow(2).sum()
>>> c.requires_grad
True

上面的例子中,当arequires_grad=True时,不使用torch.no_grad()c.requires_gradTrue,使用torch.no_grad()时,b.requires_gradFalse,当不需要进行反向传播时(推断)或不需要计算梯度(网络输入)时,requires_grad=True会占用更多的计算资源及存储资源。

4. 总结

requires_grad_()会修改Tensorrequires_grad属性。

detach()会返回一个与计算图分离的新Tensor,新Tensor不会在反向传播中计算梯度,会在特定场合使用。

torch.no_grad()更节省计算资源和存储资源,其作用域范围内的操作不会构建计算图,常用在网络推断中。

References

  1. https://pytorch.org/docs/stable/tensors.html
  2. https://pytorch.org/docs/stable/tensors.html#torch.Tensor.requires_grad_
  3. https://pytorch.org/docs/stable/autograd.html#torch.Tensor.detach
  4. https://pytorch.org/docs/master/generated/torch.no_grad.html
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2020/06/01 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 0. 测试环境
  • 1. 基本概念
  • 2. requires_grad_()与detach()
  • 3. torch.no_grad()
  • 4. 总结
  • References
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档