How do I use trainnetwork() to retrain a pre-trained model?
6 views (last 30 days)
Show older comments
How can I replace the decoder and regression layers in my pretrained CAE model with fully connected layers, softmax layers and classification layers to retrain the model into a classifier?
This is the model I created.
lgraph = layerGraph();
tempLayers = [
imageInputLayer([224 224 3],"Name","imageinput")
convolution2dLayer([3 3],256,"Name","conv_1","Padding","same","Stride",[2 2])
reluLayer("Name","relu_1")
maxPooling2dLayer([1 1],"Name","maxpoolForUnpool_3","HasUnpoolingOutputs",true,"Padding","same")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
convolution2dLayer([3 3],128,"Name","conv_2","Padding","same","Stride",[2 2])
reluLayer("Name","relu_2")
maxPooling2dLayer([1 1],"Name","maxpoolForUnpool_2","HasUnpoolingOutputs",true,"Padding","same")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
convolution2dLayer([3 3],64,"Name","conv_3","Padding","same","Stride",[2 2])
reluLayer("Name","relu_3")
maxPooling2dLayer([1 1],"Name","maxpoolForUnpool_1","HasUnpoolingOutputs",true,"Padding","same")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
transposedConv2dLayer([3 3],64,"Name","transposed-conv_1","Cropping","same")
reluLayer("Name","relu_4")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
maxUnpooling2dLayer("Name","maxunpool_1")
transposedConv2dLayer([3 3],128,"Name","transposed-conv_2","Cropping","same","Stride",[2 2])
reluLayer("Name","relu_5")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
maxUnpooling2dLayer("Name","maxunpool_2")
transposedConv2dLayer([3 3],256,"Name","transposed-conv_3","Cropping","same","Stride",[2 2])
reluLayer("Name","relu_6")];
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
maxUnpooling2dLayer("Name","maxunpool_3")
transposedConv2dLayer([3 3],3,"Name","transposed-conv_4","Cropping","same","Stride",[2 2])
reluLayer("Name","relu_7")
regressionLayer("Name","regressionoutput")];
lgraph = addLayers(lgraph,tempLayers);
% clean up helper variable
clear tempLayers;
lgraph = connectLayers(lgraph,"maxpoolForUnpool_3/out","conv_2");
lgraph = connectLayers(lgraph,"maxpoolForUnpool_3/indices","maxunpool_3/indices");
lgraph = connectLayers(lgraph,"maxpoolForUnpool_3/size","maxunpool_3/size");
lgraph = connectLayers(lgraph,"maxpoolForUnpool_2/out","conv_3");
lgraph = connectLayers(lgraph,"maxpoolForUnpool_2/indices","maxunpool_2/indices");
lgraph = connectLayers(lgraph,"maxpoolForUnpool_2/size","maxunpool_2/size");
lgraph = connectLayers(lgraph,"maxpoolForUnpool_1/out","transposed-conv_1");
lgraph = connectLayers(lgraph,"maxpoolForUnpool_1/indices","maxunpool_1/indices");
lgraph = connectLayers(lgraph,"maxpoolForUnpool_1/size","maxunpool_1/size");
lgraph = connectLayers(lgraph,"relu_4","maxunpool_1/in");
lgraph = connectLayers(lgraph,"relu_5","maxunpool_2/in");
lgraph = connectLayers(lgraph,"relu_6","maxunpool_3/in");
0 Comments
Answers (1)
Rahul
on 13 Oct 2022
Go the Apps --> Deep Network Designer --> Blank Network.
Once you create your network by dragging and dropping the layers and connecting them, click on Export --> Generate Code. This should create your model in a very simple way. If you are still unsure, please send the entire architecture, I will create the network for you.
5 Comments
Rahul
on 14 Oct 2022
Below code is just the demo CNN architecture. You can refer this to build your own CNN architecture.
layers = [ ...
imageInputLayer([28 28 1]) % image input layer
convolution2dLayer(5,20) % 2D convolutional layer
reluLayer("Name","relu1") % ReLU activation layer
maxPooling2dLayer(2,'Stride',2) % 2D max pooling layer
fullyConnectedLayer(2048,"Name","FC1") % Fully connected layer 1
reluLayer("Name","relu2") % ReLU activation layer
fullyConnectedLayer(1024,"Name","FC2") % Fully connected layer 2
reluLayer("Name","relu3") % ReLU activation layer
fullyConnectedLayer(10) % Fully connected layer 3
% (10 represented number of classes)
softmaxLayer % Softmax activation layer to calculate class probability
classificationLayer]
% Classification layer to let the system know that it is a classification
% task.
See Also
Categories
Find more on Pretrained Networks from External Platforms in Help Center and File Exchange
Products
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!