本文梳理了tf 2.0以上版本的API结构,用于帮助国内的初学者更好更快的了解这个框架,并为检索官方的API文档提供一些关键词。
官方API文档:https://tensorflow.google.cn/api_docs/python/tf?hl=zh-cn
tf中的数据类型为张量:tf.Tensor()
,可以类比numpy中的np.array()
一些特殊的张量:
tf.Variable
:变量。用来存储需要被修改、需要被持久化保存的张量,模型的参数一般都是用变量来存储的。tf.constant
:常量,定义后值和维度不可改变。tf.sparse.SparseTensor
:稀疏张量。除上述特殊张量外,其余创建方式同numpy类似,示例:
t = tf.ones([5,3], dtype=tf.float32)
a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
print(a.shape)
与numpy类似,可以对Tensor进行切片、索引;可以对这些Tensor做各种运算,例如:加减乘除、地板除、布尔运算。
tf.data
加载数据,高效的数据输入管道也可以极大的减少模型训练时间,管道执行的过程包括:从硬盘中读取数据(Extract)、数据的预处理如数据清洗、格式转换(Transform)、加载到计算设备(Load)tf.keras
构建、训练和验证模型,另外tf.estimator
中打包了一些标准的机器学习模型供我们直接使用,当我们不想从头开始训练一个模型时,可以使用TensorFlow Hub
模块来进行迁移学习。tf.distribute.Strategy
实现分布式的训练Checkpoints
或SavedModel
存储模型,前者依赖于创建模型的源代码;而后者与源代码无关,可以用于其他语言编写的模型。加载数据示例代码:
import tensorflow as tf
import multiprocessing
import matplotlib.pyplot as plt
N_CPUS = multiprocessing.cpu_count()
BATCH_SIZE = 32
SEED = 0
def load_and_preproess_image(path):
image = tf.io.read_file(path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, [192,192])
image /= 255.0
return image
# 1. 构建图片路径
# 其中 all_image_paths = ['图片1路径','图片2路径',...,'图片n路径']
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
# 2. 构建图片数据的数据集
image_ds = path_ds.map(load_and_preproess_image, num_parallel_calls=N_CPUS)
# 3. 构建标签的数据集
label_ds = tf.data.Dataset.from_tensor_slices(all_image_labels)
# 4. 将图片和类标压缩为(图片,标签)对
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
# 5. 可视化数据
plt.figure(figsize=(8,8))
for n,image_label in enumerate(image_label_ds.take(4)):
plt.subplot(2,2,n+1)
plt.imshow(image_label[0])
plt.grid(False)
plt.xlabel(image_label[1])
# 6. 打乱数据集
image_count = len(all_image_paths)
ds = image_label_ds.shuffle(buffer_size=image_count, seed=SEED)
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE)
ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) # 让训练和每批次数据加载并行
构建和训练模型示例代码
from tensorflow.keras import layers
# 创建网络,两种方法二选一
model = tf.keras.Sequential([
layers.Dense(64, activation='relu', input_shape=(32,)), # 全连接层
layers.Dense(10, activation='softmax')
])
# 或者
# model = tf.keras.Sequential()
# model.add(layers.Dense(64, activation='relu', input_shape=(32,)))
# model.add(layers.Dense(10, activation='softmax'))
# 编译网络
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
loss='categorical_crossentropy',
metrics=['accuracy'])
# 网络训练(可以是numpy数据(见官方文档),也可以是Dataset数据)
# verbose=1表示以进度条的形式显示训练信息, 验证集可以直接给也可以设置比例
model.fit(ds, epochs=2, validation_split=0.2, verbose=1)
# 模型评估(可以是numpy数据(见官方文档),也可以是Dataset数据)
model.evaluate(ds, steps=30)
# 预测
result = model.predict(data, batch_size=50)
print(result[0])
inputs = tf.keras.Input(shape=(32,))
# 网络层像函数一样被调用,输出和输入都是张量
x = layers.Dense(64, activation='relu')(inputs)
predictions = layers.Dense(10, activation='softmax')(x)
model = tf.keras.Model(inputs=inputs, outputs=predictions)
# 编译和训练同上
模型训练的技巧——callbacks的使用
callbacks = [
# 若验证集上的损失“val_loss”连续两个epoch都没有变化,则提前结束训练
tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),
# 使用TensorBoard把训练的记录保存到 "./logs"
tf.keras.callbacks.TensorBoard(log_dir='./logs')
]
model.fit(ds, epochs=5, callbacks=callbacks, validation_data=val_dataset)
如果安装的是gpu版本的TensorFlow会自动使用gpu,查看可用的GPU的代码:
from tensorflow.python.client import device_lib
def get_available_gpus():
local_device_protos = device_lib.list_local_devices()
return [x.name for x in local_device_protos if x.device_type=='GPU']
print(get_available_gpus())
单机环境下的多GPU训练:
strategy = tf.distribute.MirroredStrategy()
# 优化器及模型的构建和编译必须放在scope()中
with strategy.scope():
model = tf.keras.Sequential([
layers.Dense(64, activation='relu', input_shape=(32,)), # 全连接层
layers.Dense(10, activation='softmax')
])
model.compile(optimizer=tf.keras.optimizers.SGD(0.2), loss='binary_crossentropy')
模型的保存和恢复示例代码:
# 完整模型的保存和读取
model.save('my_model')
model = tf.keras.models.load_model('my_model')
# 模型的权重参数的保存和读取
model.save_weights('my_model.h5', save_format='h5')
model.load_weights('my_model.h5')
# 单独保存模型的结构
json_string = model.to_json()
加载数据tf.data
构建、训练和验证模型tf.keras
兼容模块tf.compat.v1
,这个模块里有完整的TensorFlow1.x的API。