Get Started with Deep Network Designer

This example shows how to fine-tune a pretrained GoogLeNet network to classify a new collection of images. This process is called transfer learning and is usually much faster and easier than training a new network, because you can apply learned features to a new task using a smaller number of training images. To interactively prepare a network for transfer learning, use Deep Network Designer.

Load Pretrained Network

Load a pretrained GoogLeNet network. If you need to download the network, use the download link.

net = googlenet;

Import Network into Deep Network Designer

Open Deep Network Designer.


Click Import and select the network from the workspace. Deep Network Designer displays a zoomed out view of the whole network. Explore the network plot. To zoom in with the mouse, use Ctrl+scroll wheel.

Edit Network for Transfer Learning

To retrain a pretrained network to classify new images, replace the final layers with new layers adapted to the new data set. You must change the number of classes to match your data.

Drag a new fullyConnectedLayer from the Layer Library onto the canvas. Edit the OutputSize to the number of classes in the new data, in this example, 5.

Edit learning rates to learn faster in the new layers than in the transferred layers. Set WeightLearnRateFactor and BiasLearnRateFactor to 10. Delete the last fully connected and connect up your new layer instead.

Replace the output layer. Scroll to the end of the Layer Library and drag a new classificationLayer onto the canvas. Delete the original output layer and connect up your new layer instead.

Check Network

To make sure your edited network is ready for training, click Analyze, and ensure the Deep Learning Network Analyzer reports zero errors.

Export Network for Training

Return to the Deep Network Designer and click Export. Deep Network Designer exports the network to a new variable called lgraph_1 containing the edited network layers. You can now supply the layer variable to the trainNetwork function. You can also generate MATLAB® code that recreates the network architecture and returns it as a layerGraph object or a Layer array in the MATLAB workspace.

Load Data and Train Network

Unzip and load the new images as an image datastore. Divide the data into 70% training data and 30% validation data.

imds = imageDatastore('MerchData','IncludeSubfolders',true,'LabelSource','foldernames');
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');

Resize images to match the pretrained network input size.

augimdsTrain = augmentedImageDatastore([224 224],imdsTrain);
augimdsValidation = augmentedImageDatastore([224 224],imdsValidation);

Specify training options.

  • Specify the mini-batch size, that is, how many images to use in each iteration.

  • Specify a small number of epochs. An epoch is a full training cycle on the entire training data set. For transfer learning, you do not need to train for as many epochs. Shuffle the data every epoch.

  • Set InitialLearnRate to a small value to slow down learning in the transferred layers.

  • Specify validation data and a small validation frequency.

  • Turn on the training plot to monitor progress while you train.

options = trainingOptions('sgdm', ...
    'MiniBatchSize',10, ...
    'MaxEpochs',6, ...
    'Shuffle','every-epoch', ...
    'InitialLearnRate',1e-4, ...
    'ValidationData',augimdsValidation, ...
    'ValidationFrequency',6, ...
    'Verbose',false, ...

To train the network, supply the layers exported from the app, lgraph_1, the training images, and options, to the trainNetwork function. By default, trainNetwork uses a GPU if available (requires Parallel Computing Toolbox™). Otherwise, it uses a CPU. Training is fast because the data set is so small.

netTransfer = trainNetwork(augimdsTrain,lgraph_1,options);

Test Trained Network

Classify the validation images using the fine-tuned network, and calculate the classification accuracy.

[YPred,probs] = classify(netTransfer,augimdsValidation);
accuracy = mean(YPred == imdsValidation.Labels)
accuracy = 1

Display four sample validation images with predicted labels and predicted probabilities.

idx = randperm(numel(augimdsValidation.Files),4);
for i = 1:4
    I = readimage(imdsValidation,idx(i));
    label = YPred(idx(i));
    title(string(label) + ", " + num2str(100*max(probs(idx(i),:)),3) + "%");

To learn more and try other pretrained networks, see Deep Network Designer.

See Also

Related Topics