前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >MNIST__数字识别__SOFTMAX

MNIST__数字识别__SOFTMAX

原创
作者头像
代号404
修改2018-08-04 23:05:25
9180
修改2018-08-04 23:05:25
举报
文章被收录于专栏:Deep Learning 笔记

本次MNIST的手写数字识别未采用input_data.py文件,想尝试一下用原始的数据集来运行这个DEMO。

需要注意的一点是,源码中的图片标签采用的的ONE-HOT编码,而数据集中的标签用的是具体的数字。

例如:图片上的数字和标签的值是5,其对应的ONT-HOT编码为[0,0,0,0,0,1,0,0,0,0](分别对应数值【0,1,2,3,4,5,6,7,8,9】) ,也就是长度为10的一维数组的第6个元素为1,其余的全为0。

源码结构:

1.读取MNIST

2.创建占位符(用读取的数据填充这些空占位符)

3.选用交叉熵作为损失函数

4.使用梯度下降法(步长0.02),来使损失函数最小

5.初始化变量

6.开始计算

7.输出识别率

源码:

代码语言:javascript
复制
import tensorflow as tf
import numpy as np
import struct
#  解析IDX文件格式的MNIST数据集,需要用struct模块对二进制文件进行读取操作
#  struct模块中最重要的三个函数是pack() , unpack() 和calcsize()
#  calculate 英 [ˈkælkjuleɪt]  vt.计算  

#  按照给定的格式(fmt)解析字节流string,返回解析出来的tuple
#  tuple = unpack(fmt, string)
#  format  英 [ˈfɔ:mæt] 格式;使格式化 (format在代码中简化为fmt)
#  tuple 英 [tʌpl] 美 [tʌpl]   n.元组,数组

#  按照给定的格式化字符串,把数据封装成字符串(实际上是类似于c结构体的字节流)
#  string = struct.pack(fmt, v1, v2, ...)

#  计算给定的格式(fmt)占用多少字节的内存
#  offset = calcsize(fmt)

import matplotlib.pyplot as plt
#  matplotlib.pyplot是一个命令型函数集合,功能齐全的绘图模块

#------------------------------------ 1 ------------------------------------------
def images_load(filename):
    #   def image_load(filename)用于读取图片数据
    #   file_name表示要访问的文件名 
    
    with open(filename, 'rb') as contents:
    #rb表示该文件以只读方式打开,使用with open()as 的好处在于:读取文件内容后会
    #自动关闭文件,无需手动关闭。
    
        data_buffers = contents.read()
    #   从一个打开的文件读取数据 
    #   buffer   英 [ˈbʌfə(r)]  
    #   n. 缓冲器; 起缓冲作用的人或物; [化] 缓冲液,缓冲剂; [计] 缓冲区
    #   vt. 缓冲       (个人感觉,了解单词的意思,代码会变的亲切一些)      
   
        magic,num,rows,cols = struct.unpack_from('>IIII',data_buffers, 0)
    #   读取图片文件前4个整型数字  
    
        bits = num * rows * cols
    #  整个images数据大小为60000*28*28        
    
        images = struct.unpack_from('>' + str(bits) + 'B', data_buffers, struct.calcsize('>IIII'))
    #   读取images数据
    
        images = np.reshape(images, [num, rows * cols])
    #   转换为[60000,784]型数组
    
    return images
#--------------------------------------- 2 --------------------------------------
def labels_load(filename):
    
    contants = open(filename, 'rb')
    #这里用open()打开文件,读取结束后要用close()关闭    
    
    data_buffers = contants.read()
    
    magic,num = struct.unpack_from('>II', data_buffers, 0) 
    #   读取label文件前2个整形数字,label的长度为num
    #   magic翻译成“魔数”,用于校验下载的文件是否属于MNIST数据集
   
    labels = struct.unpack_from('>' + str(num) + "B", data_buffers, struct.calcsize('>II'))
    #  读取labels数据
        
    contants.close()
    #  关闭文件    
    
    labels = np.reshape(labels, [num])
    #  转换为一维数组
    
    return labels   
#---------------------------------------- 3 ----------------------------------------
#读取训练和测试文件
filename_train_images = 'E:\\MNIST\\train-images.idx3-ubyte'
filename_train_labels = 'E:\\MNIST\\train-labels.idx1-ubyte'
filename_test_images = 'E:\\MNIST\\t10k-images.idx3-ubyte'
filename_test_labels = 'E:\\MNIST\\t10k-labels.idx1-ubyte'
train_images=images_load(filename_train_images)
train_labels=labels_load(filename_train_labels)
test_images=images_load(filename_test_images)
test_labels=labels_load(filename_test_labels)

#------------------------------------- 4 ------------------------------------------

x = tf.placeholder("float", [None, 784]) #输入占位符(每张手写数字有28X28个像素点)
y_ = tf.placeholder("float", [None,10]) #输入占位符(用one-hot编码表示标签的值)

w = tf.Variable(tf.zeros([784,10])) #权重
b = tf.Variable(tf.zeros([10])) #偏置
y = tf.nn.softmax(tf.matmul(x,w) + b) 
# 输入矩阵x与权重矩阵w相乘,加上偏置矩阵b,然后求softmax(sigmoid函数升级版,可以分成多类)
# softmax会将xW+b分成10类,对应数字0-9

cross_entropy = -tf.reduce_sum(y_*tf.log(y))
# 计算交叉熵

train_step = tf.train.GradientDescentOptimizer(0.02).minimize(cross_entropy)
# 使用梯度下降法(步长0.02),来使偏差和最小

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
# 初始化变量

def train_num(n_t):
    xst=train_images[:n_t,:]  
    zst=train_labels[:n_t]  
    yst=np.zeros((n_t,10))
    for i in range(0,n_t-1):
        yst[i][zst[i]]=1
    return xst,yst
#训练图片的数量,标签转换为ONE-HOT编码

def test_num(n_t):
    xst=test_images[:n_t,:]  
    zst=test_labels[:n_t]  
    yst=np.zeros((n_t,10))
    for i in range(0,n_t-1):
        yst[i][zst[i]]=1
    return xst,yst
#测试图片的数量,标签转换为ONE-HOT编码
#======================================= 5 ===========================
xs_t,ys_t=test_num(10000)  
#测试图片10000张

xs,ys=train_num(1300)
#用1300张图片进行训练

sess.run(train_step, feed_dict={x:xs,y_:ys })  
correct_prediction_1 = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy_1 = tf.reduce_mean(tf.cast(correct_prediction_1, "float"))
# 计算训练精度  

print(sess.run(accuracy_1, feed_dict={x: xs_t, y_: ys_t})) 
#输出识别的准确率

#=========================================================

print('GOOD WORK')
#  点个赞
代码语言:javascript
复制
0.7147
GOOD WORK
#运行结果 0.7147,看起来很糟糕……

将训练数据的值由1300提高到60000,结果是0.6803,居然降低了。好吧,总感觉哪里不太对,可又说不上来~

参考资料:

ONE-HOT使用体会 : https://blog.csdn.net/lanhaier0591/article/details/78702558

训练Tensorflow识别手写数字 : https://www.cnblogs.com/tengge/p/6363586.html

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 源码结构:
  • 源码:
  • 参考资料:
相关产品与服务
内容识别
内容识别(Content Recognition,CR)是腾讯云数据万象推出的对图片内容进行识别、理解的服务,集成腾讯云 AI 的多种强大功能,对存储在腾讯云对象存储 COS 的数据提供图片标签、图片修复、二维码识别、语音识别、质量评估等增值服务。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档