前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >Tensorflow 回调快速入门

Tensorflow 回调快速入门

作者头像
磐创AI
发布于 2021-10-27 08:04:26
发布于 2021-10-27 08:04:26
1.4K00
代码可运行
举报
运行总次数:0
代码可运行


磐创AI分享

作者 | ashish0765

编译 | Flin

来源 | analyticsvidhya

什么是 Tensorflow 回调?

Tensorflow 回调是在训练深度学习模型时在特定时刻执行的函数或代码块。

我们都熟悉深度学习模型的训练过程。随着模型变得越来越复杂,训练时间也显着增加。因此,模型通常需要花费数小时来训练。

在训练模型之前的工作中,我们修复了所有选项和参数,例如学习率、优化器、损失等并开始模型训练。一旦训练过程开始,就无法暂停训练,以防我们想要更改一些参数。

此外,在某些情况下,当模型已经训练了几个小时,而我们想在后期调整一些参数时,这是不可能的。而这就是 TensorFlow 回调派上用场的地方。

如何使用回调

  1. 首先定义回调
  2. 在调用 model.fit() 时传递回调
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# Stop training if NaN is encountered
NanStop = TerminateOnNaN()
# Decrease lr by 10% 
LrValAccuracy = ReduceLROnPlateau(monitor='val_accuracy', patience=1, factor= 0.9, mode='max', verbose=0)
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
model.fit(X_train,y_train,
epochs=10,
validation_data=(X_test,y_test),
callbacks = [NanStop, LrValAccuracy])

让我们来看看一些最有用的回调

提前停止

当我们训练模型时,我们通常会查看指标以监控模型的表现。通常,如果我们看到极高的指标,我们可以得出结论,我们的模型过度拟合,如果我们的指标很低,那么我们就欠拟合了。

如果指标增加到某个范围以上,我们可以停止训练以防止过度拟合。EarlyStopping 回调允许我们做到这一点。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
early_stop_cb = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', min_delta=0, patience=0, verbose=0,
    mode='auto'
)
  • monitor:你在训练时要监视的指标
  • min_delta:你要考虑作为对前一个时期的改进的指标的最小变化量
  • patience:你等待指标等待的时期数。否则,你将停止训练。
  • verbose:0:不打印任何内容,1:显示进度条,2:仅打印时期号
  • mode :
  • “auto” – 尝试从给定的指标中自动检测行为
  • “min” – 如果指标停止下降,则停止训练
  • “max” – 如果指标停止增加则停止训练

Lambda回调

此回调用于在训练过程中的特定时间调用某些 lambda 函数。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
tf.keras.callbacks.LambdaCallback(
    on_epoch_begin=None, on_epoch_end=None, on_batch_begin=None, on_batch_end=None,
    on_train_begin=None, on_train_end=None, **kwargs
)

在这里,我们可以传递我们需要在指定时间执行的任何 lambda 函数。

让我们看看参数是什么意思

  • on_epoch_begin:在每个时期开始时调用该函数。
  • on_epoch_begin:在每个时期结束时调用该函数。
  • on_batch_begin:在每批开始时调用该函数。
  • on_batch_end:在每批结束时调用该函数。
  • on_train_begin:模型开始训练时调用该函数
  • on_train_end:模型训练完成时调用
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
print_batch_callback = LambdaCallback(
    on_batch_begin=lambda bat,log: print(bat),
    on_batch_begin=lambda bat,log: print(bat)
)

学习率调度器

训练过程中最常见的任务之一是改变学习率。通常,随着模型接近损失最小值(最佳拟合),我们逐渐开始降低学习率以获得更好的收敛性。

让我们看一个简单的例子,我们希望每 3 个 epoch 将学习率降低 5%。这里我们需要向 schedule 函数传递一个参数,该参数指定学习率变化的逻辑。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
def schedule(epoch,lr):
  if epoch % 3 == 0:
    lr = lr - (lr*.05)
    return lr
  return lr

# Decrease lr by 5% for every 3rd epoch
LrScheduler = tf.keras.callbacks.LearningRateScheduler(schedule,verbose=1)
模型检查点

我们使用这个回调来以不同的频率保存我们的模型。这允许我们在中间步骤保存权重,以便在需要时我们可以稍后加载权重。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
tf.keras.callbacks.ModelCheckpoint(
    filepath, monitor='val_loss', verbose=0, save_best_only=False,
    save_weights_only=False, mode='auto', save_freq='epoch'
)

filepath:模型所在的位置 monitor:要监视的度量 save_best_only:True:仅保存最好的模型,False:保存所有的模型时,指标改善 mode:min, max或auto save_weights_only:False:仅保存模型权重, True:同时保存模型权重和模型架构

例如,让我们看一个例子,保存具有最佳精度的模型

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
filePath = "models/Model1_weights.{epoch:02d}.hdf5"
model_checkpoint_callback = tf.keras.callbacksModelCheckpoint(
    filepath=filePath,
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max')

这里我们使用一些模板字符串指定文件路径。{epoch:02d} 保存模型时由时期号代替

减少LROnPlateau

当特定指标停止增加并达到平台期时,此回调用于降低训练率。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss', factor=0.1, patience=10, verbose=0,
    mode='auto', min_delta=0.0001, cooldown=0, min_lr=0, **kwargs
)

factor:LR 减少的系数。新学习率 = old_learning_rate * 因子 min_delta:需要被视为改进的最小变化 cooldown:等待 LR 减少的时期数 min_lr:学习率不能低于该最小值

终止OnNaN

当任何损失变为 NaN 时,此回调将停止训练过程

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
tf.keras.callbacks.TerminateOnNaN()

Tensorboard

Tensorboard 允许我们显示有关训练过程的信息,如指标、训练图、激活函数直方图和其他梯度分布。

要使用Tensorboard,我们首先需要设置一个 log_dir,Tensorboard文件被保存到其中。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
log_dir="logs"
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, write_graph=True)
  • log_dir:保存文件的目录
  • histogram_freq:计算直方图和梯度图的时期频率
  • write_graph:我们是否需要在Tensorboard中显示和可视化图形

编写自己的回调

除了内置的回调之外,我们还可以为不同的目的定义和使用我们自己的回调。例如,假设我们要定义自己的度量标准,该度量标准在每个 epoch 结束时计算。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# Monitor MicroF1 and AUC Score
class Metrics_Callback(tf.keras.callbacks.Callback):
  def __init__(self,x_val,y_val):
    self.x_val = x_val
    self.y_val = y_val
  def on_train_begin(self, logs={}):
    self.history = {"auc_score":[],"micro_f1":[]}
  def on_epoch_end(self, epoch, logs={}):
    auc_score = roc_auc_score(self.y_val, model.predict_proba(self.x_val))
    y_true = [0 if x[0]==1.0 else 1 for x in self.y_val]
    f1_s = f1_score(y_true,self.model.predict_classes(self.x_val), average='micro')
    self.history["auc_score"].append(auc_score)
    self.history["micro_f1"].append(f1_s)
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
Metrics = Metrics_Callback(X_test,y_test)

这里我们要计算每个 epoch 结束时的 F1 分数和 AUC 分数。在 init 方法中,我们读取计算分数所需的数据。然后在每个 epoch 结束时,我们在 on_epoch_end 函数中计算指标。

我们可以使用以下方法在不同的时间执行代码——

on_epoch_begin:在每个时期开始时调用。

on_epoch_begin:在每个时期结束时调用。

on_batch_begin:在每批开始时调用。

on_batch_end:在每批结束时调用。

on_train_begin:模型开始训练时调用

on_train_end:模型训练完成时调用

结论

这些是一些常用和最流行的回调。TensorFlow 官方文档为我们提供了有关各种其他回调及其相关用例的详细信息。

TensorFlow 官方文档:https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2021-10-23,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 磐创AI 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
轻松理解Keras回调
随着计算机处理能力的提高,人工智能模型的训练时间并没有缩短,主要是人们对模型精确度要求越来越高。为了提升模型精度,人们设计出越来越复杂的深度神经网络模型,喂入越来越海量的数据,导致训练模型也耗时越来越长。这就如同PC产业,虽然CPU遵从摩尔定律,速度越来越快,但由于软件复杂度的提升,我们并没有感觉计算机运行速度有显著提升,反而陷入需要不断升级电脑硬件的怪圈。
云水木石
2019/08/09
2K0
神经网络训练中回调函数的实用教程
回调操作可以在训练的各个阶段执行,可能是在epoch之间,在处理一个batch之后,甚至在满足某个条件的情况下。回调可以利用许多创造性的方法来改进训练和性能,节省计算资源,并提供有关神经网络内部发生的事情的结论。
磐创AI
2020/09/04
1.3K0
TensorFlow2.0(9):神器级可视化工具TensorBoard
TensorBoard是TensorFlow中的又一神器级工具,想用户提供了模型可视化的功能。我们都知道,在构建神经网络模型时,只要模型开始训练,很多细节对外界来说都是不可见的,参数如何变化,准确率怎么样了,loss还在减小吗,这些问题都很难弄明白。但是,TensorBoard通过结合web应用为我们提供了这一功能,它将模型训练过程的细节以图表的形式通过浏览器可视化得展现在我们眼前,通过这种方式我们可以清晰感知weight、bias、accuracy的变化,把握训练的趋势。
Ai学习的老章
2019/12/25
3.7K0
TensorFlow2.0 实战强化专栏(一):Chars74项目
字符识别是一种经典的模式识别问题,字符识别在现实生活中也有着非常广泛的应用,目前对于特定环境下的拉丁字符识别已经取得了很好的效果,但是对于一些复杂场景下的字符识别依然还有很多困难,例如通过手持设备拍摄以及自然场景中的图片等,Chars74K正是针对这些困难点搜集的数据集(http://www.ee.surrey.ac.uk/CVSSP/demos/chars74k/)
磐创AI
2020/03/04
2K1
使用TensorFlow2预测国内疫情结束时间
国内的新冠肺炎疫情从发现至今已经持续3个多月了,这场起源于吃野味的灾难给大家的生活造成了诸多方面的影响。
lyhue1991
2020/07/20
8530
使用TensorFlow2预测国内疫情结束时间
【tensorflow2.0】回调函数callbacks
tf.keras的回调函数实际上是一个类,一般是在model.fit时作为参数指定,用于控制在训练过程开始或者在训练过程结束,在每个epoch训练开始或者训练结束,在每个batch训练开始或者训练结束时执行一些操作,例如收集一些日志信息,改变学习率等超参数,提前终止训练过程等等。
西西嘛呦
2020/08/26
1.5K0
Deep learning with Python 学习笔记(9)
使用 model.fit()或 model.fit_generator() 在一个大型数据集上启动数十轮的训练,有点类似于扔一架纸飞机,一开始给它一点推力,之后你便再也无法控制其飞行轨迹或着陆点。如果想要避免不好的结果(并避免浪费纸飞机),更聪明的做法是不用纸飞机,而是用一架无人机,它可以感知其环境,将数据发回给操纵者,并且能够基于当前状态自主航行。下面要介绍的技术,可以让model.fit() 的调用从纸飞机变为智能的自主无人机,可以自我反省并动态地采取行动
范中豪
2019/09/10
6730
Deep learning with Python 学习笔记(9)
Tensorflow Keras:mnist分类demo
tf2集成的keras非常好用,对一些简单的模型可以快速搭建,下面以经典mnist数据集为例,做一个demo,展示一些常用的方法
Mirza Zhao
2023/06/26
5670
【tensorflow2.0】处理时间序列数据
国内的新冠肺炎疫情从发现至今已经持续3个多月了,这场起源于吃野味的灾难给大家的生活造成了诸多方面的影响。
西西嘛呦
2020/08/26
9430
【tensorflow2.0】处理时间序列数据
使用Keras Tuner进行自动超参数调优的实用教程
在本文中将介绍如何使用 KerasTuner,并且还会介绍其他教程中没有的一些技巧,例如单独调整每一层中的参数或与优化器一起调整学习率等。Keras-Tuner 是一个可帮助您优化神经网络并找到接近最优的超参数集的工具,它利用了高级搜索和优化方法,例如 HyperBand 搜索和贝叶斯优化。所以只需要定义搜索空间,Keras-Tuner 将负责繁琐的调优过程,这要比手动的Grid Search强的多!
deephub
2022/11/11
9720
使用Keras Tuner进行自动超参数调优的实用教程
使用Python实现深度学习模型:模型监控与性能优化
在深度学习模型的实际应用中,模型的性能监控与优化是确保其稳定性和高效性的关键步骤。本文将介绍如何使用Python实现深度学习模型的监控与性能优化,涵盖数据准备、模型训练、监控工具和优化策略等内容。
Echo_Wish
2024/07/08
3470
使用Python实现深度学习模型:模型监控与性能优化
keras doc 10终结篇 激活函数 回调函数 正则项 约束项 预训练模型
激活函数可以通过设置单独的激活层实现,也可以在构造层对象时通过传递activation参数实现。
CreateAMind
2018/07/25
2.4K0
keras doc 10终结篇 激活函数 回调函数 正则项 约束项 预训练模型
深度学习从小白到入门 —— 基于keras的深度学习基本概念讲解
神经网络中的每个神经元 对其所有的输入进行加权求和,并添加一个被称为偏置(bias) 的常数,然后通过一些非线性激活函数来反馈结果。
机械视角
2019/10/23
6970
深度学习从小白到入门 —— 基于keras的深度学习基本概念讲解
深度学习快速参考:1~5
欢迎使用《深度学习快速参考》! 在本书中,我将尝试使需要解决深度学习问题的数据科学家,机器学习工程师和软件工程师更容易使用,实用和使用深度学习技术。 如果您想训练自己的深度神经网络并且陷入困境,那么本指南很有可能会有所帮助。
ApacheCN_飞龙
2023/04/23
1.1K0
TensorBoard的最全使用教程:看这篇就够了
机器学习通常涉及在训练期间可视化和度量模型的性能。有许多工具可用于此任务。在本文中,我们将重点介绍 TensorFlow 的开源工具套件,称为 TensorBoard,虽然他是TensorFlow 的一部分,但是可以独立安装,并且服务于Pytorch等其他的框架。
deephub
2022/03/12
37.7K0
TensorBoard的最全使用教程:看这篇就够了
【私人笔记】深度学习框架keras踩坑记
Keras 是一个用 Python 编写的高级神经网络 API,它能够以 TensorFlow, CNTK, 或者 Theano 作为后端运行。Keras 的开发重点是支持快速的实验。能够以最小的时间把你的想法转换为实验结果,是做好研究的关键。本人是keras的忠实粉丝,可能是因为它实在是太简单易用了,不用多少代码就可以将自己的想法完全实现,但是在使用的过程中还是遇到了不少坑,本文做了一个归纳,供大家参考。
小草AI
2019/05/29
4.5K1
TensorFlow 2.0入门
谷歌于2019年3月6日和7日在其年度TensorFlow开发者峰会上发布了最新版本的TensorFlow机器学习框架。这一新版本使用TensorFlow的方式进行了重大改进。TensorFlow拥有最大的开发者社区之一,从机器学习库到完善的机器学习生态系统已经走过了漫长的道路。
代码医生工作室
2019/06/21
2K0
TensorFlow 2.0入门
相关推荐
轻松理解Keras回调
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验