trainSSDObjectDetector
Syntax
Description
Train a Detector
trains a single shot multibox detector (SSD) using deep learning. You can train an SSD
detector to detect multiple object classes. Use this syntax to train either an untrained
or pretrained SSD object detection network. You can also use this syntax to fine-tune a
network with additional training data or to perform more training iterations to improve
detector accuracy.trainedDetector
= trainSSDObjectDetector(trainingData
,detector
,options
)
This function requires that you have Deep Learning Toolbox™. It is recommended that you also have Parallel Computing Toolbox™ to use with a CUDA®-enabled NVIDIA® GPU. For information about the supported compute capabilities, see GPU Computing Requirements (Parallel Computing Toolbox).
Resume Training a Detector
resumes training from a detector checkpoint.trainedDetector
= trainSSDObjectDetector(trainingData
,checkpoint
,options
)
Additional Properties
uses additional options specified by one or more name-value arguments and any of the
previous inputs.trainedDetector
= trainSSDObjectDetector(___,Name=Value
)
[
also returns information on the training progress, such as training loss and accuracy, for
each iteration.trainedDetector
,info
] = trainSSDObjectDetector(___)
Examples
Train SSD Object Detector
This example shows how to train an SSD object detector on a vehicle data set. The example then uses the trained detector for detecting vehicles in an image.
Load the training data into the workspace.
data = load("vehicleTrainingData.mat");
trainingData = data.vehicleTrainingData;
Specify the directory in which training samples are stored. Add full path to the filenames in training data.
dataDir = fullfile(toolboxdir("vision"),"visiondata"); trainingData.imageFilename = fullfile(dataDir,trainingData.imageFilename);
Create an image datastore using the files from the table.
imds = imageDatastore(trainingData.imageFilename);
Create a box label datastore using the label columns from the table.
blds = boxLabelDatastore(trainingData(:,2:end));
Combine the datastores.
ds = combine(imds,blds);
Specify a base network.
baseNetwork = imagePretrainedNetwork("resnet50");
Specify the names of the classes to detect.
classNames = "vehicle";
Specify the anchor boxes to use for training the network.
anchorBoxes = { ... [30 60; 60 30; 50 50; 100 100], ... [40 70; 70 40; 60 60; 120 120]};
Specify the names of the feature extraction layers to connect to the detection subnetwork.
layersToConnect = ["activation_22_relu" "activation_40_relu"];
Create an SSD object detector by using the ssdObjectDetector
function.
detector = ssdObjectDetector(baseNetwork,classNames,anchorBoxes, ...
DetectionNetworkSource=layersToConnect);
Specify the training options.
options = trainingOptions("sgdm", ... InitialLearnRate=0.001, ... MiniBatchSize=16, ... Verbose=true, ... MaxEpochs=30, ... Shuffle="never", ... VerboseFrequency=10);
Train the SSD object detector.
[detector,info] = trainSSDObjectDetector(ds,detector,options);
************************************************************************* Training an SSD Object Detector for the following object classes: * vehicle Training on single GPU. Initializing input data normalization. |=======================================================================================================| | Epoch | Iteration | Time Elapsed | Mini-batch | Mini-batch | Mini-batch | Base Learning | | | | (hh:mm:ss) | Loss | Accuracy | RMSE | Rate | |=======================================================================================================| | 1 | 1 | 00:00:00 | 54.8573 | 43.62% | 2.56 | 0.0010 | | 1 | 10 | 00:00:04 | 3.5709 | 98.90% | 1.65 | 0.0010 | | 2 | 20 | 00:00:07 | 3.2212 | 99.93% | 0.99 | 0.0010 | | 2 | 30 | 00:00:11 | 5.3296 | 98.89% | 1.51 | 0.0010 | | 3 | 40 | 00:00:15 | 4.5304 | 99.76% | 1.15 | 0.0010 | | 3 | 50 | 00:00:18 | 6.2441 | 98.47% | 1.28 | 0.0010 | | 4 | 60 | 00:00:22 | 3.6098 | 98.95% | 1.26 | 0.0010 | | 4 | 70 | 00:00:25 | 5.0838 | 98.81% | 1.23 | 0.0010 | | 5 | 80 | 00:00:29 | 4.9031 | 98.84% | 1.47 | 0.0010 | | 5 | 90 | 00:00:32 | 4.5612 | 99.91% | 0.99 | 0.0010 | | 6 | 100 | 00:00:36 | 3.0531 | 98.90% | 1.13 | 0.0010 | | 7 | 110 | 00:00:40 | 3.4591 | 99.95% | 0.74 | 0.0010 | | 7 | 120 | 00:00:43 | 3.1390 | 99.10% | 1.02 | 0.0010 | | 8 | 130 | 00:00:47 | 3.7369 | 99.77% | 1.14 | 0.0010 | | 8 | 140 | 00:00:50 | 3.3740 | 99.07% | 0.99 | 0.0010 | | 9 | 150 | 00:00:54 | 3.9861 | 98.99% | 1.35 | 0.0010 | | 9 | 160 | 00:00:57 | 3.1168 | 99.24% | 0.90 | 0.0010 | | 10 | 170 | 00:01:01 | 4.1509 | 98.92% | 1.40 | 0.0010 | | 10 | 180 | 00:01:04 | 1.8926 | 99.97% | 0.55 | 0.0010 | | 11 | 190 | 00:01:08 | 1.9993 | 99.13% | 0.88 | 0.0010 | | 12 | 200 | 00:01:11 | 1.0826 | 99.98% | 0.41 | 0.0010 | | 12 | 210 | 00:01:15 | 2.2213 | 99.37% | 0.78 | 0.0010 | | 13 | 220 | 00:01:18 | 2.4177 | 99.84% | 0.99 | 0.0010 | | 13 | 230 | 00:01:22 | 2.7282 | 99.32% | 0.81 | 0.0010 | | 14 | 240 | 00:01:26 | 3.7748 | 99.20% | 1.20 | 0.0010 | | 14 | 250 | 00:01:29 | 2.5727 | 99.33% | 0.77 | 0.0010 | | 15 | 260 | 00:01:33 | 3.8397 | 99.15% | 1.23 | 0.0010 | | 15 | 270 | 00:01:36 | 0.7277 | 99.99% | 0.43 | 0.0010 | | 16 | 280 | 00:01:40 | 2.9342 | 99.22% | 0.78 | 0.0010 | | 17 | 290 | 00:01:43 | 0.8218 | 99.99% | 0.39 | 0.0010 | | 17 | 300 | 00:01:47 | 1.9056 | 99.53% | 0.70 | 0.0010 | | 18 | 310 | 00:01:50 | 1.3405 | 99.93% | 0.71 | 0.0010 | | 18 | 320 | 00:01:54 | 1.5981 | 99.42% | 0.79 | 0.0010 | | 19 | 330 | 00:01:57 | 2.9586 | 99.30% | 1.02 | 0.0010 | | 19 | 340 | 00:02:01 | 1.2590 | 99.43% | 0.66 | 0.0010 | | 20 | 350 | 00:02:05 | 2.8346 | 99.27% | 1.04 | 0.0010 | | 20 | 360 | 00:02:08 | 0.9523 | 99.98% | 0.49 | 0.0010 | | 21 | 370 | 00:02:12 | 1.5545 | 99.26% | 0.74 | 0.0010 | | 22 | 380 | 00:02:16 | 0.7113 | 99.99% | 0.50 | 0.0010 | | 22 | 390 | 00:02:19 | 2.0585 | 99.58% | 0.61 | 0.0010 | | 23 | 400 | 00:02:23 | 1.0087 | 99.93% | 0.55 | 0.0010 | | 23 | 410 | 00:02:26 | 2.1091 | 99.48% | 0.71 | 0.0010 | | 24 | 420 | 00:02:30 | 2.9050 | 99.34% | 0.95 | 0.0010 | | 24 | 430 | 00:02:33 | 2.1874 | 99.44% | 0.64 | 0.0010 | | 25 | 440 | 00:02:37 | 2.9181 | 99.30% | 0.91 | 0.0010 | | 25 | 450 | 00:02:40 | 1.9445 | 99.94% | 0.39 | 0.0010 | | 26 | 460 | 00:02:44 | 3.2013 | 99.20% | 0.71 | 0.0010 | | 27 | 470 | 00:02:48 | 0.3986 | 99.98% | 0.37 | 0.0010 | | 27 | 480 | 00:02:51 | 1.3009 | 99.57% | 0.58 | 0.0010 | | 28 | 490 | 00:02:55 | 1.0920 | 99.96% | 0.60 | 0.0010 | | 28 | 500 | 00:02:58 | 1.7258 | 99.58% | 0.65 | 0.0010 | | 29 | 510 | 00:03:02 | 2.7426 | 99.42% | 0.86 | 0.0010 | | 29 | 520 | 00:03:05 | 1.4956 | 99.63% | 0.59 | 0.0010 | | 30 | 530 | 00:03:09 | 2.0561 | 99.39% | 0.85 | 0.0010 | | 30 | 540 | 00:03:12 | 1.0817 | 99.98% | 0.78 | 0.0010 | |=======================================================================================================| Training finished: Max epochs completed. Detector training complete. *************************************************************************
Verify the training accuracy by inspecting the training loss for each iteration.
figure plot(info.TrainingLoss) grid on xlabel("Number of Iterations") ylabel("Training Loss for Each Iteration")
Read a test image.
img = imread("detectcars.png");
Detect vehicles in the test image by using the trained SSD object detector.
[bboxes,scores] = detect(detector,img);
Display the detection results.
if(~isempty(bboxes)) img = insertObjectAnnotation(img,"rectangle",bboxes,scores); end figure imshow(img)
Input Arguments
trainingData
— Labeled ground truth images
datastore
Labeled ground truth images, specified as a datastore or a table.
If you use a datastore, your data must be set up so that calling the datastore with the
read
andreadall
functions returns a cell array or table with two or three columns. When the output contains two columns, the first column must contain bounding boxes, and the second column must contain labels, {boxes,labels}. When the output contains three columns, the second column must contain the bounding boxes, and the third column must contain the labels. In this case, the first column can contain any type of data. For example, the first column can contain images or point cloud data.data boxes labels The first column must be images.
M-by-4 matrices of bounding boxes of the form [x, y, width, height], where [x,y] represent the top-left coordinates of the bounding box.
The third column must be a cell array that contains M-by-1 categorical vectors containing object class names. All categorical data returned by the datastore must contain the same categories.
For more information, see Datastores for Deep Learning (Deep Learning Toolbox).
detector
— Untrained or pretrained SSD object detector
ssdObjectDetector
object
Untrained or pretrained SSD object detector, specified as a ssdObjectDetector
object.
options
— Training options
TrainingOptionsSGDM
object | TrainingOptionsRMSProp
object | TrainingOptionsADAM
object
Training options, specified as a TrainingOptionsSGDM
,
TrainingOptionsRMSProp
, or TrainingOptionsADAM
object returned by the trainingOptions
(Deep Learning Toolbox) function. To specify the
solver name and other options for network training, use the trainingOptions
(Deep Learning Toolbox) function.
Note
The trainSSDObjectDetector function does not support these training options:
Datastore inputs are not supported when you set the
DispatchInBackground
training option totrue
.Datastore inputs are not supported if the
Shuffle
training option is set to"never"
when theExecutionEnvironment
training option is"multi-gpu"
. For more information about using datastore for parallel training, see Preprocess Data in the Background or in Parallel (Deep Learning Toolbox).
checkpoint
— Saved detector checkpoint
ssdObjectDetector
object
Saved detector checkpoint, specified as an ssdObjectDetector
object. To periodically save a detector checkpoint during
training, specify CheckpointPath
. To control how frequently check
points are saved see the CheckPointFrequency
and
CheckPointFrequencyUnit
training options.
To load a checkpoint for a previously trained detector, load the MAT file from the
checkpoint path. For example, if the CheckpointPath
property of the
object specified by options
is "/checkpath"
, you
can load a checkpoint MAT file by using this code.
data = load("/checkpath/ssd_checkpoint__216__2018_11_16__13_34_30.mat");
checkpoint = data.detector;
The name of the MAT file includes the iteration number and timestamp of when the
detector checkpoint was saved. The detector is saved in the detector
variable of the file. Pass this file back into the trainSSDObjectDetector function:
ssdDetector = trainSSDObjectDetector(trainingData,checkpoint,options);
Name-Value Arguments
Specify optional pairs of arguments as
Name1=Value1,...,NameN=ValueN
, where Name
is
the argument name and Value
is the corresponding value.
Name-value arguments must appear after other arguments, but the order of the
pairs does not matter.
Before R2021a, use commas to separate each name and value, and enclose
Name
in quotes.
Example: "PositiveOverlapRange"
,[0.5 1]
sets the
vertical axis direction to up.
PositiveOverlapRange
— Range of bounding box overlap ratios
[0.5 1]
(default) | two-element vector
Range of bounding box overlap ratios between 0
and
1
, specified as a two-element vector. Anchor boxes that overlap
with ground truth bounding boxes within the specified range are used as positive
training samples. The function computes the overlap ratio using the
intersection-over-union between two bounding boxes.
NegativeOverlapRange
— Range of bounding box overlap ratios
[0 0.5]
(default) | two-element vector
Range of bounding box overlap ratios between 0
and
1
, specified as a two-element vector. Anchor boxes that overlap
with ground truth bounding boxes within the specified range are used as negative
training samples. The function computes the overlap ratio using the
intersection-over-union between two bounding boxes.
ExperimentManager
— Detector training experiment monitoring
"none"
(default) | experiments.Monitor
object
Detector training experiment monitoring, specified as an experiments.Monitor
(Deep Learning Toolbox) object for
use with the Experiment
Manager (Deep Learning Toolbox) app. You can use this object to track the progress of training, update
information fields in the training results table, record values of the metrics used by
the training, and to produce training plots. For an example using this app, see Train Object Detectors in Experiment Manager.
Information monitored during training:
Training loss at each iteration.
Training accuracy at each iteration.
Training root mean square error (RMSE) for the box regression layer.
Learning rate at each iteration.
Validation information when the training options
input
contains validation data:
Validation loss at each iteration.
Validation accuracy at each iteration.
Validation RMSE at each iteration.
Output Arguments
trainedDetector
— Trained SSD multibox object detector
ssdObjectDetector
object
Trained SSD object detector, returned as ssdObjectDetector
object. You can train a SSD object detector to detect
multiple object classes.
info
— Training progress information
structure array
Training progress information, returned as a structure array with eight fields. Each field corresponds to a stage of training.
TrainingLoss
— Training loss at each iteration is calculated as the sum of regression loss and classification loss. To compute the regression loss, thetrainSSDObjectDetector
function uses smooth L1 loss function. To compute the classification loss thetrainSSDObjectDetector
function uses the softmax and binary cross-entropy loss function.TrainingAccuracy
— Training set accuracy at each iteration.TrainingRMSE
— Training root mean squared error (RMSE) is the RMSE calculated from the training loss at each iteration.BaseLearnRate
— Learning rate at each iteration.ValidationLoss
— Validation loss at each iteration.ValidationAccuracy
— Validation accuracy at each iteration.ValidationRMSE
— Validation RMSE at each iteration.FinalValidationLoss
— Final validation loss at end of the training.FinalValidationRMSE
— Final validation RMSE at end of the training.
Each field is a numeric vector with one element per training iteration. Values that
have not been calculated at a specific iteration are assigned as NaN
.
The struct contains ValidationLoss
,
ValidationAccuracy
, ValidationRMSE
,
FinalValidationLoss
, and FinalValidationRMSE
fields only when options
specifies validation data.
References
[1] W. Liu, E. Anguelov, D. Erhan, C. Szegedy, S. Reed, C.Fu, and A.C. Berg. "SSD: Single Shot MultiBox Detector." European Conference on Computer Vision (ECCV), Springer Verlag, 2016
Version History
Introduced in R2020aR2022a: LayerGraph
input is not recommended
Starting in R2022a, use of LayerGraph
(Deep Learning Toolbox) object
to specify SSD object detection network as input to the trainSSDObjectDetector
is
not recommended.
The syntax
trainSSDObjectDetector(
,
specifying the input SSD detection network trainingData
,net
,options
)net
as a LayerGraph
(Deep Learning Toolbox) object
will be removed in a future release.
If your SSD object detection network is a LayerGraph
(Deep Learning Toolbox)
object, configure the network as a ssdObjectDetector
object by using the
ssdObjectDetector
function. Then, use the ssdObjectDetector
object as input to the
trainSSDObjectDetector
function for training.
See Also
Apps
Functions
trainingOptions
(Deep Learning Toolbox) |objectDetectorTrainingData
Objects
Topics
- Object Detection Using SSD Deep Learning
- Estimate Anchor Boxes From Training Data
- Code Generation for Object Detection by Using Single Shot Multibox Detector
- Train Object Detectors in Experiment Manager
- Getting Started with SSD Multibox Detection
- Get Started with Object Detection Using Deep Learning
- Choose an Object Detector
- Anchor Boxes for Object Detection
- Datastores for Deep Learning (Deep Learning Toolbox)
MATLAB Command
You clicked a link that corresponds to this MATLAB command:
Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)
Asia Pacific
- Australia (English)
- India (English)
- New Zealand (English)
- 中国
- 日本Japanese (日本語)
- 한국Korean (한국어)