前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >CIFAR10数据集实战-ResNet网络构建(上)

CIFAR10数据集实战-ResNet网络构建(上)

作者头像
用户6719124
发布2020-01-14 10:51:06
1.1K0
发布2020-01-14 10:51:06
举报
文章被收录于专栏:python pytorch AI机器学习实践

本部分介绍如何采用ResNet解决CIFAR10分类问题。

之前讲到过,ResNet包含了短接模块(short cut)。本节主要介绍如何实现这个模块。

先建立resnet.py文件。

如图

先引入相关包

代码语言:javascript
复制
import torch
import torch.nn as nn

准备构建resnet单元

代码语言:javascript
复制
class ResBlk(nn.Module):
    # 与上节一样,同样resnet的block单元,继承nn模块
    def __init__(self):
        super(ResBlk, self).__init__()
        # 完成初始化

由ResNet特点可知,需要传入channel_in和channel_out才能进行运算,因此在定义中需要加入两个变量。

代码语言:javascript
复制
def __init__(self, ch_in, ch_out):

接下来像之前一样,写入其原先的卷积层。

代码语言:javascript
复制
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
# 进行正则化处理,以使train过程更快更稳定
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)

Resnet 模块的左侧的部分写好了,

先不急着写右侧,先写左侧的forward代码

先引入工具包

代码语言:javascript
复制
import torch.nn.functional as F

书写代码

代码语言:javascript
复制
def forward(self, x):
    # 这里输入的是[b, ch, h, w]
    out = F.relu(self.bn1(self.conv1(x)))
    out = F.relu(self.bn2(self.conv2(out)))

下面开始写short cut代码

代码语言:javascript
复制
out = x + out
# 这便是element.wise add,实现了[b, ch_in, h, w] 和 [b, ch_out, h, w]两个的相加

同时要考虑,若两元素中的ch_in和ch_out不匹配,则运行时会报错。因此需要在前面指定添加if函数

代码语言:javascript
复制
if ch_out != ch_in:
    self.extra = nn.Sequential(
        nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1),
        nn.BatchNorm2d(ch_out),
    )

这段代码的意思即为实现[b, ch_in, h, w] => [b, ch_out, h, w]的转化

写好后,将element.wise add部分的x替换

代码语言:javascript
复制
out = self.extra(x) + out

这里也要考虑若ch_in和ch_out原先就相匹配的情况,则需要先进行定义。

代码语言:javascript
复制
self.extra = nn.Sequential()

最后在定义后,返回结果out

至此resnet block模块构建完毕

现代码为

代码语言:javascript
复制
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResBlk(nn.Module):
    # 与上节一样,同样resnet的block单元,继承nn模块
    def __init__(self, ch_in, ch_out):
        super(ResBlk, self).__init__()
        # 完成初始化

        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        # 进行正则化处理,以使train过程更快更稳定
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)

        self.extra = nn.Sequential()

        if ch_out != ch_in:
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1),
                nn.BatchNorm2d(ch_out),
            )



    def forward(self, x):
        # 这里输入的是[b, ch, h, w]
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))


        out = self.extra(x) + out
        # 这便是element.wise add,实现了[b, ch_in, h, w] 和 [b, ch_out, h, w]两个的相加

        return out
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-01-05,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档