neuralODELayer
Description
A neural ODE layer outputs the solution of an ODE.
Creation
Properties
Network
— Neural network characterizing neural ODE function
dlnetwork
object
Neural network characterizing neural ODE function, specified as a
dlnetwork
object.
If Network
has one input, then
predict(net,Y)
defines the ODE system, where net
is the network. If Network
has two inputs, then
predict(net,T,Y)
defines the ODE system, where
T
is a time step repeated over the batch dimension.
The network size and format of the network inputs and outputs must match.
When GradientMode
is "adjoint"
, the network State
property must be empty. To use a network with a nonempty State
property, set GradientMode
to "direct"
.
TimeInterval
— Interval of integration
numeric vector
Interval of integration, specified as a numeric vector with two or more elements.
The elements in TimeInterval
must be all increasing or all
decreasing.
The solver imposes the initial conditions given by Y0
at the
initial time TimeInterval(1)
, then integrates the ODE function from
TimeInterval(1)
to TimeInterval(end)
.
If
TimeInterval
has two elements,[t0 tf]
, then the solver returns the solution evaluated at pointtf
.If
TimeInterval
has more than two elements,[t0 t1 ... tf]
, then the solver returns the solution evaluated at the given points[t1 ... tf]
. The solver does not step precisely to each point specified inTimeInterval
. Instead, the solver uses its own internal steps to compute the solution, then evaluates the solution at the points specified inTimeInterval
. The solutions produced at the specified points are of the same order of accuracy as the solutions computed at each internal step.Specifying several intermediate points has little effect on the efficiency of computation, but for large systems it can negatively affect memory management.
GradientMode
— Method to compute gradients
"direct"
(default) | "adjoint"
Method to compute gradients with respect to the initial conditions and parameters
when using the dlgradient
function, specified as one of these values:
"direct"
— Compute gradients by backpropagating through the operations undertaken by the numerical solver. This option best suits large mini-batch sizes or whenTimeInterval
contains many values."adjoint"
— Compute gradients by solving the associated adjoint ODE system. This option best suits small mini-batch sizes or whenTimeInterval
contains a small number of values.
When GradientMode
is "adjoint"
, the network State
property must be empty. To use a network with a nonempty State
property, set GradientMode
to "direct"
.
The dlaccelerate
function does not support accelerating the
dlode45
function when the GradientMode
option is
"direct"
. To accelerate the code that calls the
dlode45
function, set the GradientMode
option to
"adjoint"
or accelerate parts of your code that do not call the
dlode45
function with the GradientMode
option
set to "direct"
.
The dlaccelerate
function does not support accelerating
networks that contain NeuralODELayer
objects when the
GradientMode
option is "direct"
. To accelerate
networks that contain NeuralODELayer
objects, set the
GradientMode
option to "adjoint"
.
Warning
When GradientMode
is "adjoint"
, all
layers in the network must support acceleration. Otherwise, the software can return
unexpected results.
When GradientMode
is "adjoint"
, the
software traces the ODE function input to determine the computation graph used for
automatic differentiation. This tracing process can take some time and can end up
recomputing the same trace. By optimizing, caching, and reusing the traces, the
software can speed up the gradient computation.
For more information on deep learning function acceleration, see Deep Learning Function Acceleration for Custom Training Loops.
The NeuralODELayer
object stores this property as a character vector.
RelativeTolerance
— Relative error tolerance
1e-3
(default) | positive scalar
Relative error tolerance, specified as a positive scalar. The relative tolerance applies to all components of the solution.
Data Types: single
| double
| int8
| int16
| int32
| int64
| uint8
| uint16
| uint32
| uint64
AbsoluteTolerance
— Absolute error tolerance
1e-6
(default) | positive scalar
Absolute error tolerance, specified as a positive scalar. The absolute tolerance applies to all components of the solution.
Data Types: single
| double
| int8
| int16
| int32
| int64
| uint8
| uint16
| uint32
| uint64
Examples
Create Neural ODE Layer
Create a neural ODE layer. Specify an ODE network containing a convolution layer followed by a tanh layer. Specify a time interval of [0, 1].
inputSize = [14 14 8];
layersODE = [
imageInputLayer(inputSize)
convolution2dLayer(3,8,Padding="same")
tanhLayer];
netODE = dlnetwork(layersODE);
tspan = [0 1];
layer = neuralODELayer(netODE,tspan)
layer = NeuralODELayer with properties: Name: '' TimeInterval: [0 1] GradientMode: 'direct' RelativeTolerance: 1.0000e-03 AbsoluteTolerance: 1.0000e-06 Learnable Parameters Network: [1x1 dlnetwork] State Parameters No properties. Use properties method to see a list of all properties.
Create a neural network containing a neural ODE layer.
layers = [
imageInputLayer([28 28 1])
convolution2dLayer([3 3],8,Padding="same",Stride=2)
reluLayer
neuralODELayer(netODE,tspan)
fullyConnectedLayer(10)
softmaxLayer];
net = dlnetwork(layers)
net = dlnetwork with properties: Layers: [6x1 nnet.cnn.layer.Layer] Connections: [5x2 table] Learnables: [6x3 table] State: [0x3 table] InputNames: {'imageinput'} OutputNames: {'softmax'} Initialized: 1 View summary with summary.
Tips
To apply the neural ODE operation in deep learning models defined as functions or in custom layer functions, use
dlode45
.
Algorithms
Neural Ordinary Differential Equation
The neural ordinary differential equation (ODE) operation returns the solution of a specified ODE. In particular, given an input, a neural ODE operation outputs the numerical solution of the ODE for the time horizon (t0,t1) and with the initial condition y(t0) = y0, where t and y denote the ODE function inputs and θ is a set of learnable parameters. Typically, the initial condition y0 is either the network input or the output of another deep learning operation.
To apply the operation, NeuralODELayer
uses the ode45
function, which is based on an explicit Runge-Kutta (4,5) formula, the
Dormand-Prince pair. It is a single-step solver—in computing
y(tn), it needs only the solution at the
immediately preceding time point, y(tn-1)
[2]
[3].
Layer Input and Output Formats
Layers in a layer array or layer graph pass data to subsequent layers as formatted dlarray
objects.
The format of a dlarray
object is a string of characters in which each
character describes the corresponding dimension of the data. The formats consist of one or
more of these characters:
"S"
— Spatial"C"
— Channel"B"
— Batch"T"
— Time"U"
— Unspecified
For example, you can describe 2-D image data that is represented as a 4-D array, where the
first two dimensions correspond to the spatial dimensions of the images, the third
dimension corresponds to the channels of the images, and the fourth dimension
corresponds to the batch dimension, as having the format "SSCB"
(spatial, spatial, channel, batch).
You can interact with these dlarray
objects in automatic differentiation
workflows, such as those for developing a custom layer, using a functionLayer
object, or using the forward
and predict
functions with
dlnetwork
objects.
This table shows the supported input formats of NeuralODELayer
objects and the
corresponding output format. If the software passes the output of the layer to a custom
layer that does not inherit from the nnet.layer.Formattable
class, or a
FunctionLayer
object with the Formattable
property
set to 0
(false
), then the layer receives an
unformatted dlarray
object with dimensions ordered according to the formats
in this table. The formats listed here are only a subset. The layer may support additional
formats such as formats with additional "S"
(spatial) or
"U"
(unspecified) dimensions.
If TimeInterval
contains more than two elements, then the layer
outputs data with a "T"
(time) dimension.
Input Format | TimeInterval | Output Format |
---|---|---|
| [t0 tf] |
|
[t0 t1 ... tf] |
| |
| [t0 tf] |
|
[t0 t1 ... tf] |
| |
| [t0 tf] |
|
[t0 t1 ... tf] |
| |
| [t0 tf] |
|
[t0 t1 ... tf] |
| |
| [t0 tf] |
|
[t0 t1 ... tf] |
| |
| [t0 tf] |
|
[t0 t1 ... tf] |
| |
| [t0 tf] |
|
[t0 t1 ... tf] |
| |
"SB" (spatial, batch) | [t0 tf] | "SB" (spatial, batch) |
[t0 t1 ... tf] | "SBT" (spatial, batch, time) | |
"SSB" (spatial, spatial, batch) | [t0 tf] | "SSB" (spatial, spatial, batch) |
[t0 t1 ... tf] | "SSBT" (spatial, spatial, batch, time) | |
"SSSB" (spatial, spatial, spatial,
batch) | [t0 tf] | "SSSB" (spatial, spatial, spatial, batch) |
[t0 t1 ... tf] | "SSSBT" (spatial, spatial, spatial, batch, time) | |
"SS" (spatial, spatial) | [t0 tf] | "SS" (spatial, spatial) |
[t0 t1 ... tf] | "SST" (spatial, spatial, time) | |
"SSS" (spatial, spatial, spatial) | [t0 tf] | "SSS" (spatial, spatial, spatial) |
[t0 t1 ... tf] | "SSST" (spatial, spatial, spatial, time) |
Version History
Introduced in R2023b
MATLAB Command
You clicked a link that corresponds to this MATLAB command:
Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: United States.
You can also select a web site from the following list
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)
Asia Pacific
- Australia (English)
- India (English)
- New Zealand (English)
- 中国
- 日本Japanese (日本語)
- 한국Korean (한국어)