本部分介绍如何采用ResNet解决CIFAR10分类问题。
之前讲到过,ResNet包含了短接模块(short cut)。本节主要介绍如何实现这个模块。
先建立resnet.py文件。
如图
先引入相关包
import torch
import torch.nn as nn
准备构建resnet单元
class ResBlk(nn.Module):
# 与上节一样,同样resnet的block单元,继承nn模块
def __init__(self):
super(ResBlk, self).__init__()
# 完成初始化
由ResNet特点可知,需要传入channel_in和channel_out才能进行运算,因此在定义中需要加入两个变量。
def __init__(self, ch_in, ch_out):
接下来像之前一样,写入其原先的卷积层。
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
# 进行正则化处理,以使train过程更快更稳定
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)
Resnet 模块的左侧的部分写好了,
先不急着写右侧,先写左侧的forward代码
先引入工具包
import torch.nn.functional as F
书写代码
def forward(self, x):
# 这里输入的是[b, ch, h, w]
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
下面开始写short cut代码
out = x + out
# 这便是element.wise add,实现了[b, ch_in, h, w] 和 [b, ch_out, h, w]两个的相加
同时要考虑,若两元素中的ch_in和ch_out不匹配,则运行时会报错。因此需要在前面指定添加if函数
if ch_out != ch_in:
self.extra = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1),
nn.BatchNorm2d(ch_out),
)
这段代码的意思即为实现[b, ch_in, h, w] => [b, ch_out, h, w]的转化
写好后,将element.wise add部分的x替换
out = self.extra(x) + out
这里也要考虑若ch_in和ch_out原先就相匹配的情况,则需要先进行定义。
self.extra = nn.Sequential()
最后在定义后,返回结果out
至此resnet block模块构建完毕
现代码为
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResBlk(nn.Module):
# 与上节一样,同样resnet的block单元,继承nn模块
def __init__(self, ch_in, ch_out):
super(ResBlk, self).__init__()
# 完成初始化
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
# 进行正则化处理,以使train过程更快更稳定
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)
self.extra = nn.Sequential()
if ch_out != ch_in:
self.extra = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1),
nn.BatchNorm2d(ch_out),
)
def forward(self, x):
# 这里输入的是[b, ch, h, w]
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.extra(x) + out
# 这便是element.wise add,实现了[b, ch_in, h, w] 和 [b, ch_out, h, w]两个的相加
return out
本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!