Deep Learning Function Acceleration for Custom Training Loops
When using the dlfeval
function in a custom training loop, the
software traces each input dlarray
object of the model loss function to
determine the computation graph used for automatic differentiation. This tracing process can
take some time and can spend time recomputing the same trace. By optimizing, caching, and
reusing the traces, you can speed up gradient computation in deep learning functions. You
can also optimize, cache, and reuse traces to accelerate other deep learning functions that
do not require automatic differentiation, for example you can also accelerate model
functions and functions used for prediction.
To speed up calls to deep learning functions, you can use the dlaccelerate
function to create an AcceleratedFunction
object that automatically optimizes, caches, and reuses
the traces. You can use the dlaccelerate
function to accelerate model functions and model loss functions directly.
The returned AcceleratedFunction
object caches the
traces of calls to the underlying function and reuses the cached
result when the same input pattern reoccurs.
Try using dlaccelerate
for function calls that:
are long-running
have
dlarray
objects, structures ofdlarray
objects, ordlnetwork
objects as inputsdo not have side effects like writing to files or displaying output
Invoke the accelerated function as you would invoke the underlying function. Note that the accelerated function is not a function handle.
Note
When using the dlfeval
function, the software automatically
accelerates the forward
and predict
functions for
dlnetwork
input. If you accelerate a deep learning function where the
majority of the computation takes place in calls to the forward
or
predict
functions for dlnetwork
input, then you might
not see an improvement in training time.
Because of the nature of caching traces, not all functions support acceleration.
The caching process can cache values that you might expect to change or that depend on external factors. You must take care when you accelerate functions that:
have inputs with random or frequently changing values
have outputs with frequently changing values
generate random numbers
use
if
statements andwhile
loops with conditions that depend on the values ofdlarray
objectshave inputs that are handles or that depend on handles
Read data from external sources (for example, by using a datastore or a
minibatchqueue
object)
Because the caching process requires extra computation, acceleration can lead to longer running code in some cases. This scenario can happen when the software spends time creating new caches that do not get reused often. For example, when you pass multiple mini-batches of different sequence lengths to the function, the software triggers a new trace for each unique sequence length.
Accelerated functions can do the following when calculating a new trace only.
modify the global state such as, the random number stream or global variables
use file input or output
display data using graphics or the command line display
When using accelerated functions in parallel, such as when using a
parfor
loop, then each worker maintains its own cache. The cache is
not transferred to the host.
Functions and custom layers used in accelerated functions must also support acceleration.
You can nest and recursively call accelerated functions. However, it is usually more efficient to have a single accelerated function.
Accelerate Deep Learning Function Directly
In most cases, you can accelerate deep learning functions directly. For example, you can accelerate the model loss function directly by replacing calls to the model loss function with calls to the corresponding accelerated function:
Consider the following use of the dlfeval
function in a custom
training
loop.
[loss,gradients,state] = dlfeval(@modelLoss,parameters,X,T,state)
dlaccelerate
function and evaluate the returned AcceleratedFunction
object:accfun = dlaccelerate(@modelLoss); [loss,gradients,state] = dlfeval(accfun,parameters,X,T,state)
Because the cached traces are not directly attached to the
AcceleratedFunction
object and that they are shared between
AcceleratedFunction
objects that use the same underlying function,
you can create the AcceleratedFunction
either in or before the custom
training loop body.
Accelerate Parts of Deep Learning Function
If a deep learning function does not fully support acceleration, for example,
functions that require an if
statement with a condition that depends
on the value of a dlarray
object, then you can accelerate parts of a
deep learning function by creating a separate function contains any supported function
calls you want to accelerate.
For example, consider the following code snippet that calls different functions
depending on whether the sum of the dlarray
object X
is negative or
nonnegative.
if sum(X,"all") < 0 Y = negFun1(parameters,X); Y = negFun2(parameters,Y); else Y = posFun1(parameters,X); Y = posFun2(parameters,Y); end
Because the if
statement depends on the value of a
dlarray
object, a function that contains this code snippet does not
support acceleration. However, if the blocks of code used inside the body of the
if
statement support acceleration, then you can accelerate these
parts separately by creating a new function containing those blocks and accelerating the
new functions instead.
For example, create the functions negFunAll
and
posFunAll
that contain the blocks of code used in the body of the
if
statement.
function Y = negFunAll(parameters,X) Y = negFun1(parameters,X); Y = negFun2(parameters,Y); end function Y = posFunAll(parameters,X) Y = posFun1(parameters,X); Y = posFun2(parameters,Y); end
if
statement
instead.accfunNeg = dlaccelerate(@negFunAll) accfunPos = dlaccelerate(@posFunAll) if sum(X,"all") < 0 Y = accfunNeg(parameters,X); else Y = accfunPos(parameters,X); end
Reusing Caches
Reusing a cached trace depends on the function inputs and outputs:
For any
dlarray
object or structure ofdlarray
object inputs, the trace depends on the size, format, and underlying datatype of thedlarray
. That is, the accelerated function triggers a new trace fordlarray
inputs with size, format, or underlying datatype not contained in the cache. Anydlarray
inputs differing only by value to a previously cached trace do not trigger a new trace.For any
dlnetwork
inputs, the trace depends on the size, format, and underlying datatype of thedlnetwork
state and learnable parameters. That is, the accelerated function triggers a new trace fordlnetwork
inputs with learnable parameters or state with size, format, and underlying datatype not contained in the cache. Anydlnetwork
inputs differing only by the value of the state and learnable parameters to a previously cached trace do not trigger a new trace.For other types of input, the trace depends on the values of the input. That is, the accelerated function triggers a new trace for other types of input with value not contained in the cache. Any other inputs that have the same value as a previously cached trace do not trigger a new trace.
The trace depends on the number of function outputs. That is, the accelerated function triggers a new trace for function calls with previously unseen numbers of output arguments. Any function calls with the same number of output arguments as a previously cached trace do not trigger a new trace.
When necessary, the software caches any new traces by evaluating the underlying function
and caching the resulting trace in the AcceleratedFunction
object.
Caution
An AcceleratedFunction
object is not aware of updates to the underlying
function. If you modify the function associated with the accelerated function, then
clear the cache using the clearCache
object function or alternatively use the command
clear functions
.
Storing and Clearing Caches
AcceleratedFunction
objects store the cache in a queue:
The software adds new traces to the back of the queue.
When the cache is full, the software discards the cached item at the head of the queue.
When a cache is reused, the software moves the cached item towards the back of the queue. This helps prevents the software from discarding commonly reused cached items.
The AcceleratedFunction
objects do not directly hold the cache. This
means that:
Multiple
AcceleratedFunction
objects that have the same underlying function share the same cache.Clearing or overwriting a variable containing an
AcceleratedFunction
object does not clear the cache.Overwriting a variable containing an
AcceleratedFunction
with anotherAcceleratedFunction
with the same underlying function does not clear the cache.
Accelerated functions that have the same underlying function share the same cache.
To clear the cache of an accelerated function, use the clearCache
object function. Alternatively, you can clear all functions in the current MATLAB® session using the commands clear functions
or
clear all
.
Note
Clearing the AcceleratedFunction
variable does not clear the cache associated with
the input function. To clear the cache for an AcceleratedFunction
object that no longer exists in the workspace, create a new AcceleratedFunction
object to the same function, and use the
clearCache
function on the new object. Alternatively, you
can clear all functions in the current MATLAB session using the commands clear functions
or
clear all
.
Acceleration Considerations
Because of the nature of caching traces, not all functions support acceleration.
The caching process can cache values that you might expect to change or that depend on external factors. You must take care when you accelerate functions that:
have inputs with random or frequently changing values
have outputs with frequently changing values
generate random numbers
use
if
statements andwhile
loops with conditions that depend on the values ofdlarray
objectshave inputs that are handles or that depend on handles
Read data from external sources (for example, by using a datastore or a
minibatchqueue
object)
Because the caching process requires extra computation, acceleration can lead to longer running code in some cases. This scenario can happen when the software spends time creating new caches that do not get reused often. For example, when you pass multiple mini-batches of different sequence lengths to the function, the software triggers a new trace for each unique sequence length.
Accelerated functions can do the following when calculating a new trace only.
modify the global state such as, the random number stream or global variables
use file input or output
display data using graphics or the command line display
When using accelerated functions in parallel, such as when using a
parfor
loop, then each worker maintains its own cache. The cache is
not transferred to the host.
Functions and custom layers used in accelerated functions must also support acceleration.
Function Inputs with Random or Frequently Changing Values
You must take care when you accelerate functions that take random or frequently
changing values as input, such as a model loss function that takes random noise as
input and adds it to the input data. If any random or frequently changing inputs to
an accelerated function are not dlarray
objects, then the function
trigger a new trace for each previously unseen value.
You can check for scenarios like this by inspecting the Occupancy
and HitRate
properties
of the AcceleratedFunction
object. If the Occupancy
property is high and the HitRate
is low, then this can indicate that the
AcceleratedFunction
object creates many new traces that it does
not reuse.
For dlarray
object input, changes in value to not trigger new
traces. To prevent frequently changing input from triggering new traces for each
evaluation, refactor your code such that the random inputs are
dlarray
inputs.
For example, consider the model loss function that accepts a random array of noise values:
function [loss,gradients,state] = modelLoss(parameters,X,T,state,noise) X = X + noise; [Y,state] = model(parameters,X,state); loss = crossentropy(Y,T); gradients = dlgradient(loss,parameters); end
To accelerate this model loss function, convert the input noise
to dlarray
before evaluating the accelerated function. Because the
modelLoss
function also supports dlarray
input for noise, you do not need to make changes to the
function.
noise = dlarray(noise,"SSCB");
accfun = dlaccelerate(@modelLoss);
[loss,gradients,state] = dlfeval(accfun,parameters,X,T,state,noise);
Alternatively, you can accelerate the parts of the model loss function that do not require the random input.
Functions with Random Number Generation
You must take care when you accelerate functions that use random number generation, such as
functions that generate random noise to add to the input. When the software caches
the trace of a function that generates random numbers that are not
dlarray
objects, the software caches the resulting random
samples in the trace. When reusing the trace, the accelerated function uses the
cached random sample. The accelerated function does not generate new random
values.
Random number generation using the "like"
option of the rand
function with a dlarray
object supports acceleration. To use random number generation in an accelerated function, ensure that the function uses the rand
function with the "like"
option set to a traced dlarray
object (a dlarray
object that depends on an input dlarray
object).
For example, consider the following model loss function.
function [loss,gradients,state] = modelLoss(parameters,X,T,state) sz = size(X); noise = rand(sz); X = X + noise; [Y,state] = model(parameters,X,state); loss = crossentropy(Y,T); gradients = dlgradient(loss,parameters); end
To ensure that the rand
function generates a new value for
each evaluation, use the "like"
option with the traced
dlarray
object
X
.
function [loss,gradients,state] = modelLoss(parameters,X,T,state) sz = size(X); noise = rand(sz,"like",X); X = X + noise; [Y,state] = model(parameters,X); loss = crossentropy(Y,T); gradients = dlgradient(loss,parameters); end
Alternatively, you can accelerate the parts of the model loss function that do not require random number generation.
Using if
Statements and while
Loops
You must take care when you accelerate functions that use if
statements and
while
loops. In particular, you can get unexpected results when you
accelerate functions with if
statements or while
loops
that yield different code paths for function inputs of the same size and format.
Accelerating functions with if
statement or while
loop conditions that depend on the values of the function input or values from external
sources (for example, results of random number generation) can lead to unexpected behavior.
When the accelerated function caches a new trace, if the function contains an
if
statement or while
loop, then the software
caches the trace of the resulting code path given by the if
statement or
while
loop condition for that particular trace. Because changes in
the value of the dlarray
input do not trigger a new trace, when reusing the
trace with different values, the software uses the same cached trace (which contains the
same cached code path) even when a difference in value should result in a different code
path.
Usually, accelerating functions that contain if
statements or
while
loops with conditions that do not depend on the values of the
function input or external factors (for example, while
loops that iterate
over elements in an array) does not result in unexpected behavior. For example, because
changes in the size of a dlarray
input trigger a new trace, when reusing
the trace with inputs of the same size, the cached code path for inputs of that size remain
consistent, even when there are differences in values.
To avoid unexpected behavior from caching code paths of if
statements,
you can refactor your code so that it determines the correct result by combining the results
of all branches and extracting the desired solution.
For example, consider this code.
if tf Y = funcA(X); else Y = funcB(X); end
Y = tf*funcA(X) + ~tf*funcB(X);
Y = cat(3,funcA(X),funcB(X)); Y = Y(:,:,[tf ~tf]);
if
statement.To use if
statements and while
loops that
depend on dlarray
object values, accelerate the body of the
if
statement or while
loop only.
Function Inputs that Depend on Handles
You must take care when you accelerate functions that take objects that depend on
handles as input, such as a minibatchqueue
object that has a
preprocessing function specified as a function handle. The
AcceleratedFunction
object throws an error when evaluating the
function with inputs depending on handles.
Instead, you can accelerate the parts of the model loss function that do not require inputs that depend on handles.
Debugging
You must take care when you debug accelerated functions. Cached traces do not support break points. When using accelerated functions, the software reaches break points in the underlying function during the tracing process only.
To debug the code in the underlying function using breakpoints, disable the
acceleration by setting the Enabled
property to
false
.
To debug the cached traces, you can compare the outputs of the accelerated
functions with the outputs of the underlying function, by setting the CheckMode
property to
"tolerance"
.
dlode45
Does Not Support Acceleration When GradientMode
Is "direct"
The dlaccelerate
function does not support accelerating the
dlode45
function when the GradientMode
option is
"direct"
. To accelerate the code that calls the
dlode45
function, set the GradientMode
option to
"adjoint"
or accelerate parts of your code that do not call the
dlode45
function with the GradientMode
option
set to "direct"
.
See Also
dlaccelerate
| AcceleratedFunction
| clearCache
| dlarray
| dlgradient
| dlfeval