首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

如何在PyTorch中找出图像分类器的混淆矩阵并作图

在PyTorch中找出图像分类器的混淆矩阵并作图的步骤如下:

  1. 导入必要的库和模块:
代码语言:txt
复制
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
  1. 加载训练好的模型和测试数据集:
代码语言:txt
复制
model = torch.load('path_to_model.pth')
test_dataset = torchvision.datasets.ImageFolder('path_to_test_data', transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
  1. 设置模型为评估模式:
代码语言:txt
复制
model.eval()
  1. 定义一个函数来获取模型的预测结果:
代码语言:txt
复制
def get_predictions(model, data_loader):
    predictions = []
    targets = []
    for images, labels in data_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        predictions.extend(predicted.tolist())
        targets.extend(labels.tolist())
    return predictions, targets
  1. 调用上述函数获取预测结果:
代码语言:txt
复制
predictions, targets = get_predictions(model, test_loader)
  1. 计算混淆矩阵:
代码语言:txt
复制
confusion_mat = confusion_matrix(targets, predictions)
  1. 绘制混淆矩阵图:
代码语言:txt
复制
plt.figure(figsize=(num_classes, num_classes))
plt.imshow(confusion_mat, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
tick_marks = np.arange(num_classes)
plt.xticks(tick_marks, class_names, rotation=90)
plt.yticks(tick_marks, class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()

其中,path_to_model.pth是训练好的模型的路径,path_to_test_data是测试数据集的路径,transform是数据预处理的方法,batch_size是每个批次的样本数量,num_classes是分类器的类别数,class_names是类别的名称。

这样,你就可以在PyTorch中找出图像分类器的混淆矩阵并作图了。

页面内容是否对你有帮助?
有帮助
没帮助

相关·内容

  • 基于支持向量机的手写数字识别详解(MATLAB GUI代码,提供手写板)

    摘要:本文详细介绍如何利用MATLAB实现手写数字的识别,其中特征提取过程采用方向梯度直方图(HOG)特征,分类过程采用性能优异的支持向量机(SVM)算法,训练测试数据集为学术及工程上常用的MNIST手写数字数据集,博主为SVM设置了合适的核函数,最终的测试准确率达99%的较高水平。根据训练得到的模型,利用MATLAB GUI工具设计了可以手写输入或读取图片进行识别的系统界面,同时可视化图片处理过程及识别结果。本套代码集成了众多机器学习的基础技术,适用性极强(用户可修改图片文件夹实现自定义数据集训练),相信会是一个非常好的学习Demo。本博文目录如下:

    05

    ROC曲线的含义以及画法

    ROC的全名叫做Receiver Operating Characteristic(受试者工作特征曲线 ),又称为感受性曲线(sensitivity curve)。得此名的原因在于曲线上各点反映着相同的感受性,它们都是对同一信号刺激的反应,只不过是在几种不同的判定标准下所得的结果而已。其主要分析工具是一个画在二维平面上的曲线——ROC 曲线。ROC曲线以真正例率TPR为纵轴,以假正例率FPR为横轴,在不同的阈值下获得坐标点,并连接各个坐标点,得到ROC曲线。 对于一个分类任务的测试集,其本身有正负两类标签,我们对于这个测试集有一个预测标签,也是正负值。分类器开始对样本进行分类时,首先会计算该样本属于正确类别的概率,进而对样本的类别进行预测。比如说给出一组图片,让分类器判断该图片是否为汉堡,分类器在开始分类前会首先计算该图片为汉堡的概率,进而对该图片的类别进行预测,是汉堡或者不是汉堡。我们用概率来表示横坐标,真实类别表示纵坐标,分类器在测试集上的效果就可以用散点图来表示,如图所示

    01
    领券