Main Content

Define Custom Learning Rate Schedule

When you train a neural network using the trainnet function, the LearnRateSchedule argument of the trainingOptions function provides several options for customizing the learning rate schedule. It provides built-in schedules such as "piecewise" and "warmup". You can also use learning rate schedule objects that you can customize further such as piecewiseLearnRate and warmupLearnRate objects. If these built-in options do not provide the functionality that you need, you can specify the learning rate schedule as a function of the epoch number using a function handle.

If you need additional flexibility, for example, you want to use a learning rate schedule that changes the learning rate between iterations, or you want to use a learning rate schedule that requires updating and maintaining a state, then you can define your own custom learning rate schedule object using this example as a guide.

To define a custom learning rate schedule, you can use the template provided in this example, which takes you through these steps:

  • Name the schedule — Give the schedule a name so that you can use it in MATLAB®.

  • Declare the schedule properties (optional) — Specify the properties of the schedule.

  • Create the constructor function (optional) — Specify how to construct the schedule and initialize its properties.

  • Create the update function — Specify how the schedule calculates the learning rate.

This example shows how to define a time-based decay learning rate schedule and use it to train a neural network. A time-based decay learning rate schedule object updates the learning rate every iteration using a decay rule.

The time-based decay learning rate schedule uses this formula to calculate the learning rate:

α=α01+λ(k1),

where:

Custom Learning Rate Schedule Template

Copy the custom learning rates schedule template into a new file in MATLAB. This template gives the structure of a schedule class definition. It outlines:

  • The optional properties block for the schedule properties.

  • The optional schedule constructor function.

  • The update function.

classdef myLearnRateSchedule < deep.LearnRateSchedule

    properties
        % (Optional) Schedule properties.

        % Declare schedule properties here.
    end

    methods
        function schedule = myLearnRateSchedule()
            % (Optional) Create a myLearnRateSchedule.
            % This function must have the same name as the class.

            % Define schedule constructor function here.
        end

        function [schedule,learnRate] = update(schedule,initialLearnRate,iteration,epoch)
            %UPDATE Update learning rate schedule

            % Define schedule update function here.
        end
    end
end

Name Schedule and Specify Superclass

First, give the schedule a name. In the first line of the class file, replace the existing name myLearnRateSchedule with timeBasedDecayLearnRate.

classdef timeBasedDecayLearnRate < deep.LearnRateSchedule

    ...
end

Next, rename the myLearnRateSchedule constructor function (the first function in the methods section) so that it has the same name as the schedule.

    methods
        function schedule = timeBasedDecayLearnRate()           
            ...
        end

        ...
     end

Save the Schedule

Save the schedule class file in a new file named timeBasedDecayLearnRate.m. The file name must match the schedule name. To use the schedule, you must save the file in the current folder or in a folder on the MATLAB path.

Declare Properties

Declare the schedule properties in the properties section.

By default, custom learning rate schedules have these properties. Do not declare these properties in the properties section.

PropertyDescription
FrequencyUnit

How often the schedule updates the learning rate, specified as "epoch" (default) or "iteration".

If FrequencyUnit is "epoch", then the software updates the learning rate every epoch and each iteration of the epoch uses the same learning rate. If FrequencyUnit is "iteration", then the software updates the learning rate every iteration.

NumStepsNumber of steps the learning rate schedule takes before it is complete, specified as a positive integer or Inf. For learning rate schedules that continue indefinitely (also known as infinite learning rate schedules), this property is Inf.

A time-based decay learning rate schedule requires one additional property: the decay value. Declare the decay value in the properties block.

    properties
        % Schedule properties
    
        Decay
    end

Create Constructor Function

Create the function that constructs the schedule and initializes the schedule properties. Specify any variables required to create the schedule as inputs to the constructor function.

The time-based decay learning rate schedule constructor function requires one argument (the decay). Specify one input argument named decay in the timeBasedDecayLearnRate function that corresponds to the decay. Add a comment to the top of the function that explains the syntax of the function.

        function schedule = timeBasedDecayLearnRate(decay)
            % timeBasedDecayLearnRate Time-based decay learning rate 
            % schedule
            %   schedule = timeBasedDecayLearnRate(decay) creates a
            %   time-based decay learning rate schedule with the specified
            %   decay.

            ...
        end

Initialize Schedule Properties

Initialize the schedule properties in the constructor function. Replace the comment % Define schedule constructor function here with code that initializes the schedule properties.

  • Because the time-based decay learning rate schedule updates the learning rate each iteration, set the FrequencyUnit property to "iteration".

  • Because the time-based decay learning rate schedule is infinite, set the NumSteps property to Inf.

  • Set the schedule Decay property to the decay argument.

            % Set schedule properties.
            schedule.FrequencyUnit = "iteration";
            schedule.NumSteps = Inf;
            schedule.Decay = decay;

View the completed constructor function.

        function schedule = timeBasedDecayLearnRate(decay) 
            % timeBasedDecayLearnRate Time-based decay learning rate
            % schedule
            %   schedule = timeBasedDecayLearnRate(decay) creates a
            %   time-based decay learning rate schedule with the specified
            %   decay.
    
    
            % Set schedule properties.
            schedule.FrequencyUnit = "iteration";
            schedule.NumSteps = Inf;
            schedule.Decay = decay;
        end

With this constructor function, the command timeBasedDecayLearnRate(0.01) creates a time-based decay learning rate schedule with a decay value of 0.01.

Create Update Function

Create the function that updates the learning rate.

Create a function named update that updates the learning rate schedule properties and also returns the calculated learning rate value.

The update function has the syntax [schedule,learnRate] = update(schedule,initialLearnRate,iteration,epoch), where:

  • schedule is an instance of the learning rate schedule.

  • learnRate is the calculated learning rate value.

  • initialLearnRate is the initial learning rate.

  • iteration is the iteration number.

  • epoch is the epoch number.

The time-based decay learning rate schedule uses this formula to calculate the learning rate:

α=α01+λ(k1),

where:

Implement this operation in update. The schedule does not require updating any state values, so the output schedule is unchanged.

Because a time-based decay learning rate schedule does not require the epoch number, the syntax for update for the schedule is [schedule,learnRate] = update(schedule,initialLearnRate,iteration,~). Because the time-based decay learning rate schedule is not finite, there is no need to update the IsDone property.

        function [schedule,learnRate] = update(schedule,initialLearnRate,iteration,~)
            % UPDATE Update learning rate schedule
            %   [schedule,learnRate] = update(schedule,initialLearnRate,iteration,~)
            %   calculates the learning rate for the specified iteration
            %   and also returns the updated schedule object.

            % Calculate learning rate.
            decay = schedule.Decay;
            learnRate = initialLearnRate / (1 + decay*(iteration-1));
        end

Completed Learning Rate Schedule

Vie the completed learning rate schedule class file.

classdef timeBasedDecayLearnRate < deep.LearnRateSchedule
    % timeBasedDecayLearnRate Time-based decay learning rate schedule

    properties
        % Schedule properties
    
        Decay
    end

    methods
        function schedule = timeBasedDecayLearnRate(decay) 
            % timeBasedDecayLearnRate Time-based decay learning rate
            % schedule
            %   schedule = timeBasedDecayLearnRate(decay) creates a
            %   time-based decay learning rate schedule with the specified
            %   decay.
    
    
            % Set schedule properties.
            schedule.FrequencyUnit = "iteration";
            schedule.NumSteps = Inf;
            schedule.Decay = decay;
        end

        function [schedule,learnRate] = update(schedule,initialLearnRate,iteration,~)
            % UPDATE Update learning rate schedule
            %   [schedule,learnRate] = update(schedule,initialLearnRate,iteration,~)
            %   calculates the learning rate for the specified iteration
            %   and also returns the updated schedule object.

            % Calculate learning rate.
            decay = schedule.Decay;
            learnRate = initialLearnRate / (1 + decay*(iteration-1));
        end
    end
end

Train Using Custom Learning Rate Schedule Object

You can use a custom learning rate schedule object in the same way as any other learning rate schedule object in the trainingOptions function. This example shows how to create and train a network for digit classification using a time-based decay learning rate schedule object you defined earlier.

Load the example training data.

load DigitsDataTrain

Create a layer array.

layers = [ 
    imageInputLayer([28 28 1])
    convolution2dLayer(5,20)
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(10)
    softmaxLayer];

Create an instance of a time-based decay learning rate schedule object with a decay value of 0.01.

schedule = timeBasedDecayLearnRate(0.01)
schedule = 
  timeBasedDecayLearnRate with properties:

            Decay: 0.0100
    FrequencyUnit: "iteration"
         NumSteps: Inf

Specify the training options. To train using the learning rate schedule object, set the LearnRateSchedule training option to the object.

options = trainingOptions("sgdm", ...
    MaxEpochs=10, ...
    LearnRateSchedule=schedule, ...
    Metrics="accuracy");

Train the neural network using the trainnet function. For classification, use index cross-entropy loss. By default, the trainnet function uses a GPU if one is available. Training on a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the trainnet function uses the CPU. To specify the execution environment, use the ExecutionEnvironment training option.

To get the information about training, such as the learning rate value for each iteration, use the second output of the trainnet function.

[net,info] = trainnet(XTrain,labelsTrain,layers,"index-crossentropy",options);
    Iteration    Epoch    TimeElapsed    LearnRate    TrainingLoss    TrainingAccuracy
    _________    _____    ___________    _________    ____________    ________________
            1        1       00:00:03         0.01          2.5434              10.938
           50        2       00:00:08    0.0067114         0.36849              89.062
          100        3       00:00:11    0.0050251         0.18073               93.75
          150        4       00:00:15    0.0040161        0.093009              97.656
          200        6       00:00:19    0.0033445        0.079959              99.219
          250        7       00:00:23    0.0028653        0.062385              99.219
          300        8       00:00:27    0.0025063        0.033808                 100
          350        9       00:00:31    0.0022272          0.0498                 100
          390       10       00:00:34     0.002045        0.047401                 100
Training stopped: Max epochs completed

Extract the learning rate information from the training information and visualize it in a plot.

figure
plot(info.TrainingHistory.LearnRate)
ylim([0 inf])
xlabel("Iteration")
ylabel("Learning Rate")

Figure contains an axes object. The axes object with xlabel Iteration, ylabel Learning Rate contains an object of type line.

Test the neural network using the testnet function. For single-label classification, evaluate the accuracy. The accuracy is the percentage of correct predictions. By default, the testnet function uses a GPU if one is available. To select the execution environment manually, use the ExecutionEnvironment argument of the testnet function.

load DigitsDataTest
classNames = categories(labelsTest);
accuracy = testnet(net,XTest,labelsTest,"accuracy")
accuracy = 
97.7000

See Also

| | | | | | | |

Related Topics