Object Detection Using YOLO v3 Deep Learning
9 views (last 30 days)
Show older comments
doTraining = true;
if ~doTraining
preTrainedDetector = downloadPretrainedYOLOv3Detector();
end
data = load('vehicleDatasetGroundTruth.mat');
vehicleDataset = data.vehicleDataset;
vehicleDataset.imageFilename = fullfile(pwd, vehicleDataset.imageFilename);
rng(0);
shuffledIndices = randperm(height(vehicleDataset));
idx = floor(0.6 * length(shuffledIndices));
trainingDataTbl = vehicleDataset(shuffledIndices(1:idx), :);
testDataTbl = vehicleDataset(shuffledIndices(idx+1:end), :);
imdsTrain = imageDatastore(trainingDataTbl.imageFilename);
imdsTest = imageDatastore(testDataTbl.imageFilename);
bldsTrain = boxLabelDatastore(trainingDataTbl(:, 2:end));
bldsTest = boxLabelDatastore(testDataTbl(:, 2:end));
trainingData = combine(imdsTrain, bldsTrain);
testData = combine(imdsTest, bldsTest);
validateInputData(trainingData);
validateInputData(testData);
augmentedTrainingData = transform(trainingData, @augmentData);
augmentedData = cell(4,1);
for k = 1:4
data = read(augmentedTrainingData);
augmentedData{k} = insertShape(data{1,1}, 'Rectangle', data{1,2});
reset(augmentedTrainingData);
end
figure
montage(augmentedData, 'BorderSize', 10)
networkInputSize = [227 227 3];
rng(0)
trainingDataForEstimation = transform(trainingData, @(data)preprocessData(data, networkInputSize));
numAnchors = 6;
[anchors, meanIoU] = estimateAnchorBoxes(trainingDataForEstimation, numAnchors)
area = anchors(:, 1).*anchors(:, 2);
[~, idx] = sort(area, 'descend');
anchors = anchors(idx, :);
anchorBoxes = {anchors(1:3,:)
anchors(4:6,:)
};
baseNetwork = squeezenet;
classNames = trainingDataTbl.Properties.VariableNames(2:end);
yolov3Detector = yolov2ObjectDetector(baseNetwork, classNames, anchorBoxes, 'DetectionNetworkSource', {'fire9-concat', 'fire5-concat'});
%yolov3ObjectDetector(baseNet,classes,aboxes,'DetectionNetworkSource',layer)
preprocessedTrainingData = transform(augmentedTrainingData, @(data)preprocess(yolov3Detector, data));
data = read(preprocessedTrainingData);
reset(preprocessedTrainingData);
I = data{1,1};
bbox = data{1,2};
annotatedImage = insertShape(I, 'Rectangle', bbox);
annotatedImage = imresize(annotatedImage,2);
figure
imshow(annotatedImage)
reset(preprocessedTrainingData);
numEpochs = 80;
miniBatchSize = 8;
learningRate = 0.001;
warmupPeriod = 1000;
l2Regularization = 0.0005;
penaltyThreshold = 0.5;
velocity = [];
if canUseParallelPool
dispatchInBackground = true;
else
dispatchInBackground = false;
end
mbqTrain = minibatchqueue(preprocessedTrainingData, 2,...
"MiniBatchSize", miniBatchSize,...
"MiniBatchFcn", @(images, boxes, labels) createBatchData(images, boxes, labels, classNames), ...
"MiniBatchFormat", ["SSCB", ""],...
"DispatchInBackground", dispatchInBackground,...
"OutputCast", ["", "double"]);
if doTraining
% Create subplots for the learning rate and mini-batch loss.
fig = figure;
[lossPlotter, learningRatePlotter] = configureTrainingProgressPlotter(fig);
iteration = 0;
% Custom training loop.
for epoch = 1:numEpochs
reset(mbqTrain);
shuffle(mbqTrain);
while(hasdata(mbqTrain))
iteration = iteration + 1;
[XTrain, YTrain] = next(mbqTrain);
% Evaluate the model gradients and loss using dlfeval and the
% modelGradients function.
[gradients, state, lossInfo] = dlfeval(@modelGradients, yolov3Detector, XTrain, YTrain, penaltyThreshold);
% Apply L2 regularization.
gradients = dlupdate(@(g,w) g + l2Regularization*w, gradients, yolov3Detector.Learnables);
% Determine the current learning rate value.
currentLR = piecewiseLearningRateWithWarmup(iteration, epoch, learningRate, warmupPeriod, numEpochs);
% Update the detector learnable parameters using the SGDM optimizer.
[yolov3Detector.Learnables, velocity] = sgdmupdate(yolov3Detector.Learnables, gradients, velocity, currentLR);
% Update the state parameters of dlnetwork.
yolov3Detector.State = state;
% Display progress.
displayLossInfo(epoch, iteration, currentLR, lossInfo);
% Update training plot with new points.
updatePlots(lossPlotter, learningRatePlotter, iteration, currentLR, lossInfo.totalLoss);
end
end
else
yolov3Detector = preTrainedDetector;
end
results = detect(yolov3Detector,testData,'MiniBatchSize',8);
% Evaluate the object detector using Average Precision metric.
[ap,recall,precision] = evaluateDetectionPrecision(results,testData);
% Plot precision-recall curve.
figure
plot(recall,precision)
xlabel('Recall')
ylabel('Precision')
grid on
title(sprintf('Average Precision = %.2f', ap))
% Read the datastore.
data = read(testData);
% Get the image.
I = data{1};
[bboxes,scores,labels] = detect(yolov3Detector,I);
% Display the detections on image.
I = insertObjectAnnotation(I,'rectangle',bboxes,scores);
figure
imshow(I)
function [gradients, state, info] = modelGradients(detector, XTrain, YTrain, penaltyThreshold)
inputImageSize = size(XTrain,1:2);
% Gather the ground truths in the CPU for post processing
YTrain = gather(extractdata(YTrain));
% Extract the predictions from the detector.
[gatheredPredictions, YPredCell, state] = forward(detector, XTrain);
% Generate target for predictions from the ground truth data.
[boxTarget, objectnessTarget, classTarget, objectMaskTarget, boxErrorScale] = generateTargets(gatheredPredictions,...
YTrain, inputImageSize, detector.AnchorBoxes, penaltyThreshold);
% Compute the loss.
boxLoss = bboxOffsetLoss(YPredCell(:,[2 3 7 8]),boxTarget,objectMaskTarget,boxErrorScale);
objLoss = objectnessLoss(YPredCell(:,1),objectnessTarget,objectMaskTarget);
clsLoss = classConfidenceLoss(YPredCell(:,6),classTarget,objectMaskTarget);
totalLoss = boxLoss + objLoss + clsLoss;
info.boxLoss = boxLoss;
info.objLoss = objLoss;
info.clsLoss = clsLoss;
info.totalLoss = totalLoss;
% Compute gradients of learnables with regard to loss.
gradients = dlgradient(totalLoss, detector.Learnables);
end
function boxLoss = bboxOffsetLoss(boxPredCell, boxDeltaTarget, boxMaskTarget, boxErrorScaleTarget)
% Mean squared error for bounding box position.
lossX = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,1),boxDeltaTarget(:,1),boxMaskTarget(:,1),boxErrorScaleTarget));
lossY = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,2),boxDeltaTarget(:,2),boxMaskTarget(:,1),boxErrorScaleTarget));
lossW = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,3),boxDeltaTarget(:,3),boxMaskTarget(:,1),boxErrorScaleTarget));
lossH = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,4),boxDeltaTarget(:,4),boxMaskTarget(:,1),boxErrorScaleTarget));
boxLoss = lossX+lossY+lossW+lossH;
end
function objLoss = objectnessLoss(objectnessPredCell, objectnessDeltaTarget, boxMaskTarget)
% Binary cross-entropy loss for objectness score.
objLoss = sum(cellfun(@(a,b,c) crossentropy(a.*c,b.*c,'TargetCategories','independent'),objectnessPredCell,objectnessDeltaTarget,boxMaskTarget(:,2)));
end
function clsLoss = classConfidenceLoss(classPredCell, classTarget, boxMaskTarget)
% Binary cross-entropy loss for class confidence score.
clsLoss = sum(cellfun(@(a,b,c) crossentropy(a.*c,b.*c,'TargetCategories','independent'),classPredCell,classTarget,boxMaskTarget(:,3)));
end
function data = augmentData(A)
% Apply random horizontal flipping, and random X/Y scaling. Boxes that get
% scaled outside the bounds are clipped if the overlap is above 0.25. Also,
% jitter image color.
data = cell(size(A));
for ii = 1:size(A,1)
I = A{ii,1};
bboxes = A{ii,2};
labels = A{ii,3};
sz = size(I);
if numel(sz) == 3 && sz(3) == 3
I = jitterColorHSV(I,...
'Contrast',0.0,...
'Hue',0.1,...
'Saturation',0.2,...
'Brightness',0.2);
end
% Randomly flip image.
tform = randomAffine2d('XReflection',true,'Scale',[1 1.1]);
rout = affineOutputView(sz,tform,'BoundsStyle','centerOutput');
I = imwarp(I,tform,'OutputView',rout);
% Apply same transform to boxes.
[bboxes,indices] = bboxwarp(bboxes,tform,rout,'OverlapThreshold',0.25);
labels = labels(indices);
% Return original data only when all boxes are removed by warping.
if isempty(indices)
data(ii,:) = A(ii,:);
else
data(ii,:) = {I, bboxes, labels};
end
end
end
function data = preprocessData(data, targetSize)
% Resize the images and scale the pixels to between 0 and 1. Also scale the
% corresponding bounding boxes.
for ii = 1:size(data,1)
I = data{ii,1};
imgSize = size(I);
% Convert an input image with single channel to 3 channels.
if numel(imgSize) < 3
I = repmat(I,1,1,3);
end
bboxes = data{ii,2};
I = im2single(imresize(I,targetSize(1:2)));
scale = targetSize(1:2)./imgSize(1:2);
bboxes = bboxresize(bboxes,scale);
data(ii, 1:2) = {I, bboxes};
end
end
function [XTrain, YTrain] = createBatchData(data, groundTruthBoxes, groundTruthClasses, classNames)
% Returns images combined along the batch dimension in XTrain and
% normalized bounding boxes concatenated with classIDs in YTrain
% Concatenate images along the batch dimension.
XTrain = cat(4, data{:,1});
% Get class IDs from the class names.
classNames = repmat({categorical(classNames')}, size(groundTruthClasses));
[~, classIndices] = cellfun(@(a,b)ismember(a,b), groundTruthClasses, classNames, 'UniformOutput', false);
% Append the label indexes and training image size to scaled bounding boxes
% and create a single cell array of responses.
combinedResponses = cellfun(@(bbox, classid)[bbox, classid], groundTruthBoxes, classIndices, 'UniformOutput', false);
len = max( cellfun(@(x)size(x,1), combinedResponses ) );
paddedBBoxes = cellfun( @(v) padarray(v,[len-size(v,1),0],0,'post'), combinedResponses, 'UniformOutput',false);
YTrain = cat(4, paddedBBoxes{:,1});
end
function currentLR = piecewiseLearningRateWithWarmup(iteration, epoch, learningRate, warmupPeriod, numEpochs)
% The piecewiseLearningRateWithWarmup function computes the current
% learning rate based on the iteration number.
persistent warmUpEpoch;
if iteration <= warmupPeriod
% Increase the learning rate for number of iterations in warmup period.
currentLR = learningRate * ((iteration/warmupPeriod)^4);
warmUpEpoch = epoch;
elseif iteration >= warmupPeriod && epoch < warmUpEpoch+floor(0.6*(numEpochs-warmUpEpoch))
% After warm up period, keep the learning rate constant if the remaining number of epochs is less than 60 percent.
currentLR = learningRate;
elseif epoch >= warmUpEpoch + floor(0.6*(numEpochs-warmUpEpoch)) && epoch < warmUpEpoch+floor(0.9*(numEpochs-warmUpEpoch))
% If the remaining number of epochs is more than 60 percent but less
% than 90 percent multiply the learning rate by 0.1.
currentLR = learningRate*0.1;
else
% If remaining epochs are more than 90 percent multiply the learning
% rate by 0.01.
currentLR = learningRate*0.01;
end
end
function [lossPlotter, learningRatePlotter] = configureTrainingProgressPlotter(f)
% Create the subplots to display the loss and learning rate.
figure(f);
clf
subplot(2,1,1);
ylabel('Learning Rate');
xlabel('Iteration');
learningRatePlotter = animatedline;
subplot(2,1,2);
ylabel('Total Loss');
xlabel('Iteration');
lossPlotter = animatedline;
end
function displayLossInfo(epoch, iteration, currentLR, lossInfo)
% Display loss information for each iteration.
disp("Epoch : " + epoch + " | Iteration : " + iteration + " | Learning Rate : " + currentLR + ...
" | Total Loss : " + double(gather(extractdata(lossInfo.totalLoss))) + ...
" | Box Loss : " + double(gather(extractdata(lossInfo.boxLoss))) + ...
" | Object Loss : " + double(gather(extractdata(lossInfo.objLoss))) + ...
" | Class Loss : " + double(gather(extractdata(lossInfo.clsLoss))));
end
function updatePlots(lossPlotter, learningRatePlotter, iteration, currentLR, totalLoss)
% Update loss and learning rate plots.
addpoints(lossPlotter, iteration, double(extractdata(gather(totalLoss))));
addpoints(learningRatePlotter, iteration, currentLR);
drawnow
end
function detector = downloadPretrainedYOLOv3Detector()
% Download a pretrained yolov3 detector.
if ~exist('yolov3SqueezeNetVehicleExample_21aSPKG.mat', 'file')
if ~exist('yolov3SqueezeNetVehicleExample_21aSPKG.zip', 'file')
disp('Downloading pretrained detector (8.9 MB)...');
pretrainedURL = 'https://ssd.mathworks.com/supportfiles/vision/data/yolov3SqueezeNetVehicleExample_21aSPKG.zip';
websave('yolov3SqueezeNetVehicleExample_21aSPKG.zip', pretrainedURL);
end
unzip('yolov3SqueezeNetVehicleExample_21aSPKG.zip');
end
pretrained = load("yolov3SqueezeNetVehicleExample_21aSPKG.mat");
detector = pretrained.detector;
end
function this = yolov2ObjectDetector(varargin)
narginchk(1,5);
if (isa(varargin{1,1},'yolov2ObjectDetector'))
clsname = 'yolov2ObjectDetector';
validateattributes(varargin{1,1},{clsname}, ...
{'scalar'}, mfilename);
this.ModelName = varargin{1,1}.ModelName;
this.Network = varargin{1,1}.Network;
this.TrainingImageSize = varargin{1,1}.TrainingImageSize;
this.FractionDownsampling = varargin{1,1}.FractionDownsampling;
this.WH2HW = varargin{1,1}.WH2HW;
%else
% Configure detector.
% this.ModelName = 'importedNetwork';
% params = this.parseDetectorInputs(varargin{:});
% this.TrainingImageSize = params.TrainingImageSize;
% this.Network = params.Network;
end
this.FilterBboxesFunctor = vision.internal.cnn.utils.FilterBboxesFunctor;
end
Answers (1)
T.Nikhil kumar
on 9 Jul 2022
Hey Dilara!
As per my understanding, you are facing an error with the preprocess function.
I assume you want to create a yolov3ObjectDetector and not yolov2Objectdetector
The yolov3ObjectDetector object that is used in this example creates an object detector for detecting objects in the image. It has an object function called preprocess, that is used to preprocess the training and testing images according to the network requirement.
I suspect that you have not installed the Computer Vision Toolbox Model for YOLOv3 Object Detection.
(Link: https://www.mathworks.com/matlabcentral/fileexchange/87959-computer-vision-toolbox-model-for-yolo-v3-object-detection)
To use this object and its functions in our code, we must install the Computer Vision Toolbox Model for YOLOv3 Object Detection. For installing it, follow these steps:
-In MATLAB, go to the home tab and click the Add-Ons button and select Get Add-Ons .
-A dialog box called Add-On explorer opens up.In the search tab here, search for “Computer Vision Toolbox Model for YOLOv3 Object Detection” ,Go ahead and install it.
This should resolve your error.
0 Comments
See Also
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!