接着前面2期rbf相关的应用分享一下rbf在分类场景的应用,数据集采用iris
前期参考
一、数据集
iris以鸢尾花的特征作为数据来源,数据集包含150个数据集,分为3类(setosa,versicolor, virginica),每类50个数据,每个数据包含4个属性。每一个数据包含4个独立的属性,这些属性变量测量植物的花朵(比如萼片和花瓣的长度等)信息。要求以iris数据为对象,来进行不可测信息(样本类别)的估计。数据随机打乱,然后训练集:测试集=7:3进行训练,并和实际结果作比较
二、编程步骤、思路
(1)读取训练数据通过load函数读取训练数据,并对数据进行打乱,提取对应的数据分为训练和验证数据,训练集和验证集7:3
iris = load('iris.txt');
inputData = iris(:,1:4);
outputData = iris(:,5);
flag = length(outputData);
orderTrain = randperm(flag);
nbertrain = round(0.7*flag);% 提取训练和验证数据 70% 训练,30% 验证
XTrain = inputData(orderTrain(1:nbertrain),1:4)';
YTrain = outputData(orderTrain(1:nbertrain))';
XValidation = inputData(orderTrain(nbertrain+1:flag),1:4)';
YValidation = outputData(orderTrain(nbertrain+1:flag))';
(2)建立一个RBF网络使用matlab的newrb函数,设定误差均方根值目标-0.02;径向基层的分布常数-1;最大的神经元个数-25
eg = 0.02; % 误差均方根值目标
sc = 1; % 径向基层的分布常数
mn = 25; % 最大的神经元个数
训练模型
net = newrb(XTrain,YTrain,eg,sc);
NEWRB, neurons = 0, MSE = 0.656327
预测准确率: 97.7778 %
(3)使用新的数据集测试这个网络将待识别的样本数据(XValidation)放在net变量,然后运行即可,
Y = net(XValidation);
最后的结果进行归一化计算,得到对应的预测类别 输出仿真结果
output = zeros(1,length(Y));
for i = 1:length(Y)
[m,n] = min(abs(Y(i)-[1 2 3]));
output(i) = n ;
end
绘制结果成图
分析:从实验运行结果可以看出,本程序的识别率准确率为97.7778 % ,
思考:本次使用了RBF神经网络,RBF是一种前馈型的神经网络,它的激励函数一般是高斯函数,高斯函数是通过计算输入与函数中心点的距离来算权重的。BP神经网络学习速率是固定的,因此网络的收敛速度慢,需要较长的训练时间。对于一些复杂问题,BP算法需要的训练时间可能非常长,这主要是由于学习速率太小造成的。而RBF神经网络是种高效的前馈式网络,它具有其他前向网络所不具有的最佳逼近性能和全局最优特性,并且结构简单,训练速度快,所以它也比BP网络更优。
完整代码
clc
close all
clear
iris = load('iris.txt');
inputData = iris(:,1:4);
outputData = iris(:,5);
flag = length(outputData);
orderTrain = randperm(flag);
nbertrain = round(0.7*flag);% 提取训练和验证数据 70% 训练,30% 验证
XTrain = inputData(orderTrain(1:nbertrain),1:4)';
YTrain = outputData(orderTrain(1:nbertrain))';
XValidation = inputData(orderTrain(nbertrain+1:flag),1:4)';
YValidation = outputData(orderTrain(nbertrain+1:flag))';
% net = newrbe(XTrain,YTrain);
eg = 0.02; % 误差均方根值目标
sc = 1; % 径向基层的分布常数
mn = 25; % 最大的神经元个数
net = newrb(XTrain,YTrain,eg,sc);
Y = net(XValidation);
output = zeros(1,length(Y));
for i = 1:length(Y)
[m,n] = min(abs(Y(i)-[1 2 3]));
output(i) = n ;
end
figure
plot(YValidation,'r+');
ylabel('label');
hold on;
plot(output,'b*');
hold off;
legend({'Target','Output'})
ylim([0 4])
figure
plot(YValidation,'r');
ylabel('label');
hold on;
plot(output,'b');
hold off;
legend({'Target','Output'})
ylim([0 4])
error = YValidation-output;
figure
bar(error)
title('error')
ylim([-1.2 1.2])
figure
plot(error)
title('error')
ylim([-1.2 1.2])
correcr_rate = length(error(error==0))/length(YValidation)*100;
disp(['预测准确率: ', num2str(correcr_rate), ' % '])