前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >KNN图像分类

KNN图像分类

作者头像
AngelNH
发布2020-04-16 15:44:07
5750
发布2020-04-16 15:44:07
举报
文章被收录于专栏:AngelNI

真味是淡至如常。

KNN图像分类

链接

摘自大佬的笔记,拿来细细品味,别是一番滋味。

代码语言:javascript
复制
import numpy as np
import os
import pickle
import matplotlib.pyplot as plt
import h5py
import scipy
from PIL import Image
from scipy import ndimage

def distance(X_test, X_train):
    """
    输入:
    X_test -- 由numpy数组表示的测试集,大小为(图片长度 * 图片高度 * 3 , 测试样本数)
    X_train -- 由numpy数组表示的训练集,大小为(图片长度 * 图片高度 * 3 , 训练样本数)
    输出:
    distances -- 测试数据与各个训练数据之间的距离,大小为(测试样本数, 训练样本数量)的numpy数组
    """
    num_test = X_test.shape[1]
    num_train = X_train.shape[1]
    distances = np.zeros((num_test, num_train))
    # (X_test - X_train)*(X_test - X_train) = -2X_test*X_train + X_test*X_test + X_train*X_train
    dist1 = np.multiply(np.dot(X_test.T,X_train), -2)    # -2X_test*X_train, shape (num_test, num_train)
    dist2 = np.sum(np.square(X_test.T), axis=1, keepdims=True)    # X_test*X_test, shape (num_test, 1)
    dist3 = np.sum(np.square(X_train), axis=0,keepdims=True)    # X_train*X_train, shape(1, num_train)
    distances = np.sqrt(dist1 + dist2 + dist3)

    return distances
def predict(X_test, X_train, Y_train, k = 1):
    """ 
    输入:
    X_test -- 由numpy数组表示的测试集,大小为(图片长度 * 图片高度 * 3 , 测试样本数)
    X_train -- 由numpy数组表示的训练集,大小为(图片长度 * 图片高度 * 3 , 训练样本数)
    Y_train -- 由numpy数组(向量)表示的训练标签,大小为 (1, 训练样本数)
    k -- 选取与训练集最近邻的数量
    输出:
    Y_prediction -- 包含X_test中所有预测值的numpy数组(向量)
    distances -- 由numpy数组表示的测试数据与各个训练数据之间的距离,大小为(测试样本数, 训练样本数)
    """
    distances = distance(X_test, X_train)
    num_test = X_test.shape[1]
    Y_prediction = np.zeros(num_test)
    for i in range(num_test):
        dists_min_k = np.argsort(distances[i])[:k]     # 按照距离递增次序进行排序,选取距离最小的k个点 
        y_labels_k = Y_train[0,dists_min_k]     # 确定前k个点的所在类别
        Y_prediction[i] = np.argmax(np.bincount(y_labels_k)) # 返回前k个点中出现频率最高的类别作为测试数据的预测分类

    return Y_prediction, distances
def model(X_test, Y_test, X_train, Y_train, k = 1, print_correct = False):
    """
    输入:
    X_test -- 由numpy数组表示的测试集,大小为(图片长度 * 图片高度 * 3 , 测试样本数)
    X_train -- 由numpy数组表示的训练集,大小为(图片长度 * 图片高度 * 3 , 训练样本数)
    Y_train -- 由numpy数组(向量)表示的训练标签,大小为 (1, 训练样本数)
    Y_test -- 由numpy数组(向量)表示的测试标签,大小为 (1, 测试样本数)
    k -- 选取与训练集最近邻的数量
    print_correct -- 设置为true时,打印正确率
    输出:
    d -- 包含模型信息的字典
    """
    Y_prediction, distances = predict(X_test, X_train, Y_train, k)
    num_correct = np.sum(Y_prediction == Y_test)
    accuracy = np.mean(Y_prediction == Y_test)
    if print_correct:
        print('Correct %d/%d: The test accuracy: %f' % (num_correct, X_test.shape[1], accuracy))
    d = {"k": k,
         "Y_prediction": Y_prediction, 
         "distances" : distances,
         "accuracy": accuracy}
    return d
def load_CIFAR_batch(filename):
    with open(filename, 'rb') as f:
        datadict = pickle.load(f,encoding='latin1')
        X = datadict['data']
        Y = datadict['labels']
        X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")
        Y = np.array(Y)
    return X, Y
def load_CIFAR10():
    xs = []
    ys = []
    for b in range(1,6):
        f = os.path.join('F:\C-and-Python-Algorithn\python\Tensorflow\data', 'cifar-10-batches-py', 'data_batch_%d' % (b, ))
        X, Y = load_CIFAR_batch(f)
        xs.append(X)
        ys.append(Y)    
    Xtr = np.concatenate(xs)
    Ytr = np.concatenate(ys)
    del X, Y
    Xte, Yte = load_CIFAR_batch(os.path.join('F:\C-and-Python-Algorithn\python\Tensorflow\data', 'cifar-10-batches-py', 'test_batch'))
    return Xtr, Ytr, Xte, Yte


X_train, y_train, X_test, y_test = load_CIFAR10()


classes = ['plane', 'car', 'bird', 'cat', 'dear', 'dog', 'frog', 'horse', 'ship', 'truck']
num_classes = len(classes)
num_each_class = 7
for y, cls in enumerate(classes):
    idxs = np.flatnonzero(y_train == y)
    idxs = np.random.choice(idxs, num_each_class, replace=False)
    for i, idx in enumerate(idxs):
        plt_idx = i * num_classes + (y + 1)
        plt.subplot(num_each_class, num_classes, plt_idx)
        plt.imshow(X_train[idx].astype('uint8'))
        plt.axis('off')
        if i == 0:
            plt.title(cls)
plt.show()

X_train = np.reshape(X_train, (X_train.shape[0], -1)).T
X_test = np.reshape(X_test, (X_test.shape[0], -1)).T
Y_set_train = y_train[:10000].reshape(1,-1)
Y_set_test = y_test[:1000].reshape(1,-1)
X_set_train = X_train[:,:10000]
X_set_test = X_test[:,:1000]


models = {}
for k in [1, 3, 5, 10]:
    print ("k = " + str(k))
    models[str(k)] = model(X_set_test, Y_set_test, X_set_train, Y_set_train, k, print_correct = True)
    print ('\n' + "-------------------------------------------------------" + '\n')


models = {}
k = []
accuracys = []
for i in range(1,11):
    models[str(i)] = model(X_set_test, Y_set_test, X_set_train, Y_set_train, i, print_correct = False)
    k.append(models[str(i)]["k"])
    accuracys.append(models[str(i)]["accuracy"])
plt.plot(k, accuracys)
plt.ylabel('accuracy')
plt.xlabel('k')
plt.show()
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2020-03-05|,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • KNN图像分类
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档