很多正在入门或刚入门TensorFlow机器学习的同学希望能够通过自己指定图片源对模型进行训练,然后识别和分类自己指定的图片。但是,在TensorFlow官方入门教程中,并无明确给出如何把自定义数据输入训练模型的方法。现在,我们就参考官方入门课程《Deep MNIST for Experts》一节的内容(传送门:https://www.tensorflow.org/get_started/mnist/pros),介绍如何将自定义图片输入到TensorFlow的训练模型。
我们先看下该节课程中涉及到mnist图片调用的代码:
对于刚接触TensorFlow的同学,要修改上述代码,可能会较为吃力。我也是经过一番摸索,才成功调用自己的图片集。
要实现输入自定义图片,需要自己先准备好一套图片集。为节省时间,我们把mnist的手写体数字集一张一张地解析出来,存放到自己的本地硬盘,保存为bmp格式,然后再把本地硬盘的手写体图片一张一张地读取出来,组成集合,再输入神经网络。mnist手写体数字集的提取方式详见《如何从TensorFlow的mnist数据集导出手写体数字图片》。
将mnist手写体数字集导出图片到本地后,就可以仿照以下python代码,实现自定义图片的训练:
上述python代码的执行结果截图如下:
对于上述代码中与模型构建相关的代码,请查阅官方《Deep MNIST for Experts》一节的内容进行理解。在本文中,需要重点掌握的是如何将本地图片源整合成为feed_dict可接受的格式。其中最关键的是这两行:
它们对应于feed_dict的两个placeholder:
这样一看,是不是很简单?
权威发布有关Imagination公司GPU、人工智能以及连接IP、无线IP最新资讯,提供有关物联网、可穿戴、通信、汽车电子、医疗电子等应用信息,每日更新大量信息,让你紧跟技术发展,欢迎关注!伸出小手按一下二维码我们就是好朋友!
领取专属 10元无门槛券
私享最新 技术干货