前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【动手学深度学习笔记】之自定义层

【动手学深度学习笔记】之自定义层

作者头像
树枝990
发布2020-08-19 09:36:06
4280
发布2020-08-19 09:36:06
举报
文章被收录于专栏:拇指笔记

1.自定义层

神经网络中存在着全连接层、卷积层、池化层和循环层等各种各样的层。虽然PyTorch提供了大量常用的层,但有时还是需要我们自定义层。本篇文章介绍如何使用Module类来自定义层。

1.1 不含模型参数的自定义层

下面以实例介绍一下通过继承Module类定义不含模型参数的自定义层。

代码语言:javascript
复制
class layer(nn.Module):
    def __init__(self,**keywargs):
        #直接继承Module的__init__()
        super(layer,self).__init__(**keywargs)
    def forward(self,x):
        #定义前向传导
	return x-x.mean()

定义的这个层并没有模型参数。实例化例子如下

代码语言:javascript
复制
layer = layer()
layer(torch.tensor([1,2,3,4,5],dtype=torch.float))

Out[1]:tensor([-2., -1.,  0.,  1.,  2.])

同样可以使用Sequential类将这个层添加到网络。

代码语言:javascript
复制
net = nn.Sequential(net.Linear(8,128),layer())

1.2 含模型参数的自定义层

为自定义层添加模型参数有以下三种方式。

使用Parameter类

上一篇文章介绍过,当一个Tensor类型为Parameters时,它将会被自动添加到参数列表中。

代码语言:javascript
复制
class net1(nn.Module):
    def __init__(self):
        super(net1,self).__init__()
        self.weight = nn.Parameter(torch.rand(4,4))
        self.bais = nn.Parameter(torch.rand(4,1))
    def forward(self,x):
        for i in range(len(self.params)):
            x = torch.mm(x, self.params[i])
        return x
net1 = net1()
for name,param in net1.named_parameters():
	print(name,param)
    
Out[1]:
    
weight Parameter containing:
tensor([[0.3217, 0.8082, 0.2425, 0.3970],
        [0.6009, 0.2262, 0.7150, 0.6720],
        [0.4062, 0.6335, 0.6234, 0.2680],
        [0.1824, 0.0825, 0.8183, 0.2564]], requires_grad=True)
bais Parameter containing:
tensor([[0.3952],
        [0.4866],
        [0.9082],
        [0.1949]], requires_grad=True)

使用ParameterList类

ParameterList类接收Parameters实例的列表作为输入然后得到一个参数列表,与List类似,可以使用索引访问,append添加。

代码语言:javascript
复制
class MyDense(nn.Module):
    def __init__(self):
        super(MyDense, self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)])
        self.params.append(nn.Parameter(torch.randn(4, 1)))
    def forward(self, x):
        for i in range(len(self.params)):
            x = torch.mm(x, self.params[i])
        return x
net = MyDense()
print(net)

Out[1]:

MyDense(
  (params): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 4x4]
      (1): Parameter containing: [torch.FloatTensor of size 4x4]
      (2): Parameter containing: [torch.FloatTensor of size 4x4]
      (3): Parameter containing: [torch.FloatTensor of size 4x1]
  )
)

使用ParameterDict类

ParameterDict类接收一个Parameter实例的字典作为输入,返回一个参数字典,同样可以使用updata()添加参数,使用key()返回所有键值,使用item()返回所有键值对等字典操作。

代码语言:javascript
复制
class MyDictDense(nn.Module):
    def __init__(self):
        super(MyDictDense, self).__init__()
        self.params = nn.ParameterDict({
                'linear1': nn.Parameter(torch.randn(4, 4)),
                'linear2': nn.Parameter(torch.randn(4, 1))
        })
        self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))}) # 新增

    def forward(self, x, choice='linear1'):
        return torch.mm(x, self.params[choice])

net = MyDictDense()
print(net)

Out[1]:
    
MyDictDense(
  (params): ParameterDict(
      (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
      (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
      (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
  )
)

使用ParameterDict类,可以通过选择不同的键,来进行不同的正向传播。

代码语言:javascript
复制
x = torch.ones(1, 4)
print(net(x, 'linear1'))
print(net(x, 'linear2'))
print(net(x, 'linear3'))

Out[1]:

tensor([[1.5082, 1.5574, 2.1651, 1.2409]], grad_fn=<MmBackward>)
tensor([[-0.8783]], grad_fn=<MmBackward>)
tensor([[ 2.2193, -1.6539]], grad_fn=<MmBackward>)

上述这些中方法创建的层,都可以像PyTorch中其他层一样,通过Sequential类、ModuleList类和ModuleDict类等方法构造模型。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-03-17,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 拇指笔记 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1.自定义层
    • 1.1 不含模型参数的自定义层
      • 1.2 含模型参数的自定义层
        • 使用Parameter类
        • 使用ParameterList类
        • 使用ParameterDict类
    领券
    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档