Loading [MathJax]/jax/output/CommonHTML/config.js
前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >ResNet18代码实现[通俗易懂]

ResNet18代码实现[通俗易懂]

作者头像
全栈程序员站长
发布于 2022-08-24 12:34:42
发布于 2022-08-24 12:34:42
1K0
举报

大家好,又见面了,我是你们的朋友全栈君。

import tensorflow as tf

from tensorflow import keras

from tensorflow.keras import layers, Sequential, Model, datasets, optimizers

# 自定义的预处理函数

def preprocess(x, y):

# 调用此函数时会自动传入x,y对象,shape为[b,28,28],[b]

# 标准化到0-1

x = 2*tf.cast(x, dtype=tf.float32) / 255.-1

# 转成整型张量

y = tf.cast(y, dtype=tf.int32)

# 返回的x,y将替换传入的x,y参数,从而实现数据的预处理功能

return x, y

# 在线下载,加载CIFAR10数据集

(x,y),(x_test,y_test)= datasets.cifar10.load_data()

# 删除y的一个不必要的维度,[b,1] → [b]

y= tf.squeeze(y,axis= 1)

y_test= tf.squeeze(y_test, axis= 1)

# 打印训练集和测试集的形状

# print(x.shape,y.shape, x_test.shape, y_test.shape)

# 构建训练集对象,随机打乱,预处理,批量化

train_db= tf.data.Dataset.from_tensor_slices((x,y))

train_db= train_db.shuffle(1000).map(preprocess).batch(512)

# 构建测试集对象,预处理,批量化

test_db= tf.data.Dataset.from_tensor_slices((x_test,y_test))

test_db= test_db.map(preprocess).batch(512)

# 从训练集中采样一个Batch,并观察

sample= next(iter(train_db))

# print(‘sample:’,sample[0].shape,sample[1].shape,tf.reduce_min(sample[0]),tf.reduce_max(sample[0]))

class BasicBlock(layers.Layer):

# 残差模块

def __init__(self, filter_num, stride= 1):

super(BasicBlock, self).__init__()

#第一个卷积单元

self.conv1= layers.Conv2D(filter_num, kernel_size=(3,3), strides= stride, padding= ‘same’)

self.bn1= layers.BatchNormalization()

self.relu= layers.Activation(‘relu’)

# 第二个卷积单元

self.conv2= layers.Conv2D(filter_num, kernel_size=(3,3), strides= 1, padding= ‘same’ )

self.bn2= layers.BatchNormalization()

# 通过1*1卷积完成shape匹配

if stride != 1:

self.downsample= Sequential()

self.downsample.add(layers.Conv2D(filter_num, kernel_size= (1,1), strides= stride))

else: # shape匹配,直接短接

self.downsample= lambda x:x

def call(self, inputs, training= None):

# 前向计算函数

# [b,h,w,c], 通过第一个卷积单元

out= self.conv1(inputs)

out= self.bn1(out)

out= self.relu(out)

# 通过第二个卷积单元

out= self.conv2(out)

out= self.bn2(out)

# 通过identity模块

identity= self.downsample(inputs)

# 两条路径输出直接相加

output= layers.add([out,identity])

output= tf.nn.relu(output)

return output

class ResNet(Model):

def __init__(self, layer_dims, num_classes= 10): #[2,2,2,2]

super(ResNet, self).__init__()

# 根网络,预处理

self.stem= Sequential([

layers.Conv2D(64, kernel_size= (3,3), strides= (1,1)),

layers.BatchNormalization(),

layers.Activation(‘relu’),

layers.MaxPool2D(pool_size=(2,2), strides=(1,1), padding= ‘same’)

])

# 堆叠4个Block,每个Block包含了多个BasicBlock,设置步长不一样

self.layer1= self.build_resblock(64, layer_dims[0])

self.layer2= self.build_resblock(128, layer_dims[1], stride= 2)

self.layer3= self.build_resblock(256, layer_dims[2], stride= 2)

self.layer4= self.build_resblock(512, layer_dims[3], stride= 2)

# 通过Pooling层将高宽降低为1*1

self.avgpool= layers.GlobalAveragePooling2D()

# 最后连接一个全连接层分类

self.fc= layers.Dense(num_classes)

def build_resblock(self, filter_num, blocks, stride= 1):

# 辅助函数,堆叠filter_num个BasicBlock

res_blocks= Sequential()

# 只有第一个BasicBlock的步长可能不为1, 实现下采样

res_blocks.add(BasicBlock(filter_num, stride))

# 其他BasicBlock步长都为1

for _ in range(1, blocks):

res_blocks.add(BasicBlock(filter_num, stride= 1))

return res_blocks

def call(self, inputs, training= None):

# 前向计算函数:通过根网络

x= self.stem(inputs)

# 一次通过4个模块

x= self.layer1(x)

x= self.layer2(x)

x= self.layer3(x)

x= self.layer4(x)

# 通过池化层

x= self.avgpool(x)

# 通过全连接层

x= self.fc(x)

return x

def resnet18():

# 通过调整模块内部BasicBlock的数量和配置实现不同的ResNet

return ResNet([2,2,2,2])

# def resnet34():

# # 通过调整模块内部BasicBlock的数量和配置实现不同的ResNet

# return ResNet([3,4,6,3])

model = resnet18() # ResNet18网络

model.build(input_shape=(None, 32, 32, 3))

# model.summary() # 统计网络参数

def main():

optimizer = optimizers.Adam(learning_rate=1e-4)

for epoch in range(10):

for step, (x,y) in enumerate(train_db):

with tf.GradientTape() as tape:

# [b, 32, 32, 3] => [b, 1, 1, 512]

logits= model(x)

y_onehot = tf.one_hot(y, depth=10)

# compute loss

loss = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)

loss = tf.reduce_mean(loss)

# 对所有参数求梯度

grads= tape.gradient(loss, model.trainable_variables)

# 自动更新

optimizer.apply_gradients(zip(grads,model.trainable_variables))

if step %10 == 0:

print(epoch, step, ‘loss:’, float(loss))

total_num = 0

total_correct = 0

for x,y in test_db:

# out = model(x)

# out = tf.reshape(out, [-1, 512])

logits = model(x)

prob = tf.nn.softmax(logits, axis=1)

pred = tf.argmax(prob, axis=1)

pred = tf.cast(pred, dtype=tf.int32)

correct = tf.cast(tf.equal(pred, y), dtype=tf.int32)

correct = tf.reduce_sum(correct)

total_num += x.shape[0]

total_correct += int(correct)

acc = total_correct / total_num

print(epoch, ‘acc:’, acc)

if __name__ == ‘__main__’:

main()

发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/141598.html原文链接:https://javaforall.cn

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2022年5月9,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
前端架构师之10_JavaScript_DOM
第1级DOM(DOM Level 1,或DOM1)。为XML和HTML文档中的元素、节点、属性等提供了必备的属性和方法。结合了Netscape及微软公司开发的DHTML(动态HTML)思想。
张哥编程
2024/12/13
1790
JS之文档对象模型DOM
<html> <head> <meta http-equiv="Content-Type" content="text/html; charset=gbk"> <title>History和Location使用</title> </head> <body> <input type="button" value="返回" onclick="history.back();" /> </body> </html> DOM 解析模型,将文档加载到 内存,形成一个树形结构 <html> 就是根节点,每个标签会成为
Java帮帮
2018/03/19
3.4K0
JS之文档对象模型DOM
JavaScript DOM基础
DOM(Document Object Model)即文档对象模型,针对HTML和XML文档的API(应用程序接口)。 一.DOM介绍 DOM中的三个字母,D(文档)可以理解为整个Web加载的网页文档;O(对象)可以理解为类似window对象之类的东西,可以调用属性和方法,这里我们说的是document对象;M(模型)可以理解为网页文档的树型结构。 DOM有三个等级,分别是DOM1、DOM2、DOM3,并且DOM1在1998年10月成为W3C标准。DOM1所支持的浏览器包括IE6+、Firefox、Safa
汤高
2018/01/11
1.4K0
HTML DOM(二):节点的增删改查
       上一篇:HTML DOM(一)        上一篇讲述了DOM的基本知识,从其得知,在DOM眼中,HTML的每个成分都可以看作是节点(文档节点、元素节点、文本节点、属性节点、注释节点,
高爽
2017/12/28
1.7K0
关于DOM的理解
当创建了一个网页并把它加载到web浏览器中时,DOM就悄然而生。浏览器根据网页文档创建一个文档对象。
Tz一号
2020/09/10
1K0
前端之BOM和DOM
BOM(Browser Object Model)浏览器对象模型,它使得JS能够与浏览器进行‘对话’(交互,通过JS对页面内容进行操作)。
GH
2019/12/16
2.8K0
前端之BOM和DOM
前端day13-JS(WebApi)学习笔记(attribute语法、DOM节点操作)
小技巧:如果API写的是Emement复数的形式,也就是后面加了s(Emements)那么它返回的就是一个伪数组 否则就是单个对象,一般只有id才会是单个对象,其他方式获取(标签名 类名)都是伪数组.
帅的一麻皮
2020/04/19
3.1K0
前端day13-JS(WebApi)学习笔记(attribute语法、DOM节点操作)
JavaScript 编程精解 中文第三版 十四、文档对象模型
当你在浏览器中打开网页时,浏览器会接收网页的 HTML 文本并进行解析,其解析方式与第 11 章中介绍的解析器非常相似。浏览器构建文档结构的模型,并使用该模型在屏幕上绘制页面。
ApacheCN_飞龙
2022/12/01
1.5K0
JavaScript 编程精解 中文第三版 十四、文档对象模型
从零开始学习BOM&amp;DOM
ECMAScript,描述了该语言的语法和基本对象,如类型、运算、流程控制、面向对象、异常等。
虎妞先生
2022/09/19
6140
从零开始学习BOM&amp;DOM
前端之HTML DOM操作
当网页被加载时,浏览器会创建页面的文档对象模型(Document Object Model)。
山河木马
2019/03/05
6090
JavaWeb——JavaScript精讲之DOM、BOM对象与案例实战(动态添加删除表格)
上一博文种讲解了JavaScript基础的ECMAScript,包括基本语法和部分对象,本文中继续讲解JavaScript中比较重要的两部分内容BOM、DOM及事件,后文中有对应的实战练习。
Winter_world
2020/09/25
2.3K0
JavaWeb——JavaScript精讲之DOM、BOM对象与案例实战(动态添加删除表格)
jQuery文档对象模型DOM的实际应用
DOM 在 JavaScript 课程中我们详细的探讨过,它是一种文档对象模型。方便开发者对 HTML 结构元素内容进行展示和修改。在 JavaScript 中,DOM 不但内容庞大繁杂,而且我们开发的过程中需要考虑更多的兼容性、扩展性。
王小婷
2018/12/19
1.2K0
DOM 文档对象模型
HTML 模板<html> <head> <title>我是网站标题</title> </head> <body> <div class="box"> <div class="box1"></div> </div> <div name="xiaoming"></div> <div id="box"></div> </body></html>访问节点通过 id 访问指定节点 getElement
菜园前端
2023/05/10
5160
javascript之DOM操作
http://www.cnblogs.com/kissdodog/archive/2012/12/25/2833213.html
bear_fish
2018/09/19
5560
js 深度解析DOM
因为document是window的一个属性,因为属性都是对象拥有的,所以他是一个object;
贵哥的编程之路
2020/11/03
5.1K0
js 深度解析DOM
E006Web学习笔记-JavaScript(四):DOM
将标记语言文档的各个部分,封装为对象,可以使用这些对象,对标记语言文档进行CRUD(增删改查)的动态操作;
訾博ZiBo
2025/01/06
890
E006Web学习笔记-JavaScript(四):DOM
3-DOM
将标记语言文档(HTML,XML…)的各个部分,封装为对象,可以使用这些对象,对标记语言文档进行CRUD动态操作
Ywrby
2022/10/27
1.4K0
第85节:Java中的JavaScript
后代选择器: 选择器1 选择器2 子元素选择器:选择器1 > 选择器2 选择器分组: 选择器1,选择器2,选择器3{} 属性选择器:选择器[属性名称='属性值']
达达前端
2019/07/03
2.7K0
第85节:Java中的JavaScript
【Java 进阶篇】深入理解 JavaScript DOM Node 对象
在前端开发中,与HTML文档进行交互是一项基本任务。文档对象模型(Document Object Model,简称DOM)为开发者提供了一种以编程方式访问和操作HTML文档的方式。DOM的核心是节点(Node)对象,它代表了文档中的各个部分。本博客将深入探讨JavaScript DOM Node对象,帮助您更好地理解它的作用和如何使用。
繁依Fanyi
2023/10/19
3700
Js DOM
要创建新的 HTML 元素 (节点)需要先创建一个元素,然后在已存在的元素中添加它。
hss
2022/02/25
3.9K0
相关推荐
前端架构师之10_JavaScript_DOM
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档