Main Content

trainPoseMaskRCNN

Train Pose Mask R-CNN network to perform pose estimation

Since R2024a

    Description

    net = trainPoseMaskRCNN(trainingData,network,trainingMode,options) trains a Pose Mask R-CNN network to perform six-degrees-of-freedom (6-DoF) pose estimation for multiple object classes.

    Note

    This functionality requires Deep Learning Toolbox™ and the Computer Vision Toolbox™ Model for Pose Mask R-CNN 6-DoF Object Pose Estimation. To use this functionality in parallel, you must have a license for Parallel Computing Toolbox™ and a CUDA® enabled NVIDIA® GPU. For information about the supported compute capabilities, see GPU Computing Requirements (Parallel Computing Toolbox).

    [net,info] = trainPoseMaskRCNN(trainingData,network,trainingMode,options) also returns information on the training progress, such as training loss and accuracy, for each iteration.

    [___] = trainPoseMaskRCNN(___,Name=Value) specifies options using name-value arguments in addition to any combination of arguments from previous syntaxes. For example, NumRegionsToSample=64 specifies for the trainPoseMaskRCNN function to sample 64 region proposals from each training image.

    Input Arguments

    collapse all

    Labeled training data, specified as a datastore. Your data must be set up so that reading the datastore using the read or readall function returns a 7-by-B cell array, where B is the number of images in the training data. This table describes the format of each column.

    Input Image DataDepth Image DataBounding BoxesLabelsMasks6-DoF Object PosesCamera Intrinsic Parameters

    RGB image, stored as an H-by-W-by-3 numeric array, or grayscale image, stored as an H-by-W numeric matrix.

    Depth image, stored as an H-by-W numeric matrix.

    Bounding boxes, stored as an M-by-4 matrix. M is the number of objects detected in the image. Each row of the matrix specifies the bounding box of the corresponding object in the form [x y width height], where x and y are the position of the top-left corner of the bounding box.

    Object class names, stored as an M-by-1 categorical vector. All categorical data returned by the datastore must contain the same categories.

    Binary masks, stored as an H-by-W-by-M. Each page of the array specifies the binary mask of the corresponding detected object.

    6-DoF object poses, stored as an M-by-1 vector of rigidtform3d objects.

    Camera intrinsic parameters, stored as a scalar cameraIntrinsics object. All images in the training data set must have the same camera intrinsic parameters.

    Pose Mask R-CNN pose estimation network to train, specified as a posemaskrcnn object.

    Training mode that specifies the losses to apply during training, specified as one of these options:

    • "mask" — Applies classification, bounding box regression, and segmentation mask losses.

    • "pose-and-mask" — Applies classification, bounding box regression, segmentation mask, rotation, and translation losses.

    You must train a Pose Mask R-CNN pose estimation network in two stages. First, train the network on only the instance segmentation task in "mask" mode. Then, train the output network on the pose estimation task in "pose-and-mask" mode.

    Training options, specified as a TrainingOptionsSGDM, TrainingOptionsRMSProp, or TrainingOptionsADAM object returned by the trainingOptions (Deep Learning Toolbox) function. To specify the solver name and other options for network training, use the trainingOptions function.

    Name-Value Arguments

    Specify optional pairs of arguments as Name1=Value1,...,NameN=ValueN, where Name is the argument name and Value is the corresponding value. Name-value arguments must appear after other arguments, but the order of the pairs does not matter.

    Example: trainPoseMaskRCNN(trainingData,network,options,trainingMode,NumRegionsToSample=64) samples 64 region proposals from each training image.

    Bounding box overlap ratios for positive training samples, specified as a two-element numeric vector with values in the range [0, 1]. The function compares region proposals to ground truth bounding boxes, and uses the proposals with overlap ratios within the specified range as positive training samples.

    Given two bounding boxes, A and B, the overlap ratio is:

    area(AB)area(AB)

    Bounding box overlap ratios for negative training samples, specified as a two-element numeric vector with values in the range [0, 1]. The function compares region proposals to ground truth bounding boxes, and uses the proposals with overlap ratios within the specified range as negative training samples.

    Given two bounding boxes, A and B, the overlap ratio is:

    area(AB)area(AB)

    Maximum number of strongest region proposals to use for generating training samples, specified as a positive integer. Reduce this value to speed up processing time at the cost of training accuracy. To use all region proposals, set this value to Inf.

    Number of region proposals to randomly sample from each training image, specified as a positive integer. Specifying a smaller number of regions to sample can reduce memory usage and increase training speed, but can also decrease training accuracy.

    Subnetworks to freeze during training, specified as one of these values:

    • "none" — Do not freeze subnetworks.

    • "backbone" — Freeze the feature extraction subnetwork, including the layers following the ROI align layer.

    • "rpn" — Freeze the region proposal subnetwork.

    • ["backbone" "rpn"] — Freeze both the feature extraction and the region proposal subnetworks.

    The weight of layers in frozen subnetworks does not change during training.

    Training experiment monitor, specified as an experiments.Monitor (Deep Learning Toolbox) object for use with the Experiment Manager (Deep Learning Toolbox) app. You can use this object to track the progress of training, update information fields in the training results table, record values of the metrics used by the training, and to produce training plots.

    Information monitored during training:

    • Training loss at each iteration

    • Training accuracy at each iteration

    • Training root mean square error (RMSE) for the box regression layer

    • Training loss for the mask segmentation branch

    • Training rotation loss at each iteration

    • Training translation loss at each iteration

    • Learning rate at each iteration

    Validation information when the training options input argument contains validation data:

    • Validation loss at each iteration

    • Validation accuracy at each iteration

    • Validation RMSE at each iteration

    • Validation loss for the mask segmentation branch

    • Validation rotation loss at each iteration

    • Validation translation loss at each iteration

    Translation loss weight, specified as a positive scalar. Adjust the TranslationLossWeight value so that the translation loss TrainingTranslationLoss after weighting is the same order of magnitude as the values of TrainingRPNLoss, TrainingRMSE, TrainingClassLoss, and TrainingMaskLoss.

    Rotation loss weight, specified as a positive scalar. Adjust the RotationLossWeight value so that the rotation loss TrainingRotationLoss after weighting is the same order of magnitude as the values of TrainingRPNLoss, TrainingRMSE, TrainingClassLoss, and TrainingMaskLoss.

    Bounding box regression loss weight, specified as a positive scalar. In most cases, train the Pose Mask R-CNN network with the default value.

    Mask loss weight to use for instance segmentation, specified as a positive scalar. In most cases, train the Pose Mask R-CNN network with the default value.

    Region proposal network (RPN) training loss weight, specified as a positive scalar. In most cases, train the Pose Mask R-CNN network with the default value.

    Output Arguments

    collapse all

    Trained Pose Mask R-CNN pose estimation network, returned as a posemaskrcnn object.

    Training progress information, returned as a structure. Each field corresponds to a stage of training.

    • TrainingLoss — Training loss at each iteration. The loss is the combination of the RPN, classification, regression, and mask segmentation losses used to train the Pose Mask R-CNN pose estimation network.

    • TrainingRPNLoss — Total RPN loss at the end of each iteration.

    • TrainingRMSE — Training RMSE for the box regression layer at the end of each iteration.

    • TrainingClassLoss — Training cross-entropy loss for the classification layer at the end of each iteration.

    • TrainingMaskLoss — Training cross-entropy loss for the mask segmentation branch at the end of each iteration.

    • TrainingRotationLoss — Training rotation prediction branch loss at the end of each iteration.

    • TrainingTranslationLoss — Training regression loss for the translation prediction branch at the end of each iteration.

    • LearnRate — Learning rate at each iteration.

    • ValidationLoss — Validation loss at each iteration.

    • ValidationRPNLoss — Validation RPN loss at each iteration.

    • ValidationRMSE — Validation RMSE at each iteration.

    • ValidationClassLoss — Validation classification loss at each iteration.

    • ValidationMaskLoss — Validation cross-entropy loss for the mask segmentation branch at each iteration.

    • ValidationRotationLoss — Validation rotation loss at each iteration.

    • ValidationTranslationLoss — Validation translation loss at each iteration.

    Each field contains a numeric vector with one element per training iteration. If the function does not calculate a metric for a specific iteration, it assigns a value of NaN for that iteration. The structure contains the ValidationLoss, ValidationRPNLoss, ValidationRMSE, and ValidationMaskLoss fields only when options specifies validation data.

    Tips

    • The trainPoseMaskRCNN function has a high GPU memory requirement. It is recommended to train a Pose Mask R-CNN network with at least 12 GB of available memory.

    • To reduce memory consumption during training, you can decrease the value of the NumRegionsToSample name-value argument to limit the number of proposals from the region proposal stage. Note that this also reduces accuracy and increases convergence time.

    Version History

    Introduced in R2024a