Main Content

Define Model Gradients Function for Custom Training Loop

When training a deep learning model with a custom training loop, the software minimizes the loss with respect to the learnable parameters. To minimize the loss, the software uses the gradients of the loss with respect to the learnable parameters. To calculate these gradients using automatic differentiation, you must define a model gradients function.

For an example showing how to train deep learning model with a dlnetwork object, see Train Network Using Custom Training Loop. For an example showing how to training a deep learning model defined as a function, see Train Network Using Model Function.

Create Model Gradients Function for Models Defined as a dlnetwork Object

If you have a deep learning model defined as a dlnetwork object, then create a model gradients function that takes the dlnetwork object as input.

For models specified as a dlnetwork object, create a function of the form gradients = modelGradients(dlnet,dlX,T), where dlnet is the network, dlX contains the input predictors, T contains the targets, and gradients contains the returned gradients. Optionally, you can pass extra arguments to the gradients function (for example, if the loss function requires extra information), or return extra arguments (for example, metrics for plotting the training progress).

For example, this function returns the gradients and the cross entropy loss for the specified dlnetwork object dlnet, input data dlX, and targets T.

function [gradients, loss] = modelGradients(dlnet, dlX, T)

    % Forward data through the dlnetwork object.
    dlY = forward(dlnet,dlX);

    % Compute loss.
    loss = crossentropy(dlX,T);

    % Compute gradients.
    gradients = dlgradient(loss,dlnet);

end

Create Model Gradients Function for Models Defined as a Function

If you have a deep learning model defined as a function of the form dlY = model(parameters,dlX), then create a function of the form gradients = modelGradients(parameters,dlX,T), where parameters is a struct containing the learnable parameters, dlX are the input predictors, T are the targets, and gradients are the returned gradients. Optionally, you can pass extra arguments to the gradients function (for example, if the loss function requires extra information), or return extra arguments (for example, metrics for plotting the training progress). For models defined as a function, you do not need to pass a network as an input argument.

For example, this function returns the gradients and the cross entropy loss for the deep learning model function model with the specified learnable parameters parameters, input data dlX, and targets T.

function [gradients, loss] = modelGradients(parameters, dlX, T)

    % Forward data through the model function.
    dlY = model(parameters,dlX);

    % Compute loss.
    loss = crossentropy(dlX,T);

    % Compute gradients.
    gradients = dlgradient(loss,parameters);

end

Evaluate Model Gradients Function

To evaluate the model gradients using automatic differentiation, use the dlfeval function which evaluates a function with automatic differentiation enabled. For the first input of dlfeval, pass the model gradients function specified as a function handle and for the following inputs, pass the required variables for the model gradients function. For the outputs of the dlfeval function, specify the same outputs as the model gradients function.

For example, to evaluate the model gradients function modelGradients with a dlnetwork object dlnet, input data dlX and T, and return the model gradients and loss, use the command:

[gradients, loss] = dlfeval(@modelGradients,dlnet,dlX,T);

Similarly, to evaluate the model gradients function modelGradients using a model function with learnable parameters specified by the struct parameters, input data dlX and T, and return the model gradients and loss, use the command:

[gradients, loss] = dlfeval(@modelGradients,parameters,dlX,T);

Update Learnable Parameters Using Gradients

To update the learnable parameters using the gradients, you can use the following functions:

FunctionDescription
adamupdateUpdate parameters using adaptive moment estimation (Adam)
rmspropupdateUpdate parameters using root mean squared propagation (RMSProp)
sgdmupdateUpdate parameters using stochastic gradient descent with momentum (SGDM)
dlupdateUpdate parameters using custom function

For example, to update the learnable parameters of a dlnetwork object dlnet using the adamupdate function, use the command:

[dlnet,trailingAvg,trailingAvgSq] = adamupdate(dlnet,gradients, ...
    trailingAvg,trailingAverageSq,iteration);
where gradients is the output of the model gradients function, and trailingAvg, trailingAvgSq, and iteration are the hyperparameters required by the adamupdate function.

Similarly, to update the learnable parameters for a model function parameters using the adamupdate function, use the command:

[parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
    trailingAvg,trailingAverageSq,iteration);
where gradients is the output of the model gradients function, and trailingAvg, trailingAvgSq, and iteration are the hyperparameters required by the adamupdate function.

Use Model Gradients Function in Custom Training Loop

When training a deep learning model using a custom training loop, evaluate the model gradients and update the learnable parameters for each mini-batch.

This code snippet shows an example of using the dlfeval and adamupdate functions in a custom training loop.

iteration = 0;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Loop over mini-batches.
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;

        % Prepare mini-batch.
        % ...

        % Evaluate model gradients.
        [gradients, loss] = dlfeval(@modelGradients,dlnet,dlX,T);

        % Update learnable parameters.
        [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
            trailingAvg,trailingAverageSq,iteration);

    end
end

For an example showing how to train deep learning model with a dlnetwork object, see Train Network Using Custom Training Loop. For an example showing how to training a deep learning model defined as a function, see Train Network Using Model Function.

Debugging Model Gradients Functions

If there is an issue in the implementation of the model gradients function, the call to dlfeval may throw an error. Sometimes, when using the dlfeval function, it is not clear which line of code is throwing the error. To help locate the error, you can try the following:

Call Model Gradients Function Directly

Try calling the model gradients function directly (that is, without using the dlfeval function) with generated inputs of the expected sizes. If any of the lines of code throw an error, then it should be clear which one did. Note that when not using the dlfeval function, any calls to the dlgradient function are expected to error.

% Generate image input data.
X = rand([28 28 1 100],'single');
dlX = dlarray(dlX);

% Generate one-hot encoded target data.
T = repmat(eye(10,'single'),[1 10]);

[gradients, loss] = modelGradients(dlnet,dlX,T);

Run Model Gradients Code Manually

Run the code inside the model gradients function manually with generated inputs of the expected sizes and inspect the output and any thrown error messages.

For example, to check the model gradients function defined by:

function [gradients, loss] = modelGradients(dlnet, dlX, T)

    % Forward data through the dlnetwork object.
    dlY = forward(dlnet,dlX);

    % Compute loss.
    loss = crossentropy(dlX,T);

    % Compute gradients.
    gradients = dlgradient(loss,dlnet);

end

run the code:

% Generate image input data.
X = rand([28 28 1 100],'single');
dlX = dlarray(dlX);

% Generate one-hot encoded target data.
T = repmat(eye(10,'single'),[1 10]);

% Check forward pass.
dlY = forward(dlnet,dlX);

% Check loss calculation.
loss = crossentropy(dlX,T)

Related Topics