前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >利用Tensorflow2.0实现手写数字识别

利用Tensorflow2.0实现手写数字识别

作者头像
用户7569543
发布2020-07-19 21:20:38
1.1K0
发布2020-07-19 21:20:38
举报
文章被收录于专栏:数据挖掘与AI算法

前面两节课我们已经简单了解了神经网络的前向传播和反向传播工作原理,并且尝试用numpy实现了第一个神经网络模型。手动实现(深度)神经网络模型听起来很牛逼,实际上却是一个费时费力的过程,特别是在神经网络层数很多的情况下,多达几十甚至上百层网络的时候我们就很难手动去实现了。这时候可能我们就需要更强大的深度学习框架来帮助我们快速实现深度神经网络模型,例如Tensorflow/Pytorch/Caffe等都是非常好的选择,而近期大热的keras是Tensorflow2.0版本中非常重要的高阶API,所以本节课老shi打算先给大家简单介绍下Tensorflow的基础知识,最后借助keras来实现一个非常经典的深度学习入门案例——手写数字识别。废话不多说,马上进入正题。

什么是Tensorflow

Tensorflow是谷歌2015年推出的一款深度学习框架,与Pytorch类似,都是目前比较热门的深度学习框架。但Tensorflow与传统的模型搭建方式不同,它是采用数据流图的方式来计算, 所以我们首先得创建一个数据流图,然后再将我们的数据(数据以张量tensor的形式存在)放到数据流图中去计算,节点Nodes在图中表示数学操作,图中的边edges则表示在节点间相互联系的多维数组, 即张量(tensor)。训练模型时tensor会不断地从数据流图中的一个节点flow到另一个节点, 这也是Tensorflow名字的由来。计算图Graph规定了各个变量之间的计算关系,建立好的计算图需要编译以确定其内部细节,而此时的计算图还是一个“空壳子”,里面并没有任何实际的数据,只有当你把需要运算的输入数据放进去后,才能在整个模型中形成数据流,从而得到模型的输出值。打个比方,就像用管道搭建的供水系统,当你在拼接水管的时候,水管里面其实是没有水的,只有等所有的管子都接好了,才能进行供水。具体如下图所示

Tensorflow中的基本概念

计算图(Graph):计算图描述了计算的过程,Tensorflow使用计算图来表示计算任务。

张量(Tensor):Tensorflow使用tensor表示数据。每个tensor是一个类型化的多维数组。规模最小的张量是0阶张量,即标量,也就是一个数;当我们把一些数有序地排列起来,就形成了1阶张量,也就是向量;如果我们继续把一组向量有序排列起来,就得到了一个2阶张量,也就是一个矩阵 ;把矩阵堆起来就是3阶张量,也就得到了一个立方体,我们常见的3通道(3色RGB)的彩色图片也是一个立方体;如果我们继续把立方体堆起来,就得到一个4阶的张量,以此类推。

操作(op):计算图中的节点被称为op(operation的缩写),即操作 op=节点Nodes;一个op获得0个或多个Tensor,执行计算后,就会产生0个或多个Tensor。

会话(Session):计算图必须在“会话”的上下文中执行。会话将计算图的op分发到如CPU或GPU之类的设备上执行。

变量(Variable):运行过程中可以被改变的量,用于维护状态。

Tensorflow2.0相比Tensorflow1.x版本的改进

1、支持tf.data加载数据,使用tf.data创建的输入管道读取训练数据,支持从内存(Numpy)方便地输入数据;

2、取消了会话Session,由静态计算图变成动态计算图,直接打印结果,不需要执行会话的过程;

3、使用tf.keras构建、训练和验证模型,或使用Premade来验证模型,可以直接标准的打包模型(逻辑回归、随机森林),也可以直接使用tf.estimator API 。如果不想从头训练模型,可以使用迁移学习来训练一个使用TensorflowHub模块的Keras或Estimator;

4、使用分发策略进行分发训练,分发策略API可以在不更改定义的情况下,轻松在不同的硬件配置上分发和训练模型,支持一系列的硬件加速器,例如GPU、TPU等;

5、使用SaveModel作为模型保存模块,更好对接线上部署。

最后,我们使用Tensorflow2.0高阶API keras来实现深度学习经典入门案例——手写数字识别,以下是案例代码,有兴趣的同学可以跟着实现一遍。下节课给大家带来卷积神经网络CNN,敬请期待!!

代码语言:javascript
复制
#coding:utf8import numpy as npnp.random.seed(123)#后面只使用keras.model搭建一个简单的全连接网络模型,不用tf.keras中的特性,在此直接用import keras也可以from tensorflow import kerasfrom keras.datasets import mnistfrom keras.utils import np_utilsfrom keras.models import Sequentialfrom keras.layers import Dense,Activationfrom keras.optimizers import RMSprop
# 数据导入(x_train,y_train),(x_test,y_test) = mnist.load_data()print(x_train.shape,y_train.shape)print(x_test.shape,y_test.shape)

# 数据预处理x_train = x_train.reshape(x_train.shape[0],-1)  / 255.0x_test = x_test.reshape(x_test.shape[0],-1) / 255.0y_train = np_utils.to_categorical(y_train,num_classes=10)y_test = np_utils.to_categorical(y_test,num_classes=10)

# 直接使用keras.Sequential()搭建全连接网络模型model = Sequential()model.add(Dense(128, input_shape=(784,)))model.add(Activation('relu'))model.add(Dense(10))model.add(Activation('softmax'))

#lr为学习率,epsilon防止出现0,rho/decay分别对应公式中的beta_1和beta_2rmsprop = RMSprop(lr=0.001,rho=0.9,epsilon=1e-08,decay=0.00001) model.compile(optimizer=rmsprop,loss='categorical_crossentropy',metrics=['accuracy'])print("---------------training--------------")model.fit(x_train,y_train,epochs=5,batch_size=32)print('\n')print("--------------testing----------------")loss,accuracy = model.evaluate(x_test,y_test)print('loss:',loss)print('accuracy:',accuracy)
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-12-14,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 多赞云数据 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档