- First, you need to train your YOLOv4 model on your custom dataset. You can use the MATLAB example code as a starting point and modify it according to your dataset.
- After training the model, you can use the evaluateDetectionPrecision function to obtain the precision metrics. This function takes the ground truth data and the detected data as inputs and returns the precision metrics.
- To obtain the precision plot, you can use the plotDetectionPrecision function. This function takes the output of the evaluateDetectionPrecision function as input and plots the precision metrics.
I have a custom dataset with 4 classes and I'm using the example given in MATLAB documentation for YOLov4 object detection. I am not able to obtain the plots of precision.
4 views (last 30 days)
Show older comments
data=load("train.mat");
defectDataset= data.gTruth;
rng("default");
shuffledIndices = randperm(height(defectDataset));
idx = floor(0.6 * length(shuffledIndices) );
trainingIdx = 1:idx;
trainingDataTbl = defectDataset(shuffledIndices(trainingIdx),:);
validationIdx = idx+1 : idx + 1 + floor(0.1 * length(shuffledIndices) );
validationDataTbl= defectDataset(shuffledIndices(validationIdx),:);
testIdx = validationIdx(end)+1 : length(shuffledIndices);
testDataTbl = defectDataset(shuffledIndices(testIdx),:);
defectDataset.imageFilename = fullfile(pwd,defectDataset.imageFilename);
imdsTrain = imageDatastore(trainingDataTbl{:,"imageFilename"});
bldsTrain = boxLabelDatastore(trainingDataTbl(:,["hole","slub","stain","knot"]));
imdsValidation = imageDatastore(validationDataTbl{:,"imageFilename"});
bldsValidation = boxLabelDatastore(validationDataTbl(:,["hole","slub","stain","knot"]));
imdsTest = imageDatastore(testDataTbl{:,"imageFilename"});
bldsTest = boxLabelDatastore(testDataTbl(:,["hole","slub","stain","knot"]));
trainingData = combine(imdsTrain,bldsTrain);
validationData = combine(imdsValidation,bldsValidation);
testData = combine(imdsTest,bldsTest);
data = read(trainingData);
I = data{1};
bbox = data{2};
labels = data{3};
annotatedImage = insertObjectAnnotation(I,"Rectangle",bbox,labels);
annotatedImage = imresize(annotatedImage,2);
figure
imshow(annotatedImage)
inputSize = [608 608 3];
className = ["hole", "slub", "stain", "knot"];
rng("default")
trainingDataForEstimation = transform(trainingData, @(data) preprocessData(data, inputSize));
numAnchors = 9;
[anchors,meanIoU] = estimateAnchorBoxes(trainingDataForEstimation,numAnchors);
area = anchors(:, 1).*anchors(:,2);
[~,idx] = sort(area,"descend");
anchors = anchors(idx,:);
anchorBoxes = {anchors(1:3,:)
anchors(4:6,:)
anchors(7:9,:)
};
detector = yolov4ObjectDetector("csp-darknet53-coco",className,anchorBoxes,InputSize=inputSize);
augmentedTrainingData = transform(trainingData,@augmentData);
augmentedData = cell(4,1);
for k = 1:4
data = read(augmentedTrainingData);
augmentedData{k} = insertShape(data{1},"rectangle",data{2});
reset(augmentedTrainingData);
end
figure
montage(augmentedData,BorderSize=10)
options = trainingOptions("adam",...
GradientDecayFactor=0.9,...
SquaredGradientDecayFactor=0.999,...
InitialLearnRate=0.001,...
LearnRateSchedule="none",...
MiniBatchSize=8,...
L2Regularization=0.0005,...
MaxEpochs=20,...
BatchNormalizationStatistics="moving",...
DispatchInBackground=true,...
ResetInputNormalization=false,...
Shuffle="every-epoch",...
VerboseFrequency=20,...
ValidationFrequency=1000,...
CheckpointPath=tempdir,...
ValidationData=validationData);
doTraining = false;
if doTraining
% Train the YOLO v4 detector.
[detector,info] = trainYOLOv4ObjectDetector(augmentedTrainingData,detector,options);
else
% Load pretrained detector for the example.
detector = downloadPretrainedYOLOv4Detector();
end
% Detect objects in the test data using the trained or pretrained detector.
detectionResults = detect(detector,testData);
% Evaluate the detection performance.
[ap,recall,precision] = evaluateDetectionPrecision(detectionResults,testData);
% Plot the precision-recall curve.
figure
plot(recall,precision)
xlabel("Recall")
ylabel("Precision")
grid on
title(sprintf("Average Precision = %.2f",ap))
% Show detection results on test images
resultsDir = fullfile(pwd, 'detection_results');
mkdir(resultsDir);
for i = 1:numel(testDataTbl.imageFilename)
I = imread(testDataTbl.imageFilename{i});
I = insertObjectAnnotation(I, 'rectangle', detectionResults(i).Boxes, detectionResults(i).Scores);
imwrite(I, fullfile(resultsDir, [num2str(i) '.jpg']));
end
function data = preprocessData(data,targetSize)
% Resize the images and scale the pixels to between 0 and 1. Also scale the
% corresponding bounding boxes.
for ii = 1:size(data,1)
I = data{ii,1};
imgSize = size(I);
bboxes = data{ii,2};
I = im2single(imresize(I,targetSize(1:2)));
scale = targetSize(1:2)./imgSize(1:2);
bboxes = bboxresize(bboxes,scale);
data(ii,1:2) = {I,bboxes};
end
end
function detector = downloadPretrainedYOLOv4Detector()
% Download a pretrained yolov4 detector.
if ~exist("yolov4CSPDarknet53VehicleExample_22a.mat", "file")
if ~exist("yolov4CSPDarknet53VehicleExample_22a.zip", "file")
disp("Downloading pretrained detector...");
pretrainedURL = "https://ssd.mathworks.com/supportfiles/vision/data/yolov4CSPDarknet53VehicleExample_22a.zip";
websave("yolov4CSPDarknet53VehicleExample_22a.zip", pretrainedURL);
end
unzip("yolov4CSPDarknet53VehicleExample_22a.zip");
end
pretrained = load("yolov4CSPDarknet53VehicleExample_22a.mat");
detector = pretrained.detector;
end
function data = augmentData(A)
% Apply random horizontal flipping, and random X/Y scaling. Boxes that get
% scaled outside the bounds are clipped if the overlap is above 0.25. Also,
% jitter image color.
data = cell(size(A));
for ii = 1:size(A,1)
I = A{ii,1};
bboxes = A{ii,2};
labels = A{ii,3};
sz = size(I);
if numel(sz) == 3 && sz(3) == 3
I = jitterColorHSV(I,...
contrast=0.0,...
Hue=0.1,...
Saturation=0.2,...
Brightness=0.2);
end
% Randomly flip image.
tform = randomAffine2d(XReflection=true,Scale=[1 1.1]);
rout = affineOutputView(sz,tform,BoundsStyle="centerOutput");
I = imwarp(I,tform,OutputView=rout);
% Apply same transform to boxes.
[bboxes,indices] = bboxwarp(bboxes,tform,rout,OverlapThreshold=0.25);
labels = labels(indices);
% Return original data only when all boxes are removed by warping.
if isempty(indices)
data(ii,:) = A(ii,:);
continue;
end
data(ii,:) = {I, bboxes, labels};
end
end
0 Comments
Answers (1)
Shubham
on 3 Apr 2023
Hi Atishay,
To obtain the precision plot in MATLAB for your custom YOLOv4 object detection model, you can follow these steps:
Here's an example code snippet that shows how to obtain the precision plot:
% Load the ground truth data and detected data
load('groundTruthData.mat');
load('detectedData.mat');
% Evaluate the detection precision
metrics = evaluateDetectionPrecision(detectedData, groundTruthData);
% Plot the precision
plotDetectionPrecision(metrics);
Note that you need to replace groundTruthData.mat and detectedData.mat with the file names of your ground truth and detected data files, respectively.
0 Comments
See Also
Categories
Find more on Recognition, Object Detection, and Semantic Segmentation in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!