clear all; close all; clc;
imds = imageDatastore('lung augmented', ...
'IncludeSubfolders',true, ...
'LabelSource','foldernames');
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomize',true);
net = alexnet();
numClasses = numel(categories(imdsTrain.Labels));
imageSize = [224 224];
lgraph = layerGraph(net.Layers);
lgraph = removeLayers(lgraph, 'fc8');
lgraph = removeLayers(lgraph, 'prob');
lgraph = removeLayers(lgraph, 'output');
inputLayer = imageInputLayer([imageSize 1], 'Name', net.Layers(1).Name,...
'DataAugmentation', net.Layers(1).DataAugmentation, ...
'Normalization', net.Layers(1).Normalization);
lgraph = replaceLayer(lgraph,net.Layers(1).Name,inputLayer);
newConv1_Weights = net.Layers(2).Weights;
newConv1_Weights = mean(newConv1_Weights(:,:,1:3,:), 3);
newConv1 = convolution2dLayer(net.Layers(2).FilterSize(1), net.Layers(2).NumFilters,...
'Name', net.Layers(2).Name,...
'NumChannels', inputLayer.InputSize(3),...
'Stride', net.Layers(2).Stride,...
'DilationFactor', net.Layers(2).DilationFactor,...
'Padding', net.Layers(2).PaddingSize,...
'Weights', newConv1_Weights,...
'Bias', net.Layers(2).Bias,...
'BiasLearnRateFactor', net.Layers(2).BiasLearnRateFactor);
lgraph = replaceLayer(lgraph,net.Layers(2).Name,newConv1);
lgraph = addLayers(lgraph, fullyConnectedLayer(numClasses,'Name', 'fc2'));
lgraph = addLayers(lgraph, softmaxLayer('Name', 'softmax'));
lgraph = addLayers(lgraph, classificationLayer('Name','output'));
lgraph = connectLayers(lgraph, 'drop7', 'fc2');
lgraph = connectLayers(lgraph, 'fc2', 'softmax');
lgraph = connectLayers(lgraph, 'softmax', 'output');
augimdsTrain = augmentedImageDatastore(imageSize,imdsTrain);
augimdsValidation = augmentedImageDatastore(imageSize,imdsValidation);
options = trainingOptions('sgdm', ...
'MiniBatchSize',64, ...
'MaxEpochs',30, ...
'InitialLearnRate',0.0001, ...
'Shuffle','every-epoch', ...
'ValidationData',augimdsValidation, ...
'ValidationFrequency',3, ...
'Verbose',false, ...
'Plots','training-progress');
net = trainNetwork(augimdsTrain,lgraph,options)
[YPred, probs] = classify(net,augimdsValidation);
accuracy = mean(YPred ==imdsValidation.Labels)
figure
cc = confusionchart (imdsValidation.Labels, YPred);
2 Comments
Direct link to this comment
https://au.mathworks.com/matlabcentral/answers/604027-alexnet-pretrained-network#comment_1034335
Direct link to this comment
https://au.mathworks.com/matlabcentral/answers/604027-alexnet-pretrained-network#comment_1034335
Direct link to this comment
https://au.mathworks.com/matlabcentral/answers/604027-alexnet-pretrained-network#comment_1034344
Direct link to this comment
https://au.mathworks.com/matlabcentral/answers/604027-alexnet-pretrained-network#comment_1034344
Sign in to comment.