原文博客:Doi技术团队 链接地址:https://blog.doiduoyi.com/authors/1584446358138 初心:记录优秀的Doi技术团队学习经历
在本篇文章中,我们将会介绍TensorFlow的安装,TensorFlow是Google公司在2015年11月9日开源的一个深度学习框架。
TensorFlow目前支持4种开发语言,分别是Python(包括Python2和Python3)、Java、Go、C。笔者使用的环境如下:
基于这些环境,我们来安装TensorFlow吧,笔者将会通过两种操作系统来安装,分别是Ubuntu 16.04和Windows 10。
在Ubuntu上我们准备两种安装方式,分别是原生pip、Virtualenv 环境 和 Docker容器,下面我们就在三个环境下安装。
使用原生的pip安装时最简单的,直接安装使用一条命令就可以安装完成了。
首先确认Python环境,Ubuntu会自带Python环境的,不用我们自己安装,使用python3 -V
可以查询安装的Python环境,输出如下:
Python 3.5.2
安装TensorFlow需要使用pip
命令,默认是没有安装的,所以我们需要安装pip
命令:
sudo apt-get install python3-pip python3-dev
这里笔者要说一下,默认的镜像源太慢了,笔者修改成阿里镜像源了,修改方式如下:
sudo cp /etc/apt/sources.list /etc/apt/sources.list.bak
sudo vi /etc/apt/sources.list
deb http://mirrors.aliyun.com/ubuntu/ xenial main
deb-src http://mirrors.aliyun.com/ubuntu/ xenial main
deb http://mirrors.aliyun.com/ubuntu/ xenial-updates main
deb-src http://mirrors.aliyun.com/ubuntu/ xenial-updates main
deb http://mirrors.aliyun.com/ubuntu/ xenial universe
deb-src http://mirrors.aliyun.com/ubuntu/ xenial universe
deb http://mirrors.aliyun.com/ubuntu/ xenial-updates universe
deb-src http://mirrors.aliyun.com/ubuntu/ xenial-updates universe
deb http://mirrors.aliyun.com/ubuntu/ xenial-security main
deb-src http://mirrors.aliyun.com/ubuntu/ xenial-security main
deb http://mirrors.aliyun.com/ubuntu/ xenial-security universe
deb-src http://mirrors.aliyun.com/ubuntu/ xenial-security universe
sudo apt update
安装完成pip
命令之后,可以使用pip3 -V
查看是否已经安装成功及安装的版本,输出如下,官方要求pip的版本要不小于8.1:
pip 8.1.1 from /usr/lib/python3/dist-packages (python 3.5)
wget https://bootstrap.pypa.io/get-pip.py
sudo python3 get-pip.py
一切多准备完成,那就可以开始安装TensorFlow了,只要使用以下一条命令就可以:
sudo pip3 install tensorflow
pip
安装的同样的操作:sudo pip3 install -i https://mirrors.aliyun.com/pypi/simple/ tensorflow
安装完成之后,可以使用以下命令查看是否完成及安装的版本:
pip3 list
**注意:**如果在运行报以下错误,多数是电脑的CPU不支持AVX指令集:
非法指令 (核心已转储)
如何知道自己的电脑是不是支持AVX指令集呢,可以通用以下的命令查看,输出Yes
就是支持,No
就是不支持:
if cat /proc/cpuinfo | grep -i avx; then echo Yes; else echo No; fi
TensorFlow在1.6版本之后都会使用AVX指令集,如果读者的电脑不支持AVX指令集,就要安装低版本的,如下是安装1.5版本的:
pip3 install tensorflow==1.5
安装完成之后,可以进行测试,测试情阅读最后的测试部分。
首先通过以下的命令来安装 pip 和 Virtualenv:
sudo apt-get install python3-pip python3-dev python-virtualenv
然后通过下面的命令来创建 Virtualenv 环境:
virtualenv --system-site-packages -p python3 ~/tensorflow
最后通过下面的命令激活 Virtualenv 环境:
source ~/tensorflow/bin/activate
这时会发现控制台已经发生了变化,变成如下状态,这表明已经进入了 Virtualenv 环境:
(tensorflow) yeyupiaoling@tensorflow:~$
接下来的操作都是在这个Virtualenv 环境下操作,比我们的pip命令也是在这里的,可以使用pip3 -V
查看:
pip 10.0.1 from /home/yeyupiaoling/tensorflow/lib/python3.5/site-packages/pip (python 3.5)
我们在Virtualenv 环境里通过以下的命令即可完成安装TensorFlow:
pip3 install tensorflow
不支持AVX的请安装1.5版本:
pip3 install tensorflow==1.5
使用完成之后,可以通过以下命令退出Virtualenv 环境:
deactivate
要使用Docker,就要先安装Docker,以下命令就是安装Docker的命令:
sudo apt-get install docker.io
安装完成之后,可以使用docker --version
查看Docker的版本,如果有显示,就证明安装成功了。
然后我们可以通过以下的命令拉取TensorFlow的镜像,我们也可以通过dockerhub获取更多Docker镜像:
docker pull tensorflow/tensorflow:1.8.0-py3
如果电脑不支持AVX指令集的,请安装低版本的TensorFlow镜像:
docker pull tensorflow/tensorflow:1.5.1-py3
拉取完成镜像,就可以使用docker images
查看已经安装的镜像:
REPOSITORY TAG IMAGE ID CREATED SIZE
tensorflow/tensorflow 1.8.0-py3 a83a3dd79ff9 2 months ago 1.33 GB
使用TensorFlow的Docker镜像有个好处就是自带了jupyter notebook,启动镜像之后可以直接使用jupyter。
sudo docker run -it -p 80:8888 tensorflow/tensorflow:1.8.0-py3
然后终端会输出以下信息,要注意输出的token:
[I 07:08:38.160 NotebookApp] Writing notebook server cookie secret to /root/.local/share/jupyter/runtime/notebook_cookie_secret
[W 07:08:38.177 NotebookApp] WARNING: The notebook server is listening on all IP addresses and not using encryption. This is not recommended.
[I 07:08:38.186 NotebookApp] Serving notebooks from local directory: /notebooks
[I 07:08:38.186 NotebookApp] 0 active kernels
[I 07:08:38.187 NotebookApp] The Jupyter Notebook is running at:
[I 07:08:38.187 NotebookApp] http://[all ip addresses on your system]:8888/?token=ab489f0445846cb7f9d5c9613edcf7b9537cd245dbecf2a6
[I 07:08:38.187 NotebookApp] Use Control-C to stop this server and shut down all kernels (twice to skip confirmation).
[C 07:08:38.187 NotebookApp]
Copy/paste this URL into your browser when you connect for the first time,
to login with a token:
http://localhost:8888/?token=ab489f0445846cb7f9d5c9613edcf7b9537cd245dbecf2a6
然后我们在浏览器上输入IP地址,如何是在本地,那就就输入localhost,得到的页面如下,输入终端输出的token和新密码就可以登录使用jupyter了:
得到的jupyter网页如下:
如果停止运行镜像了,可以使用以下的命令找到之前使用这个进行run生成的一个容器:
sudo docker ps -a
会得到以下信息,其中最重要的是CONTAINER ID:
CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES
44aa680ac51f tensorflow/tensorflow:1.8.0-py3 "/run_jupyter.sh -..." 14 minutes ago Exited (0) 21 seconds ago cranky_elion
通过这个CONTAINER ID可以再次启动这个容器,这样就不用每次都run一个容器出来,占用磁盘容量,同时也可以保存原来的环境,可以使用以下的命令启动容器:
sudo docker start 44aa680ac51f
启动之后是在后台运行的,那么如何让容器有信息输入的同时会输出到控制台呢,可以用使用以下的命令实现:
sudo docker attach 44aa680ac51f
如果要以终端的方式进入到容器中,可以使用以下的命令:
sudo docker exec -it 44aa680ac51f /bin/bash
安装完成之后,可以进行测试,测试情阅读最后的测试部分。
在Windows上,笔者同样介绍三种安装方式,分别是原生pip、Docker容器、Windows的Linux子系统。
在Windows上默认是没有安装Python的,所以要先安装Python,这里笔者安装Python 3.6.5,首先到Python官网上下载对应的版本,必须是64位的Python。然后安装Python,安装过程笔者就不介绍了,主要安装完成之后,还有配置一下环境变量。在默认的安装路径为:
C:\Python36
首页我们要在环境变量的Path上添加以下l两条环境变量:
C:\Python36
C:\Python36\Scripts
如果读者同时还安装了Python2,笔者建议修改一下Python3的文件,首先把C:\Python36\Scripts
里面的pip.exe
删除,避免与Python2的冲突,然后把C:\Python36
的python.exe
和pythonw.exe
修改成python.exe
和pythonw3.exe
。以后在使用Python3的时候,分别是使用pip3
和python3
命令。
现在就开始安装TensorFlow,命令如下:
pip3 install tensorflow
如果在使用pip3
报以下错误:
Fatal error in launcher: Unable to create process using '"'
可以使用以下命令修复:
python3 -m pip install --upgrade pip
**注意:**如果以下错误,是因为缺少DLL动态库,可以看到最后提供下载动态库的链接:
Traceback (most recent call last):
File "C:\Python36\lib\site-packages\tensorflow\python\platform\self_check.py", line 47, in preload_check
ctypes.WinDLL(build_info.msvcp_dll_name)
File "C:\Python36\lib\ctypes\__init__.py", line 348, in __init__
self._handle = _dlopen(self._name, mode)
OSError: [WinError 126] 找不到指定的模块。
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "test.py", line 1, in <module>
import tensorflow as tf
File "C:\Python36\lib\site-packages\tensorflow\__init__.py", line 24, in <module>
from tensorflow.python import *
File "C:\Python36\lib\site-packages\tensorflow\python\__init__.py", line 49, in <module>
from tensorflow.python import pywrap_tensorflow
File "C:\Python36\lib\site-packages\tensorflow\python\pywrap_tensorflow.py", line 30, in <module>
self_check.preload_check()
File "C:\Python36\lib\site-packages\tensorflow\python\platform\self_check.py", line 55, in preload_check
% build_info.msvcp_dll_name)
ImportError: Could not find 'msvcp140.dll'. TensorFlow requires that this DLL be installed in a directory that is named in your %PATH% environment variable. You may install this DLL by downloading Visual C++ 2015 Redistributable Update 3 from this URL: https://www.microsoft.com/en-us/download/details.aspx?id=53587
我们通过这个链接去下载并安装这个动态库即可:
https://www.microsoft.com/en-us/download/details.aspx?id=53587
如果在执行TensorFlow程度的是报以下错误,多数是CPU不支持AVX指令集:
Traceback (most recent call last):
File "C:\Python36\lib\site-packages\tensorflow\python\pywrap_tensorflow_internal.py", line 14, in swig_import_helper
return importlib.import_module(mname)
File "C:\Python36\lib\importlib\__init__.py", line 126, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
File "<frozen importlib._bootstrap>", line 994, in _gcd_import
File "<frozen importlib._bootstrap>", line 971, in _find_and_load
File "<frozen importlib._bootstrap>", line 955, in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 658, in _load_unlocked
File "<frozen importlib._bootstrap>", line 571, in module_from_spec
File "<frozen importlib._bootstrap_external>", line 922, in create_module
File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
ImportError: DLL load failed: 动态链接库(DLL)初始化例程失败。
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "C:\Python36\lib\site-packages\tensorflow\python\pywrap_tensorflow.py", line 58, in <module>
from tensorflow.python.pywrap_tensorflow_internal import *
File "C:\Python36\lib\site-packages\tensorflow\python\pywrap_tensorflow_internal.py", line 17, in <module>
_pywrap_tensorflow_internal = swig_import_helper()
File "C:\Python36\lib\site-packages\tensorflow\python\pywrap_tensorflow_internal.py", line 16, in swig_import_helper
return importlib.import_module('_pywrap_tensorflow_internal')
File "C:\Python36\lib\importlib\__init__.py", line 126, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
ModuleNotFoundError: No module named '_pywrap_tensorflow_internal'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "test.py", line 1, in <module>
import tensorflow as tf
File "C:\Python36\lib\site-packages\tensorflow\__init__.py", line 24, in <module>
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
File "C:\Python36\lib\site-packages\tensorflow\python\__init__.py", line 49, in <module>
from tensorflow.python import pywrap_tensorflow
File "C:\Python36\lib\site-packages\tensorflow\python\pywrap_tensorflow.py", line 74, in <module>
raise ImportError(msg)
ImportError: Traceback (most recent call last):
File "C:\Python36\lib\site-packages\tensorflow\python\pywrap_tensorflow_internal.py", line 14, in swig_import_helper
return importlib.import_module(mname)
File "C:\Python36\lib\importlib\__init__.py", line 126, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
File "<frozen importlib._bootstrap>", line 994, in _gcd_import
File "<frozen importlib._bootstrap>", line 971, in _find_and_load
File "<frozen importlib._bootstrap>", line 955, in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 658, in _load_unlocked
File "<frozen importlib._bootstrap>", line 571, in module_from_spec
File "<frozen importlib._bootstrap_external>", line 922, in create_module
File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
ImportError: DLL load failed: 动态链接库(DLL)初始化例程失败。
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "C:\Python36\lib\site-packages\tensorflow\python\pywrap_tensorflow.py", line 58, in <module>
from tensorflow.python.pywrap_tensorflow_internal import *
File "C:\Python36\lib\site-packages\tensorflow\python\pywrap_tensorflow_internal.py", line 17, in <module>
_pywrap_tensorflow_internal = swig_import_helper()
File "C:\Python36\lib\site-packages\tensorflow\python\pywrap_tensorflow_internal.py", line 16, in swig_import_helper
return importlib.import_module('_pywrap_tensorflow_internal')
File "C:\Python36\lib\importlib\__init__.py", line 126, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
ModuleNotFoundError: No module named '_pywrap_tensorflow_internal'
Failed to load the native TensorFlow runtime.
See https://www.tensorflow.org/install/install_sources#common_installation_problems
for some common reasons and solutions. Include the entire stack trace
above this error message when asking for help.
那就要安装低版本的TensorFlow:
pip3 install tensorflow==1.5
关于如果在Windows上安装Docker容器,可以参考笔者的《我的PaddlePaddle学习之路》笔记一——PaddlePaddle的安装》的在Windows上安装Docker容器部分,这里就不在展开介绍了。
启动容器之后,就可以拉取TensorFlow的镜像了:
docker pull tensorflow/tensorflow:1.8.0-py3
同样可以使用docker images
查看已经安装的镜像:
REPOSITORY TAG IMAGE ID CREATED SIZE
tensorflow/tensorflow 1.8.0-py3 a83a3dd79ff9 2 months ago 1.33 GB
关于如果在Windows上安装Linux子系统,可以参考笔者之前的文章《Windows10安装Linux子系统Ubuntu》
安装完成Linux子系统之后,就可以在PowerShell上输入bash
命令进入到Linux子系统,在这个子系统上安装TensorFlow请参考Ubuntu使用原生pip安装TensorFlow的方法,这个笔者就不在重复介绍了。
在这一部分,我们介绍如何在Ubuntu上使用TensorFlow的源码编译安装。
git clone https://github.com/tensorflow/tensorflow
sudo apt-get install python3-numpy python3-dev python3-pip python3-wheel
1、安装依赖库
sudo apt-get install pkg-config zip g++ zlib1g-dev unzip python
2、下载bazel-0.15.0-installer-linux-x86_64.sh
文件,下载地址如下:
https://github.com/bazelbuild/bazel/releases
3、运行安装Bazel
chmod +x bazel-0.15.0-installer-linux-x86_64.sh
./bazel-0.15.0-installer-linux-x86_64.sh --user
4、添加到环境变量,编写vim ~/.bashrc
,在最后的加上以下信息:
export PATH="$PATH:$HOME/bin"
cd tensorflow
git branch -a
输出的版本信息:
root@tensorflow:/home/yeyupiaoling/test/tensorflow# git branch -a
* master
remotes/origin/0.6.0
remotes/origin/HEAD -> origin/master
remotes/origin/achowdhery-patch-1
remotes/origin/andrewharp-patch-1
remotes/origin/martinwicke-patch-1
remotes/origin/martinwicke-patch-2
remotes/origin/master
remotes/origin/r0.10
remotes/origin/r0.11
remotes/origin/r0.12
remotes/origin/r0.7
remotes/origin/r0.8
remotes/origin/r0.9
remotes/origin/r1.0
remotes/origin/r1.1
remotes/origin/r1.2
remotes/origin/r1.3
remotes/origin/r1.4
remotes/origin/r1.5
remotes/origin/r1.6
remotes/origin/r1.7
remotes/origin/r1.8
remotes/origin/r1.9
remotes/origin/release-notes-1.9-tfdbg
remotes/origin/yifeif-patch-1
remotes/origin/yifeif-patch-2
remotes/origin/yifeif-patch-3
比如笔者想切换到1.5版本,使用以下命令即可切换:
git checkout origin/r1.5
./configure
/usr/bin/python3.5
,也可以指定是否要编译GPU版本的,具体读者可以查看笔者的配置信息,笔者多数是默认的。Extracting Bazel installation...
WARNING: --batch mode is deprecated. Please instead explicitly shut down your Bazel server using the command "bazel shutdown".
You have bazel 0.15.0 installed.
Please specify the location of python. [Default is /usr/bin/python]: /usr/bin/python3.5
Found possible Python library paths:
/usr/local/lib/python3.5/dist-packages
/usr/lib/python3/dist-packages
Please input the desired Python library path to use. Default is [/usr/local/lib/python3.5/dist-packages]
Do you wish to build TensorFlow with jemalloc as malloc support? [Y/n]:
jemalloc as malloc support will be enabled for TensorFlow.
Do you wish to build TensorFlow with Google Cloud Platform support? [Y/n]: n
No Google Cloud Platform support will be enabled for TensorFlow.
Do you wish to build TensorFlow with Hadoop File System support? [Y/n]:
Hadoop File System support will be enabled for TensorFlow.
Do you wish to build TensorFlow with Amazon AWS Platform support? [Y/n]:
Amazon AWS Platform support will be enabled for TensorFlow.
Do you wish to build TensorFlow with Apache Kafka Platform support? [Y/n]:
Apache Kafka Platform support will be enabled for TensorFlow.
Do you wish to build TensorFlow with XLA JIT support? [y/N]:
No XLA JIT support will be enabled for TensorFlow.
Do you wish to build TensorFlow with GDR support? [y/N]:
No GDR support will be enabled for TensorFlow.
Do you wish to build TensorFlow with VERBS support? [y/N]:
No VERBS support will be enabled for TensorFlow.
Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]:
No OpenCL SYCL support will be enabled for TensorFlow.
Do you wish to build TensorFlow with CUDA support? [y/N]: N
No CUDA support will be enabled for TensorFlow.
Do you wish to download a fresh release of clang? (Experimental) [y/N]:
Clang will not be downloaded.
Do you wish to build TensorFlow with MPI support? [y/N]:
No MPI support will be enabled for TensorFlow.
Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native]:
Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]:
Not configuring the WORKSPACE for Android builds.
Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See tools/bazel.rc for more details.
--config=mkl # Build with MKL support.
--config=monolithic # Config for mostly static monolithic build.
Configuration finished
bazel build --config=opt //tensorflow/tools/pip_package:build_pip_package
bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
/tmp/tensorflow_pkg/tensorflow-1.5.1-cp35-cp35m-linux_x86_64.whl
,可能读者的版本会跟笔者的不一样,根据实际的版本信息安装:sudo pip install /tmp/tensorflow_pkg/tensorflow-1.5.1-cp35-cp35m-linux_x86_64.whl
到这里就完成了TensorFlow的编译安装,安装完成之后,可以参考文章的最后一部分进行测试环境。
安装完成之后,我们要测试一下环境是不是已经成功安装并且可以正常使用了。
首先编译一个测试test1.py
文件:
import tensorflow as tf
hello = tf.constant('Hello, TensorFlow!')
sess = tf.Session()
print(sess.run(hello))
然后我们执行这个文件python3 test1.py
就可以运行它了,正常情况下会输出以下内容:
2018-07-08 15:11:05.240607: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
b'Hello, TensorFlow!'
我们也可以编写一个稍微有训练效果的程序test2.py
:
import tensorflow as tf
import numpy as np
# 使用 NumPy 生成假数据(phony data), 总共 100 个点.
x_data = np.float32(np.random.rand(2, 100)) # 随机输入
y_data = np.dot([0.100, 0.200], x_data) + 0.300
# 构造一个线性模型
#
b = tf.Variable(tf.zeros([1]))
W = tf.Variable(tf.random_uniform([1, 2], -1.0, 1.0))
y = tf.matmul(W, x_data) + b
# 最小化方差
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
# 初始化变量
init = tf.initialize_all_variables()
# 启动图 (graph)
sess = tf.Session()
sess.run(init)
# 拟合平面
for step in range(0, 201):
sess.run(train)
if step % 20 == 0:
print (step, sess.run(W), sess.run(b))
# 得到最佳拟合结果 W: [[0.100 0.200]], b: [0.300]
同样我们执行它python3 test1.py
可以得到以下信息:
WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/python/util/tf_should_use.py:118: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.
2018-07-08 15:14:15.455774: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
0 [[0.260745 0.56325 ]] [-0.00149411]
20 [[0.15775657 0.30871654]] [0.20844586]
40 [[0.11969341 0.23245212]] [0.27153042]
60 [[0.10656733 0.20975856]] [0.29113895]
80 [[0.10215606 0.2029533 ]] [0.29723996]
100 [[0.10069981 0.20089868]] [0.2991398]
120 [[0.10022521 0.20027474]] [0.2997318]
140 [[0.10007201 0.20008431]] [0.29991636]
160 [[0.10002291 0.20002596]] [0.2999739]
180 [[0.10000726 0.20000802]] [0.29999185]
200 [[0.1000023 0.20000248]] [0.29999745]
以上是在终端上操作的,那么使用Docker应该如何执行这些文件呢。有两种方法,一种就是以命令终端的方式进入到TensorFlow镜像中,之后的操作就跟在Ubuntu操作差不多了:
docker run -it -v $PWD:/work tensorflow/tensorflow:1.8.0-py3 /bin/bash
另一种就是挂载目录到镜像上,然后直接通过命令执行代码文件:
docker run -it -v $PWD:/work -w /work tensorflow/tensorflow:1.8.0-py3 python3 /work/test1.py
这里笔者使用官方提供的模型,这里官方提供的丰富的模型。这次笔者使用的是mobilenet_v1_1.0_224.tgz模型,我们下载这个模型之后解压可以以下文件:
我们使用到的模型文件是mobilenet_v1_1.0_224_frozen.pb
,其中mobilenet_v1_1.0_224_info.txt
是说明网络输入输出的字段,该文件的内容如下:
Model: mobilenet_v1_1.0_224
Input: input
Output: MobilenetV1/Predictions/Reshape_1
有了上面的模型,我们就来编写预测代码,全部的代码如下:
import numpy as np
import tensorflow as tf
from PIL import Image
# 数据预处理
def load_image(file):
im = Image.open(file)
im = im.resize((224, 224), Image.ANTIALIAS)
im = np.array(im).astype(np.float32)
# 减去均值
im -= 128.0
im /= 128.0
im = np.expand_dims(im, axis=0)
return im
def infer(image_path, pd_path):
# 将(frozen)TensorFlow模型载入内存
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(pd_path, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
# 设置探测图的输入和输出张量
image_tensor = detection_graph.get_tensor_by_name('input:0')
detection_classes = detection_graph.get_tensor_by_name('MobilenetV1/Predictions/Reshape_1:0')
# infer image
image_np = load_image(image_path)
# Actual detection.
result = sess.run([detection_classes], feed_dict={image_tensor: image_np})
result = np.squeeze(result)
idx = np.argsort(-result)
label = idx[0] - 1
print("result label is :", label)
if __name__ == '__main__':
image_path = "0b77aba2-9557-11e8-a47a-c8ff285a4317.jpg"
pd_path = 'models/mobilenet_v1_1.0_224_frozen.pb'
infer(image_path, pd_path)
其中以下这个函数是数据预处理,处理方式要跟训练的时候一样:
# 数据预处理
def load_image(file):
im = Image.open(file)
im = im.resize((224, 224), Image.ANTIALIAS)
im = np.array(im).astype(np.float32)
# 减去均值
im -= 128.0
im /= 128.0
im = np.expand_dims(im, axis=0)
return im
以下的代码片段是把模型加载到内存中,这个模型就是我们使用的mobilenet_v1_1.0_224_frozen.pb
模型。
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(pd_path, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
设置模型的输出输出,这字段就是来自刚才的mobilenet_v1_1.0_224_info.txt
文件中。
image_tensor = detection_graph.get_tensor_by_name('input:0')
detection_classes = detection_graph.get_tensor_by_name('MobilenetV1/Predictions/Reshape_1:0')
以下的代码片段就是输入图片得到预测结果的,如果需要预测多张图片,可以把该代码片段放在循环中。要注意的是,输出的大小是1001,包括的第一个label是background。
image_np = load_image(image_path)
result = sess.run([detection_classes], feed_dict={image_tensor: image_np})
result = np.squeeze(result)
idx = np.argsort(-result)
label = idx[0]
print("result label is :", label)