Loading [MathJax]/jax/output/CommonHTML/config.js
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >21 | 使用PyTorch完成医疗图像识别大项目:训练模型

21 | 使用PyTorch完成医疗图像识别大项目:训练模型

作者头像
机器学习之禅
发布于 2022-07-11 07:51:15
发布于 2022-07-11 07:51:15
74900
代码可运行
举报
文章被收录于专栏:机器学习之禅机器学习之禅
运行总次数:0
代码可运行

昨天我们已经完成了训练和验证模型的主体代码,在进行训练之前,我们还需要处理一下输出信息。前面我们已经记录了一部分信息到trnMetrics_g和valMetrics_g中,每迭代一个周期,就会输出一次结果方便我们查看。如果发现模型的结果很差,比如说出现了无法收敛的情况,我们就可以中止模型训练,不用再浪费更多时间,因为一个深度模型训练需要花费很长的时间。

日志处理

日志处理方法也挺长的,甚至比模型训练还要长,不过里面的内容比较简单

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
    def logMetrics(   #接收参数
            self,
            epoch_ndx,  #迭代次数
            mode_str,  #当前模式,是训练还是验证
            metrics_t, #接收结果信息
            classificationThreshold=0.5, #分类判断阈值
    ):#这个初始化是为了后面写入tensorboard准备的
        self.initTensorboardWriters()
        log.info("E{} {}".format(
            epoch_ndx,
            type(self).__name__,
        ))#构建掩码,这里面的静态变量我们上节已经声明过了,根据阈值判断获取负样本标注结果和预测结果
        negLabel_mask = metrics_t[METRICS_LABEL_NDX] <= classificationThreshold
        negPred_mask = metrics_t[METRICS_PRED_NDX] <= classificationThreshold#这块采用了一个trick的方法来获取正样本结果,直接对上面的负样本结果取反,在二分类是可以这么操作,如果是多分类就不能这么操作了
        posLabel_mask = ~negLabel_mask
        posPred_mask = ~negPred_mask#统计训练集正负样本数
        neg_count = int(negLabel_mask.sum())
        pos_count = int(posLabel_mask.sum())#统计预测正确的正负样本数
        neg_correct = int((negLabel_mask & negPred_mask).sum())
        pos_correct = int((posLabel_mask & posPred_mask).sum())#这一块在计算平均损失,总体平均损失,负样本平均损失,正样本平均损失
        metrics_dict = {}
        metrics_dict['loss/all'] = \
            metrics_t[METRICS_LOSS_NDX].mean()
        metrics_dict['loss/neg'] = \
            metrics_t[METRICS_LOSS_NDX, negLabel_mask].mean()
        metrics_dict['loss/pos'] = \
            metrics_t[METRICS_LOSS_NDX, posLabel_mask].mean()#计算准确率
        metrics_dict['correct/all'] = (pos_correct + neg_correct) \            / np.float32(metrics_t.shape[1]) * 100
        metrics_dict['correct/neg'] = neg_correct / np.float32(neg_count) * 100
        metrics_dict['correct/pos'] = pos_correct / np.float32(pos_count) * 100#紧接着就是日志记录
        log.info(  #整体损失和整体准确率
            ("E{} {:8} {loss/all:.4f} loss, "
                 + "{correct/all:-5.1f}% correct, "
            ).format(
                epoch_ndx,
                mode_str,
                **metrics_dict,
            )
        )
        log.info( #负样本损失和负类别准确率
            ("E{} {:8} {loss/neg:.4f} loss, "
                 + "{correct/neg:-5.1f}% correct ({neg_correct:} of {neg_count:})"
            ).format(
                epoch_ndx,
                mode_str + '_neg',
                neg_correct=neg_correct,
                neg_count=neg_count,
                **metrics_dict,
            )
        )
        log.info(#正样本损失和正类别准确率
            ("E{} {:8} {loss/pos:.4f} loss, "
                 + "{correct/pos:-5.1f}% correct ({pos_correct:} of {pos_count:})"
            ).format(
                epoch_ndx,
                mode_str + '_pos',
                pos_correct=pos_correct,
                pos_count=pos_count,
                **metrics_dict,
            )
        )#下面的部分跟写入tensorboard相关
        writer = getattr(self, mode_str + '_writer')

        for key, value in metrics_dict.items():
            writer.add_scalar(key, value, self.totalTrainingSamples_count)

        writer.add_pr_curve(
            'pr',
            metrics_t[METRICS_LABEL_NDX],
            metrics_t[METRICS_PRED_NDX],
            self.totalTrainingSamples_count,
        )

        bins = [x/50.0 for x in range(51)]

        negHist_mask = negLabel_mask & (metrics_t[METRICS_PRED_NDX] > 0.01)
        posHist_mask = posLabel_mask & (metrics_t[METRICS_PRED_NDX] < 0.99)

        if negHist_mask.any():
            writer.add_histogram(
                'is_neg',
                metrics_t[METRICS_PRED_NDX, negHist_mask],
                self.totalTrainingSamples_count,
                bins=bins,
            )
        if posHist_mask.any():
            writer.add_histogram(
                'is_pos',
                metrics_t[METRICS_PRED_NDX, posHist_mask],
                self.totalTrainingSamples_count,
                bins=bins,
            )

接下来我们尝试启动训练。调试代码通常需要花一点时间,尤其是这种已经写好的代码。不过这个还好,唯一需要注意的就是代码中有用到cache,如果在运行的时候出现了错误,当修改了错误之后可能需要重新启动一下。我们用之前定义的run方法在JupyterNotebook中启动

可以看到这次成功了,开始构建数据cache,总共56个,我这里还没有用全部数据,只用了两个set,结果还是出了问题,报错是空间不足了,这个项目确实有点占空间,抓紧清理了一下我的硬盘,然后重新启动。

正好赶上中午吃饭时间,吃完饭休息一会再看已经跑完了。这一步处理缓存在我这个电脑上耗时接近40分钟。

接下来跑一下训练试试,我们只迭代一个周期。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
run('testproject.training.LunaTrainingApp', '--epochs=1')

这个输出的日志比较长,就不全贴了,耗时约40分钟

从这个结果可以看到,训练集和验证集损失都很低,准确率竟然高达99.7%,这只是用了其中的两个数据集,竟然就达到了这种效果?再仔细看一下下面两行,负样本的准确率是99.9%,但是正样本的准确率是0%???109个负样本一个都没有分出来!这就好像有1000个病人说不舒服要来看病,医生告诉他们都没有病,不需要治疗,结果最后有一个病人死掉了,医生说我的诊断准确率是99.9%!但是这有什么意义呢,我们需要的是把有问题的那些识别出来,但是现在我们的模型只知道把所有的都归为非结节就能够拿到一个高分,显然我们不想要这样的模型。

使用TensorBoard绘制训练指标

本来想一天写完的,结果要训练10个epoch花的时间太长了,我就把电脑放在这里自己跑,拿起了塞尔达玩了一会,谁知道昨天就过去了。今早上起来看10个epoch确实跑完了,但是结果更加诡异,我的模型竟然把数据都分成了结节类。

为了更好的观察模型训练的情况,我们这里把结果数据输出到TensorBoard中,如果你没有安装过,这里可以装一下requirements.txt文件,里面已经提供了需要的包。不过我试了一下,仍然有一些包可能会缺失,这就需要自己动手安装一下了,比如我这里缺少importlib-metadata。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
#安装需要的包,用pip比较快,但是偶尔会有跟conda不兼容的问题
pip install -r requirements.txt
conda install importlib-metadata

环境配置是很讨厌的一个环节,动不动就会有问题。接下来启动TensorBoard

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
tensorboard --logdir runs/

在浏览器里面输入http://localhost:6606/,就可以看到下面的页面

左侧最上面是一些配置,中间的smoothing是平滑系数,最下面是我们之前训练过的模型结果,我这里因为执行了多次,有多个结果,如果训练的次数比较多,可以选择自己需要查看的那个训练,其他的取消掉。 右侧就是结果了,上面一排是准确率,下面一排是损失。通过在这个图表上观察,我们可以更清晰的看出来我们的训练效果有多糟糕。

我们看一下用于写入TensorBoard的代码,在trainning.py文件中。这里构建了两个写入器,一个是训练集结果写入器,一个是验证集结果写入器。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
    def initTensorboardWriters(self):
        if self.trn_writer is None:
            log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)

            self.trn_writer = SummaryWriter(
                log_dir=log_dir + '-trn_cls-' + self.cli_args.comment)
            self.val_writer = SummaryWriter(
                log_dir=log_dir + '-val_cls-' + self.cli_args.comment)

再就是我们前面已经写过的代码,在logMetrics使用writer写入。 关于TensorBoard还有很多功能,这里暂时不介绍了,因为我也不太熟练。如果我们能够在训练中很好的使用它,能够更好的帮助我们理解模型训练的效果,如果你对TensorBoard感兴趣可以研究一下。

不管怎么样,我们已经可以从TensorBoard上看出来模型效果非常不理想了,后面我们该研究怎么去优化效果。现在我先把10个数据集都加进去训练一下。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-06-15,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 机器学习之禅 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
Python print函数参数详解以及效果展示
print(…)  print(value, …, sep=’ ‘, end=’\n’, file=sys.stdout, flush=False)  Prints the values to a stream, or to sys.stdout by default.  Optional keyword arguments:  file: a file-like object (stream); defaults to the current sys.stdout.  sep: string inserted between values, default a space.  end: string appended after the last value, default a newline.  flush: whether to forcibly flush the stream.
用户7886150
2020/11/23
9700
Python编程 print输出函数
语法: print(self, *args, sep=' ' , end='\n' , file=None)
网络豆
2022/11/20
8020
Python编程 print输出函数
python之 print()函数的输出学问(函数解析以及格式化输出)
前言:内容比较简单基础,但是很有用,方便。本篇主要针对print()函数的输出进行说明,所以不会构建长篇大论的大标题小标题。简洁明了!
兰舟千帆
2022/07/16
8960
python之 print()函数的输出学问(函数解析以及格式化输出)
Python3.5里print()的用法
print(*objects, sep=' ', end='\n', file=sys.stdout, flush=False)
用户7886150
2021/01/15
6960
#5 Python变量与输入输出
学习一门编程语言,最基本的无非不过学习其变量规则、条件语句、循环语句和函数,接下来的几节将开始记录这些基本的语法,本节主要记录变量规则!
py3study
2020/01/17
1.2K0
python3.x的print()函数默
        print(j, 'x', i, '=', j*i,,end='\t')
py3study
2020/01/09
4300
【基础教程】Python print()函数高级用法
前面使用 print() 函数时,都只输出了一个变量,但实际上 print() 函数完全可以同时输出多个变量,而且它具有更多丰富的功能。
matinal
2020/11/27
1.1K0
Python打印print函数深入解析
 尊重劳动成果,请访问CSDN著者原文链接 http://blog.csdn.net/zixiao217/article/details/51929078  学会在IDLE中使用help(BIF)命令查看BIF的说明
青山师
2023/05/04
4290
Python怎么去写单元测试用例去测试hello world呢
逛着博客园,看到乙醇大佬的一篇随笔 https://www.cnblogs.com/nbkhic/p/9370446.html,于是就在想怎么测试这句hello world
未来sky
2018/08/30
8070
python怎么换行输出的数字对齐_print语句输出换行,format格式化输出「建议收藏」
其实本来挺简单的一个函数,奈何每次用都忘记了怎么换行输出,所以想想算了还是自己做个记录,免得每次都要去查.
全栈程序员站长
2022/07/31
2.1K0
你真的懂print('Hello World!')?我不信
相信很多同学入门Python的第一行代码都是print('Hello World!')
Ai学习的老章
2020/12/09
8720
你真的懂print('Hello World!')?我不信
python3基础:文件操作
相对路径:顾名思义就是相对于当前文件的路径。网页中一般表示路径使用这个方法。 绝对路径:绝对路径就是主页上的文件或目录在硬盘上真正的路径。 比如 c:/apache/cgi-bin 下的,那么 c:/apache/cgi-bin就是cgi-bin目录的绝对路径
py3study
2020/01/10
7930
python3基础:文件操作
python学习笔记2.2-print函数以及格式化输出
文章主要介绍了Python中常用的print格式化输出方法,包括使用占位符、格式限定符、精度和类型等。同时,还介绍了如何通过控制台输入和输出,以及使用格式化库(str.format())和f-string(格式化字符串)来实现更复杂的输出格式。
锦小年
2018/01/02
1.4K0
python学习笔记2.2-print函数以及格式化输出
python中\r的意义及用法
原文出处:https://www.cnblogs.com/zzliu/p/10156658.html
SL_World
2021/09/18
1.3K0
Python 输出日志 print 函数的应用(python专栏001)
在Python中,print()函数是一个用于输出内容到标准输出设备的函数,通常用于调试程序和显示程序运行结果
早起的年轻人
2023/04/27
3770
Python原地输出效果实现
由于 GIL 的存在,所以每次只有一个线程在运行,所以 slow_function() 的作用就是强制 sleep 主线程,使子线程得到执行
用户7685359
2020/08/24
6330
Python原地输出效果实现
通过内置对象理解 Python(三)
这个函数将存储常量 2 以及变量名 number,但显然它不能包含 number 的实际值,因为只有在函数实际运行时才会给该参数赋值。
老齐
2021/11/15
5640
python基础——输入与输出【input 和 print】
📝前言: 上一篇文章python基础——入门必备知识中讲解了一些关于python的基础知识,可以让我们更好的理解程序代码中内容的含义,不至于一头雾水。今天我就来介绍一下,python中两个常见的输入和输出语句 input 和 print
用户11029137
2024/03/19
3220
python基础——输入与输出【input 和 print】
Python内置(3)exec&eval、globals&locals、input&print、5个基本类型、object
exec (execute执行)的缩写。将一些Python代码作为字符串接收,并将其作为Python代码运行。默认情况下,exec将在与其余代码相同的范围内运行,这意味着它可以读取和操作变量,就像Python文件中的任何其他代码段一样。
一只大鸽子
2022/12/06
6410
Prin()输出函数和使用方法
我们在之前的文章中我们用的最多的就是print()这个函数来打印一些数据,这就是我们今天要讲的输出语句,通过print()不仅可以输出变量,还有很多其他功能。下面就来详细讲解一下。
python自学网
2021/11/29
8130
Prin()输出函数和使用方法
推荐阅读
相关推荐
Python print函数参数详解以及效果展示
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验