在PyTorch中找出图像分类器的混淆矩阵并作图的步骤如下:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
model = torch.load('path_to_model.pth')
test_dataset = torchvision.datasets.ImageFolder('path_to_test_data', transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
model.eval()
def get_predictions(model, data_loader):
predictions = []
targets = []
for images, labels in data_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
predictions.extend(predicted.tolist())
targets.extend(labels.tolist())
return predictions, targets
predictions, targets = get_predictions(model, test_loader)
confusion_mat = confusion_matrix(targets, predictions)
plt.figure(figsize=(num_classes, num_classes))
plt.imshow(confusion_mat, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
tick_marks = np.arange(num_classes)
plt.xticks(tick_marks, class_names, rotation=90)
plt.yticks(tick_marks, class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()
其中,path_to_model.pth
是训练好的模型的路径,path_to_test_data
是测试数据集的路径,transform
是数据预处理的方法,batch_size
是每个批次的样本数量,num_classes
是分类器的类别数,class_names
是类别的名称。
这样,你就可以在PyTorch中找出图像分类器的混淆矩阵并作图了。
领取专属 10元无门槛券
手把手带您无忧上云