2.运行程序
步骤1:加载高光谱数据集
使用超立方体函数读取高光谱图像。
hcube = hypercube("indian_pines.dat");
使用colorize函数可视化图像的假彩色版本。
rgbImg = colorize(hcube,method="rgb");
imshow(rgbImg)
加载地面真相标签并指定类的数量。
gtLabel = load("indian_pines_gt.mat");
gtLabel = gtLabel.indian_pines_gt;
numClasses = 16;
步骤2:预处理训练数据
使用hyperpca函数将光谱带的数量减少到30个。
dimReduction = 30;
imageData = hyperpca(hcube,dimReduction);
规格化图像数据。
sd = std(imageData,[],3);
imageData = imageData./sd;
使用createImagePatchesFromHypercube函数,将高光谱图像分割成大小为25×25像素、具有30个通道的Patches。
windowSize = 25;
inputSize = [windowSize windowSize dimReduction];
[allPatches,allLabels] = createImagePatchesFromHypercube(imageData,gtLabel,windowSize);
indianPineDataTransposed = permute(allPatches,[2 3 4 1]);
dsAllPatches = augmentedImageDatastore(inputSize,indianPineDataTransposed,allLabels);
仅选择标记的立方体进行训练。
patchesLabeled = allPatches(allLabels>0,:,:,:);
patchLabels = allLabels(allLabels>0);
numCubes = size(patchesLabeled,1);
将数字标签转换为分类标签。
patchLabels = categorical(patchLabels);
将Patches随机分为训练和测试数据集。
[trainingIdx,valIdx,testIdx] = dividerand(numCubes,0.3,0,0.7);
dataInputTrain = patchesLabeled(trainingIdx,:,:,:);
dataLabelTrain = patchLabels(trainingIdx,1);
dataInputTest = patchesLabeled(testIdx,:,:,:);
dataLabelTest = patchLabels(testIdx,1);
转换输入数据。
dataInputTransposeTrain = permute(dataInputTrain,[2 3 4 1]);
dataInputTransposeTest = permute(dataInputTest,[2 3 4 1]);
创建读取训练和测试数据批的数据存储。
dsTrain = augmentedImageDatastore(inputSize,dataInputTransposeTrain,dataLabelTrain);
dsTest = augmentedImageDatastore(inputSize,dataInputTransposeTest,dataLabelTest);
步骤3:创建CSCNN分类网络
定义CSCNN架构。
layers = [
image3dInputLayer(inputSize,Name="Input",Normalization="None")
convolution3dLayer([3 3 7],8,Name="conv3d_1")
reluLayer(Name="Relu_1")
convolution3dLayer([3 3 5],16,Name="conv3d_2")
reluLayer(Name="Relu_2")
convolution3dLayer([3 3 3],32,Name="conv3d_3")
reluLayer(Name="Relu_3")
convolution3dLayer([3 3 1],8,Name="conv3d_4")
reluLayer(Name="Relu_4")
fullyConnectedLayer(256,Name="fc1")
reluLayer(Name="Relu_5")
dropoutLayer(0.4,Name="drop_1")
fullyConnectedLayer(128,Name="fc2")
dropoutLayer(0.4,Name="drop_2")
fullyConnectedLayer(numClasses,Name="fc3")
softmaxLayer(Name="softmax")
classificationLayer(Name="output")];
lgraph = layerGraph(layers);
使用深度网络设计器可视化网络。
deepNetworkDesigner(lgraph)
步骤4:指定训练选项
指定所需的网络参数。
numEpochs = 100;
miniBatchSize = 256;
initLearningRate = 0.001;
momentum = 0.9;
learningRateFactor = 0.01;
options = trainingOptions("adam", ...
InitialLearnRate=initLearningRate, ...
LearnRateSchedule="piecewise", ...
LearnRateDropPeriod=30, ...
LearnRateDropFactor=learningRateFactor, ...
MaxEpochs=numEpochs, ...
MiniBatchSize=miniBatchSize, ...
GradientThresholdMethod="l2norm", ...
GradientThreshold=0.01, ...
VerboseFrequency=100, ...
ValidationData=dsTest, ...
ValidationFrequency=100);
步骤5:训练网络
默认情况下,该示例为Indian Pines数据集下载预训练的分类器。
doTraining = true;
if doTraining
net = trainNetwork(dsTrain,lgraph,options);
modelDateTime = string(datetime("now",Format="yyyy-MM-dd-HH-mm-ss"));
save("trainedIndianPinesCSCNN-"+modelDateTime+".mat","net");
else
dataDir = pwd;
trainedNetwork_url = "https://ssd.mathworks.com/supportfiles/image/data/trainedIndianPinesCSCNN.mat";
downloadTrainedNetwork(trainedNetwork_url,pwd);
load(fullfile(dataDir,"trainedIndianPinesCSCNN.mat"));
end
步骤6:基于训练的CSCNN的高光谱图像分类计算测试数据集的分类精度。
predictionTest = classify(net,dsTest);
accuracy = sum(predictionTest == dataLabelTest)/numel(dataLabelTest);
disp("Accuracy of the test data = "+num2str(accuracy))
Accuracy of the test data 0.99512
通过对所有图像像素进行分类来重建完整图像,包括标记的训练块中的像素、标记的测试块中的图像和未标记的像素。
prediction = classify(net,dsAllPatches);
prediction = double(prediction);
在标记的Patches上进行训练。
patchesUnlabeled = find(allLabels==0);
prediction(patchesUnlabeled) = 0;
重构分类像素以匹配地面真实图像的尺寸。
[m,n,d] = size(imageData);
indianPinesPrediction = reshape(prediction,[n m]);
indianPinesPrediction = indianPinesPrediction';
显示地面实况和预测分类。
cmap = parula(numClasses);
figure
tiledlayout(1,2,TileSpacing="Tight")
nexttile
imshow(gtLabel,cmap)
title("Ground Truth Classification")
nexttile
imshow(indianPinesPrediction,cmap)
colorbar
title("Predicted Classification")
链接:https://ww2.mathworks.cn/help/images/hyperspectral-image-classification-using-deep-learning.html
https://www.ehu.eus/ccwintco/index.php/Hyperspectral_Remote_Sensing_Scenes
本文分享自 图像处理与模式识别研究所 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!