AlexNet pretrained network?

7 views (last 30 days)
shivan artosh
shivan artosh on 2 Oct 2020
Commented: Walter Roberson on 5 Oct 2020
hello
i have this code and i need to exchange AlexNet with (vgg16, vgg19, ResNet18 and densnet201) one by one.
could you please tell me which part of this code should be changed?
clear all; close all; clc;
imds = imageDatastore('lung augmented', ...
'IncludeSubfolders',true, ...
'LabelSource','foldernames'); % for JPG images
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomize',true);
net = alexnet(); % analyzeNetwork(lgraph)
numClasses = numel(categories(imdsTrain.Labels)); % number of classes = number of folders
imageSize = [224 224]; % you can use here the original dataset size
lgraph = layerGraph(net.Layers);
lgraph = removeLayers(lgraph, 'fc8');
lgraph = removeLayers(lgraph, 'prob');
lgraph = removeLayers(lgraph, 'output');
% create and add layers
inputLayer = imageInputLayer([imageSize 1], 'Name', net.Layers(1).Name,...
'DataAugmentation', net.Layers(1).DataAugmentation, ...
'Normalization', net.Layers(1).Normalization);
lgraph = replaceLayer(lgraph,net.Layers(1).Name,inputLayer);
newConv1_Weights = net.Layers(2).Weights;
newConv1_Weights = mean(newConv1_Weights(:,:,1:3,:), 3); % taking the mean of kernal channels
newConv1 = convolution2dLayer(net.Layers(2).FilterSize(1), net.Layers(2).NumFilters,...
'Name', net.Layers(2).Name,...
'NumChannels', inputLayer.InputSize(3),...
'Stride', net.Layers(2).Stride,...
'DilationFactor', net.Layers(2).DilationFactor,...
'Padding', net.Layers(2).PaddingSize,...
'Weights', newConv1_Weights,...BiasLearnRateFactor
'Bias', net.Layers(2).Bias,...
'BiasLearnRateFactor', net.Layers(2).BiasLearnRateFactor);
lgraph = replaceLayer(lgraph,net.Layers(2).Name,newConv1);
lgraph = addLayers(lgraph, fullyConnectedLayer(numClasses,'Name', 'fc2'));
lgraph = addLayers(lgraph, softmaxLayer('Name', 'softmax'));
lgraph = addLayers(lgraph, classificationLayer('Name','output'));
lgraph = connectLayers(lgraph, 'drop7', 'fc2');
lgraph = connectLayers(lgraph, 'fc2', 'softmax');
lgraph = connectLayers(lgraph, 'softmax', 'output');
% -------------------------------------------------------------------------
augimdsTrain = augmentedImageDatastore(imageSize,imdsTrain);
augimdsValidation = augmentedImageDatastore(imageSize,imdsValidation);
options = trainingOptions('sgdm', ...
'MiniBatchSize',64, ...
'MaxEpochs',30, ... % i changed this from 20 to 10 and 5
'InitialLearnRate',0.0001, ...
'Shuffle','every-epoch', ...
'ValidationData',augimdsValidation, ...
'ValidationFrequency',3, ...
'Verbose',false, ...
'Plots','training-progress');
net = trainNetwork(augimdsTrain,lgraph,options)
[YPred, probs] = classify(net,augimdsValidation);
accuracy = mean(YPred ==imdsValidation.Labels)
figure
cc = confusionchart (imdsValidation.Labels, YPred);
  2 Comments
Walter Roberson
Walter Roberson on 2 Oct 2020
I am not clear as to what you mean by "exchanging" those?
shivan artosh
shivan artosh on 2 Oct 2020
i mean substitute alexnet with another network e.g. (vgg16, vgg19, densnet...) as shown in this line:
net = alexnet(); % analyzeNetwork(lgraph)

Sign in to comment.

Answers (1)

Walter Roberson
Walter Roberson on 2 Oct 2020
nets = {alexnet(), vgg16(), vgg19(), resnet18()}; %I do not see desnet201 available
numnet = length(nets);
for netidx = 1 : numnet
net = nets{netidx};
now do your stuff starting from the assignment to numClasses
end
I tend to suspect that the exact names of the existing layers to remove will differ from model to model.
  8 Comments
shivan artosh
shivan artosh on 4 Oct 2020
i use only resnet18, but i got this error:
Error using nnet.cnn.LayerGraph>iValidateLayerName (line 663)
Layer 'fc18' does not exist.
Error in nnet.cnn.LayerGraph/removeLayers (line 234)
iValidateLayerName( ...
Error in SHIVANaugmented_test (line 20)
lgraph = removeLayers(lgraph, 'fc18');
Walter Roberson
Walter Roberson on 5 Oct 2020
It looks to me as if resnet18 has a layer 'fc1000'

Sign in to comment.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!