Main Content

Define Custom Metric Object

Note

This topic explains how to define custom deep learning metric objects for your tasks. For a list of built-in metrics in Deep Learning Toolbox™, see Metrics. You can also specify custom metrics using a function handle. For more information, see Define Custom Metric Function.

In deep learning, a metric is a numerical value that evaluates the performance of a deep learning network. You can use metrics to monitor how well a model is performing by comparing the model predictions to the ground truth. Common deep learning metrics are accuracy, F-score, precision, recall, and root mean squared error.

 How To Decide Which Metric Type To Use

If Deep Learning Toolbox does not provide the metric that you need for your task and you cannot use a function handle, then you can define your own custom metric object using this topic as a guide. After you define the custom metric, you can specify the metric as the Metrics name-value argument in the trainingOptions function.

To define a custom deep learning metric class, you can use the template in this example, which takes you through these steps:

  1. Name the metric — Give the metric a name so that you can use it in MATLAB®.

  2. Declare the metric properties — Specify the public and private properties of the metric.

  3. Create a constructor function — Specify how to construct the metric and set default values.

  4. Create an initialization function (optional) — Specify how to initialize variables and run validation checks.

  5. Create a reset function — Specify how to reset the metric properties between iterations.

  6. Create an update function — Specify how to update metric properties between iterations.

  7. Create an aggregation function — Specify how to aggregate the metric values across multiple instances of the metric object.

  8. Create an evaluation function — Specify how to calculate the metric value for each iteration.

This example shows how to create a custom false positive rate (FPR) metric. This equation defines the metric:

FPR=False PositiveFalse Positive+True Negative

To see the completed metric class definition, see Completed Metric.

Metric Template

Copy the metric template into a new file in MATLAB. This template gives the structure of a metric class definition. It outlines:

  • The properties block for public metric properties. This block must contain the Name property.

  • The properties block for private metric properties. This block is optional.

  • The metric constructor function.

  • The optional initialize function.

  • The required reset, update, aggregate, and evaluate functions.

classdef myMetric < deep.Metric

    properties
        % (Required) Metric name.
        Name

        % Declare public metric properties here.

        % Any code can access these properties. Include here any properties
        % that you want to access or edit outside of the class.
    end

    properties (Access = private)
        % (Optional) Metric properties.

        % Declare private metric properties here.

        % Only members of the defining class can access these properties.
        % Include here properties that you do not want to edit outside
        % the class.
    end

    methods
        function metric = myMetric(args)
            % Create a myMetric object.
            % This function must have the same name as the class.

            % Define metric construction function here.
        end

        function metric = initialize(metric,batchY,batchT)
            % (Optional) Initialize metric.
            %
            % Use this function to initialize variables and run validation
            % checks.
            %
            % Inputs:
            %           metric - Metric to initialize
            %           batchY - Mini-batch of predictions
            %           batchT - Mini-batch of targets
            %
            % Output:
            %           metric - Initialized metric
            %
            % For networks with multiple outputs, replace batchY with
            % batchY1,...,batchYN and batchT with batchT1,...,batchTN,
            % where N is the number of network outputs. To create a metric
            % that supports any number of network outputs, replace batchY
            % and batchT with varargin.

            % Define metric initialization function here.
        end

        function metric = reset(metric)
            % Reset metric properties.
            %
            % Use this function to reset the metric properties between
            % iterations.
            %
            % Input:
            %           metric - Metric containing properties to reset
            %
            % Output:
            %           metric - Metric with reset properties

            % Define metric reset function here.
        end

        function metric = update(metric,batchY,batchT)
            % Update metric properties.
            %
            % Use this function to update metric properties that you use to
            % compute the final metric value.
            %
            % Inputs:
            %           metric - Metric containing properties to update
            %           batchY - Mini-batch of predictions
            %           batchT - Mini-batch of targets
            %
            % Output:
            %           metric - Metric with updated properties
            %
            % For networks with multiple outputs, replace batchY with
            % batchY1,...,batchYN and batchT with batchT1,...,batchTN,
            % where N is the number of network outputs. To create a metric
            % that supports any number of network outputs, replace batchY
            % and batchT with varargin.

            % Define metric update function here.
        end

        function metric = aggregate(metric,metric2)
            % Aggregate metric properties.
            %
            % Use this function to define how to aggregate properties from
            % multiple instances of the same metric object during parallel
            % training.
            %
            % Inputs:
            %           metric  - Metric containing properties to aggregate
            %           metric2 - Metric containing properties to aggregate
            %
            % Output:
            %           metric - Metric with aggregated properties
            %
            % Define metric aggregation function here.
        end

        function val = evaluate(metric)
            % Evaluate metric properties.
            %
            % Use this function to define how to use the metric properties
            % to compute the final metric value.
            %
            % Input:
            %           metric - Metric containing properties to use to
            %           evaluate the metric value
            %
            % Output:
            %           val - Evaluated metric value
            %
            % To return multiple metric values, replace val with val1,...
            % valN.

            % Define metric evaluation function here.
        end
    end
end

Metric Name

First, give the metric a name. In the first line of the class file, replace the existing name myMetric with fprMetric.

classdef fprMetric < deep.Metric
    ... 
end

Next, rename the myMetric constructor function (the first function in the methods section) so that it has the same name as the metric.

    methods
        function metric = fprMetric(args)
            ...
        end
    ...
    end

Save Metric

Save the metric class file in a new file with the name fprMetric and the .m extension. The file name must match the metric name. To use the metric, you must save the file in the current folder or in a folder on the MATLAB path.

Declare Properties

Declare the metric properties in the property sections. You can specify attributes in the class definition to customize the behavior of properties for specific purposes. This template defines two property types by setting their Access attribute. Use the Access attribute to control access to specific class properties.

  • properties — Any code can access these properties. This is the default properties block with the default property attributes. By default, the Access attribute is public.

  • properties (Access = private) — Only members of the defining class can access the property.

Declare Public Properties

Declare public properties by listing them in the properties section. This section must contain the Name property.

    properties
        % (Required) Metric name.
        Name
    end

Declare Private Properties

Declare private properties by listing them in the properties (Access = private) section. This metric requires twp properties to evaluate the value: true negatives (TNs) and false positives (FPs). Only the functions within the metric class require access to these values.

    properties (Access = private)
        % Define true negatives (TNs) and false positives (FPs).
        TrueNegatives
        FalsePositives
    end

Create Constructor Function

Create the function that constructs the metric and initializes the metric properties. If the software requires any variables to evaluate the metric value, then these variables must be inputs to the constructor function.

The FPR score metric constructor function requires the Name, NetworkOutput, and Maximize arguments. These arguments are optional when you use the constructor to create a metric object. Specify an args input to the fprMetric function that corresponds to the optional name-value arguments. Add a comment to explain the syntax of the function.

        function metric = fprMetric(args)
            % metric = fprMetric creates an fprMetric metric object.

            % metric = fprMetric(Name=name,NetworkOutput="out1",Maximize=0)
            % also specifies the optional Name option. By default,
            % the metric name is "FPR". By default,
            % the NetworkOutput is [], which corresponds to using all of
            % the network outputs. Maximize is set to 0 as the optimal value
            % occurs when the FPR is minimized.

            ...            
        end

Next, set the default values for the metric properties. Parse the input arguments using an arguments block. Specify the default metric name as "FPR", the default network output as [], and the Maximize property as 0. The metric name appears in plots and verbose output.

        function metric = fprMetric(args)
            ...
         
            arguments
                args.Name = "FPR"
                args.NetworkOutput = []
                args.Maximize = 0
            end
            ...
        end

Set the properties of the metric.

        function metric = fprMetric(args)
            ...
       
            % Set the metric name.
            metric.Name = args.Name;

            % To support this metric for use with multi-output networks, set
            % the network output.
            metric.NetworkOutput = args.NetworkOutput;

            % To support this metric for early stopping and returning the
            % best network, set the maximize property. 
            metric.Maximize = args.Maximize;
        end

View the completed constructor function. With this constructor function, the command fprMetric(Name="fpr") creates an FPR metric object with the name "fpr".

        function metric = fprMetric(args)
            % metric = fprMetric creates an fprMetric metric object.

            % metric = fprMetric(Name=name,NetworkOutput="out1",Maximize=0)
            % also specifies the optional Name option. By default,
            % the metric name is "FPR". By default,
            % the NetworkOutput is [], which corresponds to using all of
            % the network outputs. Maximize is set to 0 as the optimal value
            % occurs when the FPR is minimized.

            arguments
                args.Name = "FPR"
                args.NetworkOutput = []
                args.Maximize = 1
            end

            % Set the metric name.
            metric.Name = args.Name;

            % To support this metric for use with multi-output networks, set
            % the network output.
            metric.NetworkOutput = args.NetworkOutput;

            % To support this metric for early stopping and returning the
            % best network, set the maximize property. 
            metric.Maximize = args.Maximize;
        end

Create Initialization Function

Create the optional function that initializes variables and runs validation checks. For this example, the metric does not need the initialize function, so you can delete it. For an example of an initialize function, see Initialization Function.

Create Reset Function

Create the function that resets the metric properties. The software calls this function before each iteration. For the FPR score metric, reset the TN and FP values to zero at the start of each iteration.

        function metric = reset(metric)
            % metric = reset(metric) resets the metric properties.
            metric.TrueNegatives  = 0;
            metric.FalsePositives = 0;
        end

Create Update Function

Create the function that updates the metric properties that you use to compute the FPR score value. The software calls this function in each training and validation mini-batch.

In the update function, define these steps:

  1. Find the maximum score for each observation. The maximum score corresponds to the predicted class for each observation.

  2. Find the TN and FP values.

  3. Add the batch TN and FP values to the running total number of TNs and FPs.

        function metric = update(metric,batchY,batchT)
            % metric = update(metric,batchY,batchT) updates the metric
            % properties.

            % Find the channel (class) dimension.
            cDim = finddim(batchY,"C");

            % Find the maximum score, which corresponds to the predicted
            % class. Set the predicted class to 1 and all other classes to 0.
            batchY = batchY == max(batchY,[],cDim);

            % Find the TN and FP values for this batch.
            batchTrueNegatives = sum(~batchY & ~batchT, 2);
            batchFalsePositives = sum(batchY & ~batchT, 2);

            % Add the batch values to the running totals and update the metric
            % properties.
            metric.TrueNegatives = metric.TrueNegatives + batchTrueNegatives;
            metric.FalsePositives = metric.FalsePositives + batchFalsePositives;
        end

For categorical targets, the layout of the targets that the software passes to the metric depends on which function you want to use the metric with.

  • When using the metric with trainnet and the targets are categorical arrays, if the loss function is "index-crossentropy", then the software automatically converts the targets to numeric class indices and passes them to the metric. For other loss functions, the software converts the targets to one-hot encoded vectors and passes them to the metric.

  • When using the metric with testnet and the targets are categorical arrays, if the specified metrics include "index-crossentropy" but do not include "crossentropy", then the software converts the targets to numeric class indices and passes them to the metric. Otherwise, the software converts the targets to one-hot encoded vectors and passes them to the metric.

Create Aggregation Function

Create the function that specifies how to combine the metric values and properties across multiple instances of the metric. For example, the aggregate function defines how to aggregate properties from multiple instances of the same metric object during parallel training.

For this example, to combine the TN and FP values, add the values from each metric instance.

        function metric = aggregate(metric,metric2)
            % metric = aggregate(metric,metric2) aggregates the metric
            % properties across two instances of the metric.

            metric.TrueNegatives = metric.TrueNegatives + metric2.TrueNegatives;
            metric.FalsePositives = metric.FalsePositives + metric2.FalsePositives;
        end

Create Evaluation Function

Create the function that specifies how to compute the metric value in each iteration. This equation defines the FPR metric as:

FPR=False PositiveFalse Positive+True Negative

Implement this equation in the evaluate function. Find the macro average by taking the average across all the classes.

        function val = evaluate(metric)
            % val = evaluate(metric) uses the properties in metric to return the
            % evaluated metric value.

            % Extract TN and FP values.
            tn = metric.TrueNegatives;
            fp = metric.FalsePositives;

            % Compute the FPR value.
            val = mean(fp/(fp+tn+eps));
        end

As the denominator value of this metric can be zero, add eps to the denominator to prevent the metric returning a NaN value.

Completed Metric

View the completed metric class file.

Note

For more information about when the software calls each function in the class, see Function Call Order.

classdef fprMetric < deep.Metric

    properties
        % (Required) Metric name.
        Name
    end

    properties (Access = private)
        % Define true negatives (TNs) and false positives (FPs).
        TrueNegatives
        FalsePositives
    end

    methods
        function metric = fprMetric(args)
            % metric = fprMetric creates an fprMetric metric object.

            % metric = fprMetric(Name=name,NetworkOutput="out1",Maximize=0)
            % also specifies the optional Name option. By default,
            % the metric name is "FPR". By default,
            % the NetworkOutput is [], which corresponds to using all of
            % the network outputs. Maximize is set to 0 as the optimal value
            % occurs when the FPR value is minimized.

            arguments
                args.Name = "FPR"
                args.NetworkOutput = []
                args.Maximize = false
            end

            % Set the metric name value.
            metric.Name = args.Name;

            % To support this metric for use with multi-output networks, set
            % the network output.
            metric.NetworkOutput = args.NetworkOutput;

            % To support this metric for early stopping and returning the
            % best network, set the maximize property.
            metric.Maximize = args.Maximize;
        end

        function metric = reset(metric)
            % metric = reset(metric) resets the metric properties.
            metric.TrueNegatives  = 0;
            metric.FalsePositives = 0;
        end

        function metric = update(metric,batchY,batchT)
            % metric = update(metric,batchY,batchT) updates the metric
            % properties.

            % Find the channel (class) dimension.
            cDim = finddim(batchY,"C");

            % Find the maximum score, which corresponds to the predicted
            % class. Set the predicted class to 1 and all other classes to 0.
            batchY = batchY == max(batchY,[],cDim);

            % Find the TN and FP values for this batch.
            batchTrueNegatives = sum(~batchY & ~batchT, 2);
            batchFalsePositives = sum(batchY & ~batchT, 2);

            % Add the batch values to the running totals and update the metric
            % properties.
            metric.TrueNegatives = metric.TrueNegatives + batchTrueNegatives;
            metric.FalsePositives = metric.FalsePositives + batchFalsePositives;
        end

        function metric = aggregate(metric,metric2)
            % metric = aggregate(metric,metric2) aggregates the metric
            % properties across two instances of the metric.

            metric.TrueNegatives = metric.TrueNegatives + metric2.TrueNegatives;
            metric.FalsePositives = metric.FalsePositives + metric2.FalsePositives;
        end

        function val = evaluate(metric)
            % val = evaluate(metric) uses the properties in metric to return the
            % evaluated metric value.

            % Extract TN and FP values.
            tn = metric.TrueNegatives;
            fp = metric.FalsePositives;

            % Compute the FPR value.
            val = mean(fp./(fp+tn+eps));
        end
    end
end

Use Custom Metric During Training

You can use a custom metric in the same way as any other metric in Deep Learning Toolbox™. This section shows how to create and train a network for digit classification and track the FPR value.

Unzip the digit sample data and create an image datastore. The imageDatastore function automatically labels the images based on folder names.

unzip("DigitsData.zip")

imds = imageDatastore("DigitsData", ...
    IncludeSubfolders=true, ...
    LabelSource="foldernames");

Use a subset of the data as the validation set.

numTrainingFiles = 750;
[imdsTrain,imdsVal] = splitEachLabel(imds,numTrainingFiles,"randomize");

layers = [ ...
    imageInputLayer([28 28 1])
    convolution2dLayer(5,20)
    reluLayer
    maxPooling2dLayer(2,Stride=2)
    fullyConnectedLayer(10)
    softmaxLayer];

Create an fprMetric object.

metric = fprMetric(Name="FalsePositiveRate")
metric = 
  fprMetric with properties:

             Name: "FalsePositiveRate"
    NetworkOutput: []
         Maximize: 0

Specify the FPR metric in the training options. To plot the metric during training, set Plots to "training-progress". To output the values during training, set Verbose to true. Return the network that achieves the best FPR value.

options = trainingOptions("adam", ...
    MaxEpochs=5, ...
    Metrics=metric, ...
    ValidationData=imdsVal, ...
    ValidationFrequency=50, ...
    Verbose=true, ...
    Plots="training-progress", ...
    ObjectiveMetricName="FalsePositiveRate", ...
    OutputNetwork="best-validation");

Train the network using the trainnet function. The values for the training and validation sets appear in the plot.

net = trainnet(imdsTrain,layers,"crossentropy",options);
    Iteration    Epoch    TimeElapsed    LearnRate    TrainingLoss    ValidationLoss    TrainingFalsePositiveRate    ValidationFalsePositiveRate
    _________    _____    ___________    _________    ____________    ______________    _________________________    ___________________________
            0        0       00:00:06        0.001                            13.488                                                     0.10018
            1        1       00:00:06        0.001          13.974                                        0.10322                               
           50        1       00:00:24        0.001          2.7424            2.7448                     0.037368                       0.038889
          100        2       00:00:33        0.001          1.2965            1.2235                     0.027008                       0.023333
          150        3       00:00:41        0.001         0.64661           0.80412                     0.013953                       0.017867
          200        4       00:00:49        0.001         0.18627           0.53273                     0.006153                       0.012311
          250        5       00:00:57        0.001         0.16763           0.49371                    0.0060146                       0.012267
          290        5       00:01:05        0.001         0.25976           0.39347                    0.0062093                      0.0098222
Training stopped: Max epochs completed

See Also

| |

Related Topics