前一段时间写了系列的机器学习入门,本期打算写深度学习入门数据集,第一个入手的是Cifar-10。Cifar-10数据集主要用来做图像识别。这个数据集包含图像和标签,图像信息由32*32像素大小组成,标签包含10个类别(飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车)。 这个数据集的目的是,用这些标注好的数据训练深度学习模型,使模型能够识别图片中的目标。比如,我们可以通过这个神经网络识别猫vs狗。
官网地址 官网上提供多种格式数据集,我们选bin。首先观察前25条记录。由于图像像素32*32,很多图像人眼也是难以进行辨别。
CIFA-10 前25条数据
相关代码:
import numpy as np
from scipy.misc import imsave
import matplotlib.pyplot as plt
import pylab
filename = '/Users/wangsen/ai/13/models-master/data/cifar-10-batches-bin/test_batch.bin'
label_mate = '/Users/wangsen/ai/13/models-master/data/cifar-10-batches-bin/batches.meta.txt'
labels_txt = open(label_mate,"r").read().strip().split("\n")
bytestream = open(filename, "rb")
buf = bytestream.read(25 * (1 + 32 * 32 * 3))
bytestream.close()
data = np.frombuffer(buf, dtype=np.uint8)
data = data.reshape(25, 1 + 32*32*3)
labels_images = np.hsplit(data, [1])
labels = labels_images[0].reshape(25)
images = labels_images[1].reshape(25, 32, 32, 3)
fig, axes1 = plt.subplots(5,5,figsize=(4,5))
# for itr,label in enumerate(labels):
# print(itr,":",labels_txt[label])
i = 0
for j in range(5):
for k in range(5):
img = np.reshape(images[i],(3,32,32))
img = img.transpose(1,2,0)
axes1[j][k].set_axis_off()
axes1[j][k].imshow(img)
axes1[j][k].set_title(labels_txt[labels[i]])
i=i+1
pylab.show()
在Tensorflow 官网教程里,有一个CIFAR-10训练程序的例子。官网 代码下载地址:https://github.com/tensorflow/models 代码位置models/tutorials/image/cifar10/
>python cifar10_train.py
,如果数据集没有下载,那么要重新下载数据集,运行结果如下:
Filling queue with 20000 CIFAR images before starting to train. This will take a few minutes.
2019-02-20 13:42:05.167927: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2019-02-20 13:42:09.260566: step 0, loss = 4.67 (304.9 examples/sec; 0.420 sec/batch)
2019-02-20 13:42:13.762996: step 10, loss = 4.63 (284.3 examples/sec; 0.450 sec/batch)
2019-02-20 13:42:18.095651: step 20, loss = 4.49 (295.4 examples/sec; 0.433 sec/batch)
2019-02-20 13:42:22.444906: step 30, loss = 4.50 (294.3 examples/sec; 0.435 sec/batch)
2019-02-20 13:42:27.136578: step 40, loss = 4.40 (272.8 examples/sec; 0.469 sec/batch)
2019-02-20 13:42:31.833072: step 50, loss = 4.32 (272.5 examples/sec; 0.470 sec/batch)
官方给出的训练数据如下,我的主机Mac air2018 i7 2核,快赶上Tesla K20m的训练速度了。那么需要许梿
A binary to train CIFAR-10 using a single GPU.
Accuracy:
cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of
data) as judged by cifar10_eval.py.
Speed: With batch_size 128.
System | Step Time (sec/batch) | Accuracy
------------------------------------------------------------------
1 Tesla K20m | 0.35-0.60 | ~86% at 60K steps (5 hours)
1 Tesla K40m | 0.25-0.35 | ~86% at 100K steps (4 hours)
当训练结束,可以运行评估代码,评估代码在10000张图片上进行预测,判断预测准确率。
python cifar10_eval.py
设置训练step1000步,准确率在60%。
2019-02-20 15:59:41.109588: precision @ 1 = 0.606
通过实验,训练在100k时,准确率为86%。
测试代码
对大图片的预测效果较差,需要将图片用较好算法压缩到50px以下,实测预测准确率不到50%。
# -*- coding:utf-8 -*-
import tensorflow as tf
from tensorflow.python.ops.image_ops_impl import ResizeMethod
from prettytable import PrettyTable
import cifar10
import numpy as np
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
FLAGS = tf.app.flags.FLAGS
# 设置存储模型训练结果的路径
tf.app.flags.DEFINE_string('checkpoint_dir', '/Users/wangsen/ai/13/models-master/tutorials/image/cifar10/cifar10_train',
"""Directory where to read model checkpoints.""")
tf.app.flags.DEFINE_string('class_dir', '//Users/wangsen/ai/13/models-master/data/cifar-10-batches-bin/',
"""存储文件batches.meta.txt的目录""")
tf.app.flags.DEFINE_string('test_file', '/Users/wangsen/Desktop/1.jpeg', """测试用的图片""")
IMAGE_SIZE = 24
def evaluate_images(images): # 执行验证
logits = cifar10.inference(images)
load_trained_model(logits=logits)
def load_trained_model(logits):
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
# 从训练模型恢复数据
saver = tf.train.Saver()
saver.restore(sess, ckpt.model_checkpoint_path)
else:
print('No checkpoint file found')
return
# 从文件以字符串方式获取10个类标签,使用制表格分割
cifar10_class = np.loadtxt(FLAGS.class_dir + "batches.meta.txt", str, delimiter='\t')
# 预测最大的三个分类
top_k_pred = tf.nn.top_k(logits, k=3)
output = sess.run(top_k_pred)
probability = np.array(output[0]).flatten() # 取出概率值,将其展成一维数组
index = np.array(output[1]).flatten()
# 使用表格的方式显示
tabel = PrettyTable(["index", "class", "probability"])
tabel.align["index"] = "l"
tabel.padding_width = 1
for i in np.arange(index.size):
tabel.add_row([index[i], cifar10_class[index[i]], probability[i]])
print(tabel)
lena = mpimg.imread(FLAGS.test_file) # 读取和代码处于同一目录下的 lena.png
plt.imshow(lena) # 显示图片
plt.axis('off') # 不显示坐标轴
plt.show()
def img_read(filename):
if not tf.gfile.Exists(filename):
tf.logging.fatal('File does not exists %s', filename)
image_data = tf.image.convert_image_dtype(tf.image.decode_jpeg(tf.read_file(filename),
channels=3), dtype=tf.float32)
height = IMAGE_SIZE
width = IMAGE_SIZE
image = tf.image.resize_images(image_data, (height, width), method=ResizeMethod.BILINEAR)
image = tf.expand_dims(image, -1)
image = tf.reshape(image, (1, 24, 24, 3))
return image
def main(argv=None): # pylint: disable=unused-argument
filename = FLAGS.test_file
images = img_read(filename)
evaluate_images(images)
if __name__ == '__main__':
tf.app.run()