Compress Image Classification Network for Deployment to Resource-Constrained Embedded Devices
This example shows how to reduce the memory footprint and computation requirements of an image classification network for deployment on resource constrained embedded devices such as the Raspberry Pi™.
In many applications where transfer learning is used to retrain an image classification network for a new task or where a new network is trained from scratch, the optimal network architecture is not known, and the network might be overparameterized. An overparameterized network has redundancies. Network pruning is a powerful model compression tool that helps identify redundancies that can be removed with little impact on the final network output. When you use pruning in combination with network quantization, you can reduce the inference time and memory footprint of the network making it easier to deploy to ARM® CPU platforms such as the Raspberry Pi.
This example shows how to:
Use transfer learning to retrain
SqueezeNet
, a pretrained convolutional neural network to classify a new set of images from the CIFAR-10 data set.Prune filters from the convolutional layers of the network by using first-order Taylor approximation.
Retrain the network after pruning to regain any loss in accuracy.
Evaluate the impact of pruning on classification accuracy.
Quantize the weights, biases, and activations of the convolution layers to 8-bit scaled integer data type.
Generate and deploy optimized C++ code to a Raspberry Pi.
Evaluate the impact of quantization on the classification accuracy of the pruned network.
Third-Party Prerequisites
Raspberry Pi hardware
ARM Compute Library (on the target ARM hardware)
Environment variables for the compilers and libraries. See Prerequisites for Deep Learning with MATLAB Coder.
Prepare Data
Download the CIFAR-10 data set [1]. The data set contains 60,000 images. Each image is 32-by-32 pixels in size and has three color channels (RGB). The size of the data set is 175 MB. Depending on your internet connection, the download process can take time.
datadir = tempdir; downloadCIFARData(datadir);
Load the CIFAR-10 training and test images as 4-D arrays. The training set contains 50,000 images and the test set contains 10,000 images. Use the CIFAR-10 test images for network validation.
[XTrain,TTrain,XValidation,TValidation] = loadCIFARData(datadir);
You can display a random sample of the training images using the following code.
figure; idx = randperm(size(XTrain,4),20); im = imtile(XTrain(:,:,:,idx),ThumbnailSize=[96,96]); imshow(im)
Create an augmentedImageDatastore
object to use for network training. During training, the datastore randomly flips the training images along the vertical axis and randomly translates them up to four pixels horizontally and vertically. Data augmentation helps prevent the network from overfitting and memorizing the exact details of the training images.
imageSize = [32,32,3]; pixelRange = [-4,4]; imageAugmenter = imageDataAugmenter( ... RandXReflection=true, ... RandXTranslation=pixelRange, ... RandYTranslation=pixelRange); augimdsTrain = augmentedImageDatastore(imageSize,XTrain,TTrain, ... DataAugmentation=imageAugmenter,OutputSizeMode="randcrop"); augimdsValidation = augmentedImageDatastore(imageSize,XValidation, ... TValidation,DataAugmentation=imageAugmenter); classes = categories(TTrain);
Retrain Network on CIFAR-10 Data Using Transfer Learning
SqueezeNet
has been trained on over a million images and can classify images into 1000 object categories (such as keyboard, coffee mug, pencil, and many animals). The pretrained SqueezeNet
network is fine-tuned by using transfer learning. Fine-tuning a network with transfer learning is usually much faster and easier than training a network with randomly initialized weights from scratch.
Retrain Network
Training the network on a good GPU takes considerable amount of time. If you do not have a GPU, then training takes much longer. Training on a GPU or in parallel requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).
To save time while running this example, load a pretrained network by setting doTraining
to false
. To train the network yourself, set doTraining
to true
.
doTraining = false; if doTraining net = squeezenet; %#ok<UNRCH> lgraph = layerGraph(net); larray = [imageInputLayer(imageSize,'Name','data')]; lgraph = replaceLayer(lgraph,'data',larray); [learnableLayer,classLayer] = findLayersToReplace(lgraph); numClasses = 10; newFirstConvLayer = convolution2dLayer([3,3], 64,'WeightLearnRateFactor', ... 10,'BiasLearnRateFactor',10,"Name",'new_firstconv'); lgraph = replaceLayer(lgraph,'conv1',newFirstConvLayer); newConvLayer = convolution2dLayer([1,1],numClasses, ... 'WeightLearnRateFactor',10,'BiasLearnRateFactor',10,"Name",'new_conv'); lgraph = replaceLayer(lgraph,'conv10',newConvLayer); newClassificatonLayer = classificationLayer('Name','new_classoutput'); lgraph = replaceLayer(lgraph,'ClassificationLayer_predictions',newClassificatonLayer); options = trainingOptions('adam', ... 'MiniBatchSize',100, ... 'MaxEpochs',15, ... 'InitialLearnRate',2e-4/3, ... 'Shuffle','every-epoch', ... 'ValidationData',augimdsValidation, ... 'ValidationFrequency',25, ... 'ValidationPatience',5, ... 'Verbose',false, ... 'Plots','training-progress'); transferNet = trainNetwork(augimdsTrain,lgraph,options); else load('transferNet.mat','transferNet'); end
Save the trained network.
save('transferNet.mat','transferNet');
Evaluate Trained Network
Calculate the final accuracy of the network on the validation set (without data augmentation).
[YValPred,probs] = classify(transferNet,XValidation); accuracyOfTrainedNet = mean(YValPred == TValidation) * 100; disp("Validation accuracy of trained network: " + accuracyOfTrainedNet + "%")
Validation accuracy of trained network: 60.48%
Prune Network
Prune the network using the taylorPrunableNetwork
function. The network computes an importance score for each convolution filter in the network based on Taylor expansion [2][3]. Pruning is iterative; each time the loop runs, until a stopping criterion is met, the function removes a small number of the least important convolution filters and updates the network architecture.
Specify Pruning and Fine-Tuning Options
Set the pruning options.
maxPruningIterations
sets the maximum number of iterations to be used for pruning process.maxToPrune
is set as the maximum number of filters to be pruned in each iteration of the pruning cycle.
maxPruningIterations = 30; maxToPrune = 32;
Set the fine-tuning options.
learnRate = 1e-2/3; momentum = 0.9; miniBatchSize = 256; numMinibatchUpdates = 50; validationFrequency = 1;
Prune Network using Custom Pruning Loop
To implement a custom pruning loop, convert the network to a dlnetwork
object.
layerG = layerGraph(transferNet); layerG = removeLayers(layerG,layerG.OutputNames); net = dlnetwork(layerG);
Print a summary of the dlnetwork
object. The summary shows whether the network is initialized, the total number of learnable parameters, and information about the network inputs.
summary(net)
Initialized: true Number of learnables: 727.6k Inputs: 1 'data' 32×32×3 images
Create a Taylor prunable network from the original network.
prunableNet = taylorPrunableNetwork(net); maxPrunableFilters = prunableNet.NumPrunables;
Create a minibatchqueue
object that processes and manages mini-batches of images during training. For each mini-batch:
Use the custom mini-batch preprocessing function
preprocessMiniBatch
(defined at the end of this example) to convert the labels to one-hot encoded variables.Format the image data with the dimension labels
'SSCB'
(spatial, spatial, channel, batch). By default, theminibatchqueue
object converts the data todlarray
objects with underlying typesingle
. Do not add a format to the class labels.Train on a GPU if one is available. By default, the
minibatchqueue
object converts each output to agpuArray
if a GPU is available.
mbqTrain = minibatchqueue(augimdsTrain, ... MiniBatchSize = miniBatchSize, ... MiniBatchFcn = @preprocessMiniBatchTraining, ... OutputAsDlarray = [1 1], ... OutputEnvironment = ["auto","auto"], ... PartialMiniBatch = "return", ... MiniBatchFormat = ["SSCB",""]); mbqTest = minibatchqueue(augimdsValidation,... MiniBatchSize = miniBatchSize,... MiniBatchFcn = @preprocessMiniBatchTraining, ... OutputAsDlarray = [1 1], ... OutputEnvironment = ["auto","auto"], ... PartialMiniBatch = "return", ... MiniBatchFormat = ["SSCB",""]);
Initialize the training progress plots.
figure("Position",[10,10,700,700]) tl = tiledlayout(3,1); lossAx = nexttile; lineLossFinetune = animatedline(Color=[0.85 0.325 0.098]); ylim([0 inf]) xlabel("Fine-Tuning Iteration") ylabel("Loss") grid on title("Mini-Batch Loss During Pruning") xTickPos = []; accuracyAx = nexttile; lineAccuracyPruning = animatedline(Color=[0.098 0.325 0.85],LineWidth=2,Marker="o"); ylim([0 100]) xlabel("Pruning Iteration") ylabel("Accuracy") grid on addpoints(lineAccuracyPruning,0,accuracyOfTrainedNet) title("Validation Accuracy After Pruning") numPrunablesAx = nexttile; lineNumPrunables = animatedline(Color=[0.4660 0.6740 0.1880],LineWidth=2,Marker="^"); ylim([200 3000]) xlabel("Pruning Iteration") ylabel("Prunable Filters") grid on addpoints(lineNumPrunables,0,double(maxPrunableFilters)) title("Number of Prunable Convolution Filters After Pruning")
Prune the network by repeatedly fine-tuning the network and removing the low scoring filters.
For each pruning iteration. The following steps are used:
Fine-tune network and accumulate Taylor scores for convolution filters for
numMinibatchUpdates
Prune the network using the
updatePrunables
function to removemaxToPrune
number of convolution filtersCompute validation accuracy
To fine tune the network, loop over the mini-batches of the training data. For each mini-batch in the fine-tuning iteration the following steps are used:
Evaluate the pruning loss, gradients of the pruning activations, pruning activations, model gradients and the state using the
dlfeval
andmodelLossPruning
functions.Update the network state.
Update the network parameters using the
sgdmupdate
function.Update the Taylor scores of the prunable network using the
updateScore
function.Display the training progress.
start = tic; iteration = 0; for pruningIteration = 1:maxPruningIterations shuffle(mbqTrain); velocity = []; % Loop over mini-batches. fineTuningIteration = 0; while hasdata(mbqTrain) iteration = iteration + 1; fineTuningIteration = fineTuningIteration + 1; [X, T] = next(mbqTrain); [loss,pruningActivations, pruningGradients, netGradients, state] = ... dlfeval(@modelLossPruning, prunableNet, X, T); prunableNet.State = state; [prunableNet, velocity] = sgdmupdate(prunableNet, netGradients, velocity, learnRate, momentum); prunableNet = updateScore(prunableNet, pruningActivations, pruningGradients); % Display the training progress. D = duration(0,0,toc(start),Format="hh:mm:ss"); addpoints(lineLossFinetune, iteration, double(loss)) title(tl,"Processing Pruning Iteration: " + pruningIteration + " of " + maxPruningIterations + ... ", Elapsed Time: " + string(D)) % Synchronize the x-axis of the accuracy and numPrunables plots with the loss plot. xlim(accuracyAx,lossAx.XLim) xlim(numPrunablesAx,lossAx.XLim) drawnow % Stop the fine-tuning loop when numMinibatchUpdates is reached. if (fineTuningIteration > numMinibatchUpdates) break end end % Prune filters based on previously computed Taylor scores. prunableNet = updatePrunables(prunableNet, MaxToPrune = maxToPrune); % Show results on the validation data set in a subset of pruning iterations. isLastPruningIteration = pruningIteration == maxPruningIterations; if (mod(pruningIteration, validationFrequency) == 0 || isLastPruningIteration) accuracy = modelAccuracy(prunableNet, mbqTest, classes, augimdsValidation.NumObservations); addpoints(lineAccuracyPruning, iteration, accuracy) addpoints(lineNumPrunables,iteration,double(prunableNet.NumPrunables)) end xTickPos = [xTickPos, iteration]; %#ok<AGROW> xticks(lossAx,xTickPos) xticks(accuracyAx,[0,xTickPos]) xticks(numPrunablesAx,[0,xTickPos]) xticklabels(accuracyAx,["Unpruned",string(1:pruningIteration)]) xticklabels(numPrunablesAx,["Unpruned",string(1:pruningIteration)]) drawnow end
In contrast to typical training where the loss decreases with each iteration, pruning may increase the loss and reduce the validation accuracy due to the change of network structure when convolution filters are pruned. To further improve the accuracy of the network, you can retrain the network.
Once pruning is complete, convert the taylorPrunableNetwork
back to a dlnetwork
for retraining.
prunedNet = dlnetwork(prunableNet);
Retrain Network After Pruning
Retrain the network after pruning to regain any loss in accuracy. To retrain the network using the trainNetwork
function,
Extract the
layerGraph
from thedlnetwork
.Add the removed classification layer from the original network to the
layerGraph
of the pruned network.Train the
layerGraph
network.
prunedLayerGraph = layerGraph(prunedNet); outputLayerName = string(transferNet.OutputNames{1}); outputLayerIdx = {transferNet.Layers.Name} == outputLayerName; prunedLayerGraph = addLayers(prunedLayerGraph,transferNet.Layers(outputLayerIdx)); prunedLayerGraph = connectLayers(prunedLayerGraph,prunedNet.OutputNames{1},outputLayerName);
Set the options to the default settings for stochastic gradient descent with momentum. Set the maximum number of retraining epochs at 10 and start the training with an initial learning rate of 0.01.
options = trainingOptions("adam", ... MaxEpochs = 10, ... MiniBatchSize = 100, ... InitialLearnRate = 2e-4/3, ... LearnRateSchedule = "piecewise", ... LearnRateDropFactor = 0.1, ... LearnRateDropPeriod = 2, ... L2Regularization = 0.02, ... ValidationData = augimdsValidation, ... ValidationPatience=5,... ValidationFrequency = 25, ... Verbose = false, ... Shuffle = "every-epoch", ... Plots = "training-progress");
Train the network.
prunedDAGNet = trainNetwork(augimdsTrain,prunedLayerGraph,options);
Save the pruned network.
save('prunedDAGNet.mat','prunedDAGNet');
Compare Original Network and Pruned Network
Determine the impact of pruning on each layer.
[originalNetFilters,layerNames] = numConvLayerFilters(transferNet); prunedNetFilters = numConvLayerFilters(prunedDAGNet);
Visualize the number of filters in the original network and in the pruned network.
figure("Position",[10,10,900,900]) bar([originalNetFilters,prunedNetFilters]) xlabel("Layer") ylabel("Number of Filters") title("Number of Filters Per Layer") xticks(1:(numel(layerNames))) xticklabels(layerNames) xtickangle(90) ax = gca; ax.TickLabelInterpreter = "none"; legend("Original Network Filters","Pruned Network Filters","Location","southoutside")
Large differences between the number of filters of the two networks indicate where many of the less important filters have been pruned.
Next, compare the accuracy of the original network and the pruned network.
tic YPredOriginal = classify(transferNet,XValidation); toc
Elapsed time is 1.435466 seconds.
accuOriginal = mean(YPredOriginal == TValidation)
accuOriginal = 0.6048
tic YPredPruned = classify(prunedDAGNet,XValidation); toc
Elapsed time is 2.194408 seconds.
accuPruned = mean(YPredPruned == TValidation)
accuPruned = 0.7843
Pruning can unequally affect the classification of different classes and introduce bias into the model, which might not be apparent from the accuracy value. To assess the impact of pruning at a class level, use a confusion matrix chart.
figure confusionchart(TValidation,YPredOriginal,Normalization = "row-normalized"); title("Original Network")
figure confusionchart(TValidation,YPredPruned,Normalization = "row-normalized"); title("Pruned Network")
Next, estimate the model parameters for the original network and the pruned network to understand the impact of pruning on the overall network learnables and size.
analyzeNetworkMetrics(transferNet,prunedDAGNet,accuOriginal,accuPruned)
ans=3×3 table
Network Learnables Approx. Network Memory (MB) Accuracy
__________________ ___________________________ ________
Original Network 7.2763e+05 2.7757 0.6048
Pruned Network 3.8997e+05 1.4876 0.7843
Percentage Change -46.406 -46.406 29.679
This table compares the size and classification accuracy of the original and the pruned network. A decrease in network memory and improves accuracy values indicate a successful pruning operation.
Quantize the Pruned Network
To quantize the pruned network using the dlquantizer
function, specify the network you want to calibrate and the execution environment, and then calibrate with calibration data.
clear r r = raspi; quantOpts = dlquantizationOptions('Target',r); quantObj = dlquantizer(prunedDAGNet,'ExecutionEnvironment','CPU');
Use the calibrate function to exercise the network with the calibration data and collect range statistics for the weights, biases, and activations at each layer.
calResults = calibrate(quantObj,augimdsTrain,'UseGPU','off')
### Host application produced the following standard output (stdout) and standard error (stderr) messages:
calResults=122×5 table
Optimized Layer Name Network Layer Name Learnables / Activations MinValue MaxValue
____________________________ ____________________ ________________________ _________ ________
{'new_firstconv_Weights' } {'new_firstconv' } "Weights" -0.53081 0.50032
{'new_firstconv_Bias' } {'new_firstconv' } "Bias" -0.13664 0.2061
{'fire2-squeeze1x1_Weights'} {'fire2-squeeze1x1'} "Weights" -1.3348 1.1903
{'fire2-squeeze1x1_Bias' } {'fire2-squeeze1x1'} "Bias" -0.12888 0.25519
{'fire2-expand1x1_Weights' } {'fire2-expand1x1' } "Weights" -0.71728 0.87709
{'fire2-expand1x1_Bias' } {'fire2-expand1x1' } "Bias" -0.065638 0.14888
{'fire2-expand3x3_Weights' } {'fire2-expand3x3' } "Weights" -0.71899 0.6452
{'fire2-expand3x3_Bias' } {'fire2-expand3x3' } "Bias" -0.062058 0.08805
{'fire3-squeeze1x1_Weights'} {'fire3-squeeze1x1'} "Weights" -0.72677 0.67948
{'fire3-squeeze1x1_Bias' } {'fire3-squeeze1x1'} "Bias" -0.11343 0.33745
{'fire3-expand1x1_Weights' } {'fire3-expand1x1' } "Weights" -0.68734 0.93931
{'fire3-expand1x1_Bias' } {'fire3-expand1x1' } "Bias" -0.075568 0.31345
{'fire3-expand3x3_Weights' } {'fire3-expand3x3' } "Weights" -0.5874 0.72577
{'fire3-expand3x3_Bias' } {'fire3-expand3x3' } "Bias" -0.066463 0.12058
{'fire4-squeeze1x1_Weights'} {'fire4-squeeze1x1'} "Weights" -0.70607 1.0569
{'fire4-squeeze1x1_Bias' } {'fire4-squeeze1x1'} "Bias" -0.11843 0.14643
⋮
Save the dlquantizer
object containing the network to quantize.
save('squeezenetCalResults.mat','calResults'); save('squeezenetQuantObj.mat','quantObj');
We can use the Deep Network Quantizer app to further visualize the dynamic ranges of the calibrated layers:
Use the validate function to compare the results of the network before and after quantization using the validation data set. Examine the MetricResults.Result
field of the validation output to see the accuracy of the quantized network.
validationMetricsC = validate(quantObj,augimdsValidation,quantOpts);
### Starting application: 'codegen/lib/validate_predict_int8/pil/validate_predict_int8.elf' To terminate execution: clear validate_predict_int8_pil ### Launching application validate_predict_int8.elf... ### Host application produced the following standard output (stdout) and standard error (stderr) messages:
quantObj.ValidationMetrics.MetricResults.Result
ans=2×2 table
NetworkImplementation MetricOutput
_____________________ ____________
{'Floating-Point'} 0.765
{'Quantized' } 0.7641
Generate and Deploy INT8 C++ Code to Raspberry Pi
The predictResponses.m
entry-point function takes an image input and runs prediction on the image using the specified network. The function uses a persistent object mynet
to load the network object and reuses the persistent object for prediction on subsequent calls.
type predictResponses.m
function out = predictResponses(net,in) persistent mynet; if isempty(mynet) mynet = coder.loadDeepLearningNetwork(net); end out = predict(mynet, in); end
To generate a PIL MEX function, create a code configuration object for a static library and set the verification mode to 'PIL'
. Set the target language to C++. Create a coder.Hardware
object for Raspberry Pi and attach it to the code generation configuration object.
cfg = coder.config('lib', 'ecoder', true); cfg.VerificationMode = 'PIL'; cfg.TargetLang = 'C++'; cfg.Hardware = coder.hardware('Raspberry Pi');
Create a deep learning configuration object for the ARM Compute library. Specify the library version and arm architecture. For this example, suppose that the ARM Compute Library in the Raspberry Pi hardware is version 20.02.1.
dlcfg = coder.DeepLearningConfig('arm-compute'); dlcfg.ArmComputeVersion = '20.02.1'; dlcfg.ArmArchitecture = 'armv7';
Set the properties of dlcfg
to generate code for INT8 inference.
dlcfg.CalibrationResultFile = 'squeezenetQuantObj.mat'; dlcfg.DataType = 'int8'; cfg.DeepLearningConfig = dlcfg; inputs = {coder.Constant('prunedDAGNet.mat'),ones(32,32,3,'uint8')};
Generate a PIL MEX function by using the codegen
command.
codegen -config cfg predictResponses -args inputs
Deploying code. This may take a few minutes. ### Connectivity configuration for function 'predictResponses': 'Raspberry Pi' Location of the generated elf : /home/pi/MATLAB_ws/R2023a/home/lnarasim/Documents/MATLAB/ExampleManager/lnarasim.Bdoc23a.j2174901/deeplearning_shared-ex40890309/codegen/lib/predictResponses/pil Code generation successful.
Compare Classification Accuracy of the Transfer Learned, Pruned, and Quantized Networks
Evaluate the impact of quantization on the classification accuracy of the pruned network.
testImages = read(augimdsValidation); testImage = table2array(testImages(4,1)); predictScores(:,1) = predictResponses('transferNet.mat', testImage{1}); predictScores(:,2) = predictResponses('prunedDAGNet.mat', testImage{1}); predictScores(:,3) = predictResponses_pil('prunedDAGNet.mat',testImage{1});
### Starting application: 'codegen/lib/predictResponses/pil/predictResponses.elf' To terminate execution: clear predictResponses_pil ### Launching application predictResponses.elf...
barh(predictScores) xlabel('Probability') yticklabels(classes) XLim = [0 1.1]; YAxisLocation = 'left'; legend('Trained Network (Single)','Pruned Network (Single)','ARM-Compute (8-bit integer)'); sgtitle('Network Predictions')
Helper Functions
Download CIFAR-10 Dataset
The downloadCIFARData
function downloads the CIFAR-10 dataset from the external website. The download is approximately 175MB in size.
function downloadCIFARData(destination) url = 'https://www.cs.toronto.edu/~kriz/cifar-10-matlab.tar.gz'; unpackedData = fullfile(destination,'cifar-10-batches-mat'); if ~exist(unpackedData,'dir') fprintf('Downloading CIFAR-10 dataset (175 MB). This can take a while...'); untar(url,destination); fprintf('done.\n\n'); end end
Process CIFAR-10 Dataset
Load the CIFAR-10 training and test images as 4-D arrays. The training set contains 50,000 images and the test set contains 10,000 images. Use the CIFAR-10 test images for network validation.
function [XTrain,YTrain,XTest,YTest] = loadCIFARData(location) location = fullfile(location,'cifar-10-batches-mat'); [XTrain1,YTrain1] = loadBatchAsFourDimensionalArray(location,'data_batch_1.mat'); [XTrain2,YTrain2] = loadBatchAsFourDimensionalArray(location,'data_batch_2.mat'); [XTrain3,YTrain3] = loadBatchAsFourDimensionalArray(location,'data_batch_3.mat'); [XTrain4,YTrain4] = loadBatchAsFourDimensionalArray(location,'data_batch_4.mat'); [XTrain5,YTrain5] = loadBatchAsFourDimensionalArray(location,'data_batch_5.mat'); XTrain = cat(4,XTrain1,XTrain2,XTrain3,XTrain4,XTrain5); YTrain = [YTrain1;YTrain2;YTrain3;YTrain4;YTrain5]; [XTest,YTest] = loadBatchAsFourDimensionalArray(location,'test_batch.mat'); end function [XBatch,YBatch] = loadBatchAsFourDimensionalArray(location,batchFileName) s = load(fullfile(location,batchFileName)); XBatch = s.data'; XBatch = reshape(XBatch,32,32,3,[]); XBatch = permute(XBatch,[2 1 3 4]); YBatch = convertLabelsToCategorical(location,s.labels); end function categoricalLabels = convertLabelsToCategorical(location,integerLabels) s = load(fullfile(location,'batches.meta.mat')); categoricalLabels = categorical(integerLabels,0:9,s.label_names); end
Mini-Batch Preprocessing Function
The preprocessMiniBatchTraining
function preprocesses a mini-batch of predictors and labels for loss computation during training.
function [X,T] = preprocessMiniBatchTraining(XCell,TCell) % Concatenate. X = cat(4,XCell{1:end}); % Extract label data from cell and concatenate. T = cat(2,TCell{1:end}); % One-hot encode labels. T = onehotencode(T,1); end
Model Gradients Function for Fine-Tuning and Pruning
The modelLossPruning
function takes as input a deep.prune.TaylorPrunableNetwork
object prunableNet
, a mini-batch of input data X
with corresponding labels T and returns the loss, gradients of the loss with respect to the pruning activations, pruning activations, gradients of the loss with respect to the learnable parameters in prunableNet
and the network state. To compute the gradients automatically, use the dlgradient
function.
function [loss,pruningGradient,pruningActivations,netGradients,state] = modelLossPruning(prunableNet, X, T) [dlYPred,state,pruningActivations] = forward(prunableNet,X); dlYPred = squeeze(dlYPred); loss = crossentropy(dlYPred,T); [pruningGradient,netGradients] = dlgradient(loss,pruningActivations,prunableNet.Learnables); end
Evaluate Model Accuracy
The modelAccuracy
function takes as input the network(dlnetwork
), minibatchque
object, the classes and the number of observations and returns the accuracy.
function accuracy = modelAccuracy(net, mbq, classes, numObservations) % This function computes the model accuracy of a net(dlnetwork) on the minibatchque 'mbq'. totalCorrect = 0; classes = int32(categorical(classes)); reset(mbq); while hasdata(mbq) [dlX, Y] = next(mbq); dlYPred = extractdata(predict(net, dlX)); dlYPred = squeeze(dlYPred); YPred = onehotdecode(dlYPred,classes,1)'; YReal = onehotdecode(Y,classes,1)'; miniBatchCorrect = nnz(YPred == YReal); totalCorrect = totalCorrect + miniBatchCorrect; end accuracy = totalCorrect / numObservations * 100; end
Evaluate Number of Filters in Convolution Layers
The numConvLayerFilters
function returns the number of filters in each convolution layer.
function [nFilters, convNames] = numConvLayerFilters(net) numLayers = numel(net.Layers); convNames = []; nFilters = []; % Check for convolution layers and extract the number of filters. for cnt = 1:numLayers if isa(net.Layers(cnt),"nnet.cnn.layer.Convolution2DLayer") sizeW = size(net.Layers(cnt).Weights); nFilters = [nFilters; sizeW(end)]; %#ok<AGROW> convNames = [convNames; string(net.Layers(cnt).Name)]; %#ok<AGROW> end end end
Evaluate the network statistics of original network and pruned network
The analyzeNetworkMetrics
function takes input as the original network, pruned network, accuracy of original network and the accuracy of the pruned network and returns the different statistics like network learnables, network memory and the accuracy on the test data in form of a table.
function [statistics] = analyzeNetworkMetrics(originalNet,prunedNet,accuracyOriginal,accuracyPruned) originalNetMetrics = estimateNetworkMetrics(originalNet); prunedNetMetrics = estimateNetworkMetrics(prunedNet); % Accuracy of original network and pruned network perChangeAccu = 100*(accuracyPruned - accuracyOriginal)/accuracyOriginal; AccuracyForNetworks = [accuracyOriginal;accuracyPruned;perChangeAccu]; % Total learnables in both networks originalNetLearnables = sum(originalNetMetrics(1:end,"NumberOfLearnables").NumberOfLearnables); prunedNetLearnables = sum(prunedNetMetrics(1:end,"NumberOfLearnables").NumberOfLearnables); percentageChangeLearnables = 100*(prunedNetLearnables - originalNetLearnables)/originalNetLearnables; LearnablesForNetwork = [originalNetLearnables;prunedNetLearnables;percentageChangeLearnables]; % Approximate parameter memory approxOriginalMemory = sum(originalNetMetrics(1:end,"ParameterMemory (MB)").("ParameterMemory (MB)")); approxPrunedMemory = sum(prunedNetMetrics(1:end,"ParameterMemory (MB)").("ParameterMemory (MB)")); percentageChangeMemory = 100*(approxPrunedMemory - approxOriginalMemory)/approxOriginalMemory; NetworkMemory = [ approxOriginalMemory; approxPrunedMemory; percentageChangeMemory]; % Create the summary table statistics = table(LearnablesForNetwork,NetworkMemory,AccuracyForNetworks, ... 'VariableNames',["Network Learnables","Approx. Network Memory (MB)","Accuracy"], ... 'RowNames',{'Original Network','Pruned Network','Percentage Change'}); end function [statistics] = analyzeQuantizedNetworkMetrics(originalNet,quantizedNet) originalNetMetrics = estimateNetworkMetrics(originalNet); quantizedNetMetrics = estimateNetworkMetrics(quantizedNet); % Total learnables in both networks originalNetLearnables = sum(originalNetMetrics(1:end,"NumberOfLearnables").NumberOfLearnables); quantizedNetLearnables = sum(quantizedNetMetrics(1:end,"NumberOfLearnables").NumberOfLearnables); percentageChangeLearnables = 100*(quantizedNetLearnables - originalNetLearnables)/originalNetLearnables; LearnablesForNetwork = [originalNetLearnables;quantizedNetLearnables;percentageChangeLearnables]; % Approximate parameter memory approxOriginalMemory = sum(originalNetMetrics(1:end,"ParameterMemory (MB)").("ParameterMemory (MB)")); approxQuantizedMemory = sum(quantizedNetMetrics(1:end,"ParameterMemory (MB)").("ParameterMemory (MB)")); percentageChangeMemory = 100*(approxQuantizedMemory - approxOriginalMemory)/approxOriginalMemory; NetworkMemory = [ approxOriginalMemory; approxQuantizedMemory; percentageChangeMemory]; % Create the summary table statistics = table(LearnablesForNetwork,NetworkMemory, ... 'VariableNames',["Network Learnables","Approx. Network Memory (MB)"], ... 'RowNames',{'Original Network', 'Pruned & Quantized Network','Percentage Change'}); end
References
[1] Krizhevsky, Alex. "Learning multiple layers of features from tiny images" (2009). https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf
[2] He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep residual learning for image recognition." In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770-778. 2016.
[3] Molchanov, Pavlo, Arun Mallya, Stephen Tyree, Iuri Frosio, and Jan Kautz. “Importance Estimation for Neural Network Pruning.” In 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 11256–64. Long Beach, CA, USA: IEEE, 2019. https://doi.org/10.1109/CVPR.2019.01152.
See Also
Functions
forward
(Deep Learning Toolbox) |predict
(Deep Learning Toolbox) |updatePrunables
(Deep Learning Toolbox) |updateScore
(Deep Learning Toolbox) |TaylorPrunableNetwork
(Deep Learning Toolbox) |dlnetwork
(Deep Learning Toolbox)