Main Content

trainNetwork

Train neural network

Description

For classification and regression tasks, you can train various types of neural networks using the trainNetwork function.

For example, you can train:

  • a convolutional neural network (ConvNet, CNN) for image data

  • a recurrent neural network (RNN) such as a long short-term memory (LSTM) or a gated recurrent unit (GRU) neural network for sequence and time-series data

  • a multilayer perceptron (MLP) neural network for numeric feature data

You can train on either a CPU or a GPU. For image classification and image regression, you can train a single neural network in parallel using multiple GPUs or a local or remote parallel pool. Training on a GPU or in parallel requires Parallel Computing Toolbox™. To use a GPU for deep learning, you must also have a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). To specify training options, including options for the execution environment, use the trainingOptions function.

When training a neural network, you can specify the predictors and responses as a single input or in two separate inputs.

example

net = trainNetwork(images,layers,options) trains the neural network specified by layers for image classification and regression tasks using the images and responses specified by images and the training options defined by options.

example

net = trainNetwork(images,responses,layers,options) trains using the images specified by images and responses specified by responses.

net = trainNetwork(sequences,layers,options) trains a neural network for sequence or time-series classification and regression tasks (for example, an LSTM or GRU neural network) using the sequences and responses specified by sequences.

example

net = trainNetwork(sequences,responses,layers,options) trains using the sequences specified by sequences and responses specified by responses.

example

net = trainNetwork(features,layers,options) trains a neural network for feature classification or regression tasks (for example, a multilayer perceptron (MLP) neural network) using the feature data and responses specified by features.

net = trainNetwork(features,responses,layers,options) trains using the feature data specified by features and responses specified by responses.

net = trainNetwork(mixed,layers,options) trains a neural network with multiple inputs with mixed data types with the data and responses specified by mixed.

[net,info] = trainNetwork(___) also returns information on the training using any of the previous syntaxes.

Examples

collapse all

Load the data as an ImageDatastore object.

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet', ...
    'nndemos','nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

The datastore contains 10,000 synthetic images of digits from 0 to 9. The images are generated by applying random transformations to digit images created with different fonts. Each digit image is 28-by-28 pixels. The datastore contains an equal number of images per category.

Display some of the images in the datastore.

figure
numImages = 10000;
perm = randperm(numImages,20);
for i = 1:20
    subplot(4,5,i);
    imshow(imds.Files{perm(i)});
    drawnow;
end

Figure contains 20 axes objects. Axes object 1 contains an object of type image. Axes object 2 contains an object of type image. Axes object 3 contains an object of type image. Axes object 4 contains an object of type image. Axes object 5 contains an object of type image. Axes object 6 contains an object of type image. Axes object 7 contains an object of type image. Axes object 8 contains an object of type image. Axes object 9 contains an object of type image. Axes object 10 contains an object of type image. Axes object 11 contains an object of type image. Axes object 12 contains an object of type image. Axes object 13 contains an object of type image. Axes object 14 contains an object of type image. Axes object 15 contains an object of type image. Axes object 16 contains an object of type image. Axes object 17 contains an object of type image. Axes object 18 contains an object of type image. Axes object 19 contains an object of type image. Axes object 20 contains an object of type image.

Divide the datastore so that each category in the training set has 750 images and the testing set has the remaining images from each label.

numTrainingFiles = 750;
[imdsTrain,imdsTest] = splitEachLabel(imds,numTrainingFiles,'randomize');

splitEachLabel splits the image files in digitData into two new datastores, imdsTrain and imdsTest.

Define the convolutional neural network architecture.

layers = [ ...
    imageInputLayer([28 28 1])
    convolution2dLayer(5,20)
    reluLayer
    maxPooling2dLayer(2,'Stride',2)
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

Set the options to the default settings for the stochastic gradient descent with momentum. Set the maximum number of epochs at 20, and start the training with an initial learning rate of 0.0001.

options = trainingOptions('sgdm', ...
    'MaxEpochs',20,...
    'InitialLearnRate',1e-4, ...
    'Verbose',false, ...
    'Plots','training-progress');

Train the network.

net = trainNetwork(imdsTrain,layers,options);

Figure Training Progress (19-Aug-2023 11:53:04) contains 2 axes objects and another object of type uigridlayout. Axes object 1 with xlabel Iteration, ylabel Loss contains 6 objects of type patch, text, line. Axes object 2 with xlabel Iteration, ylabel Accuracy (%) contains 6 objects of type patch, text, line.

Run the trained network on the test set, which was not used to train the network, and predict the image labels (digits).

YPred = classify(net,imdsTest);
YTest = imdsTest.Labels;

Calculate the accuracy. The accuracy is the ratio of the number of true labels in the test data matching the classifications from classify to the number of images in the test data.

accuracy = sum(YPred == YTest)/numel(YTest)
accuracy = 0.9400

Train a convolutional neural network using augmented image data. Data augmentation helps prevent the network from overfitting and memorizing the exact details of the training images.

Load the sample data, which consists of synthetic images of handwritten digits.

[XTrain,YTrain] = digitTrain4DArrayData;

digitTrain4DArrayData loads the digit training set as 4-D array data. XTrain is a 28-by-28-by-1-by-5000 array, where:

  • 28 is the height and width of the images.

  • 1 is the number of channels.

  • 5000 is the number of synthetic images of handwritten digits.

YTrain is a categorical vector containing the labels for each observation.

Set aside 1000 of the images for network validation.

idx = randperm(size(XTrain,4),1000);
XValidation = XTrain(:,:,:,idx);
XTrain(:,:,:,idx) = [];
YValidation = YTrain(idx);
YTrain(idx) = [];

Create an imageDataAugmenter object that specifies preprocessing options for image augmentation, such as resizing, rotation, translation, and reflection. Randomly translate the images up to three pixels horizontally and vertically, and rotate the images with an angle up to 20 degrees.

imageAugmenter = imageDataAugmenter( ...
    'RandRotation',[-20,20], ...
    'RandXTranslation',[-3 3], ...
    'RandYTranslation',[-3 3])
imageAugmenter = 
  imageDataAugmenter with properties:

           FillValue: 0
     RandXReflection: 0
     RandYReflection: 0
        RandRotation: [-20 20]
           RandScale: [1 1]
          RandXScale: [1 1]
          RandYScale: [1 1]
          RandXShear: [0 0]
          RandYShear: [0 0]
    RandXTranslation: [-3 3]
    RandYTranslation: [-3 3]

Create an augmentedImageDatastore object to use for network training and specify the image output size. During training, the datastore performs image augmentation and resizes the images. The datastore augments the images without saving any images to memory. trainNetwork updates the network parameters and then discards the augmented images.

imageSize = [28 28 1];
augimds = augmentedImageDatastore(imageSize,XTrain,YTrain,'DataAugmentation',imageAugmenter);

Specify the convolutional neural network architecture.

layers = [
    imageInputLayer(imageSize)
    
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

Specify training options for stochastic gradient descent with momentum.

opts = trainingOptions('sgdm', ...
    'MaxEpochs',15, ...
    'Shuffle','every-epoch', ...
    'Plots','training-progress', ...
    'Verbose',false, ...
    'ValidationData',{XValidation,YValidation});

Train the network. Because the validation images are not augmented, the validation accuracy is higher than the training accuracy.

net = trainNetwork(augimds,layers,opts);

Load the sample data, which consists of synthetic images of handwritten digits. The third output contains the corresponding angles in degrees by which each image has been rotated.

Load the training images as 4-D arrays using digitTrain4DArrayData. The output XTrain is a 28-by-28-by-1-by-5000 array, where:

  • 28 is the height and width of the images.

  • 1 is the number of channels.

  • 5000 is the number of synthetic images of handwritten digits.

YTrain contains the rotation angles in degrees.

[XTrain,~,YTrain] = digitTrain4DArrayData;

Display 20 random training images using imshow.

figure
numTrainImages = numel(YTrain);
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)
    subplot(4,5,i)    
    imshow(XTrain(:,:,:,idx(i)))
    drawnow;
end

Figure contains 20 axes objects. Axes object 1 contains an object of type image. Axes object 2 contains an object of type image. Axes object 3 contains an object of type image. Axes object 4 contains an object of type image. Axes object 5 contains an object of type image. Axes object 6 contains an object of type image. Axes object 7 contains an object of type image. Axes object 8 contains an object of type image. Axes object 9 contains an object of type image. Axes object 10 contains an object of type image. Axes object 11 contains an object of type image. Axes object 12 contains an object of type image. Axes object 13 contains an object of type image. Axes object 14 contains an object of type image. Axes object 15 contains an object of type image. Axes object 16 contains an object of type image. Axes object 17 contains an object of type image. Axes object 18 contains an object of type image. Axes object 19 contains an object of type image. Axes object 20 contains an object of type image.

Specify the convolutional neural network architecture. For regression problems, include a regression layer at the end of the network.

layers = [ ...
    imageInputLayer([28 28 1])
    convolution2dLayer(12,25)
    reluLayer
    fullyConnectedLayer(1)
    regressionLayer];

Specify the network training options. Set the initial learn rate to 0.001.

options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.001, ...
    'Verbose',false, ...
    'Plots','training-progress');

Train the network.

net = trainNetwork(XTrain,YTrain,layers,options);

Figure Training Progress (19-Aug-2023 11:41:24) contains 2 axes objects and another object of type uigridlayout. Axes object 1 with xlabel Iteration, ylabel Loss contains 7 objects of type patch, text, line. Axes object 2 with xlabel Iteration, ylabel RMSE contains 7 objects of type patch, text, line.

Test the performance of the network by evaluating the prediction accuracy of the test data. Use predict to predict the angles of rotation of the validation images.

[XTest,~,YTest] = digitTest4DArrayData;
YPred = predict(net,XTest);

Evaluate the performance of the model by calculating the root-mean-square error (RMSE) of the predicted and actual angles of rotation.

rmse = sqrt(mean((YTest - YPred).^2))
rmse = single
    6.0516

Train a deep learning LSTM network for sequence-to-label classification.

Load the example data from WaveformData.mat. The data is a numObservations-by-1 cell array of sequences, where numObservations is the number of sequences. Each sequence is a numChannels-by-numTimeSteps numeric array, where numChannels is the number of channels of the sequence and numTimeSteps is the number of time steps of the sequence.

load WaveformData

Visualize some of the sequences in a plot.

numChannels = size(data{1},1);

idx = [3 4 5 12];
figure
tiledlayout(2,2)
for i = 1:4
    nexttile
    stackedplot(data{idx(i)}',DisplayLabels="Channel "+string(1:numChannels))
    
    xlabel("Time Step")
    title("Class: " + string(labels(idx(i))))
end

Set aside data for testing. Partition the data into a training set containing 90% of the data and a test set containing the remaining 10% of the data. To partition the data, use the trainingPartitions function, attached to this example as a supporting file. To access this file, open the example as a live script.

numObservations = numel(data);
[idxTrain,idxTest] = trainingPartitions(numObservations, [0.9 0.1]);
XTrain = data(idxTrain);
TTrain = labels(idxTrain);

XTest = data(idxTest);
TTest = labels(idxTest);

Define the LSTM network architecture. Specify the input size as the number of channels of the input data. Specify an LSTM layer to have 120 hidden units and to output the last element of the sequence. Finally, include a fully connected with an output size that matches the number of classes, followed by a softmax layer and a classification layer.

numHiddenUnits = 120;
numClasses = numel(categories(TTrain));

layers = [ ...
    sequenceInputLayer(numChannels)
    lstmLayer(numHiddenUnits,OutputMode="last")
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer]
layers = 
  5×1 Layer array with layers:

     1   ''   Sequence Input          Sequence input with 3 dimensions
     2   ''   LSTM                    LSTM with 120 hidden units
     3   ''   Fully Connected         4 fully connected layer
     4   ''   Softmax                 softmax
     5   ''   Classification Output   crossentropyex

Specify the training options. Train using the Adam solver with a learn rate of 0.01 and a gradient threshold of 1. Set the maximum number of epochs to 150 and shuffle every epoch. The software, by default, trains on a GPU if one is available. Using a GPU requires Parallel Computing Toolbox and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).

options = trainingOptions("adam", ...
    MaxEpochs=150, ...
    InitialLearnRate=0.01,...
    Shuffle="every-epoch", ...
    GradientThreshold=1, ...
    Verbose=false, ...
    Plots="training-progress");

Train the LSTM network with the specified training options.

net = trainNetwork(XTrain,TTrain,layers,options);

Classify the test data. Specify the same mini-batch size used for training.

YTest = classify(net,XTest);

Calculate the classification accuracy of the predictions.

acc = mean(YTest == TTest)
acc = 0.8400

Display the classification results in a confusion chart.

figure
confusionchart(TTest,YTest)

If you have a data set of numeric features (for example a collection of numeric data without spatial or time dimensions), then you can train a deep learning network using a feature input layer.

Read the transmission casing data from the CSV file "transmissionCasingData.csv".

filename = "transmissionCasingData.csv";
tbl = readtable(filename,'TextType','String');

Convert the labels for prediction to categorical using the convertvars function.

labelName = "GearToothCondition";
tbl = convertvars(tbl,labelName,'categorical');

To train a network using categorical features, you must first convert the categorical features to numeric. First, convert the categorical predictors to categorical using the convertvars function by specifying a string array containing the names of all the categorical input variables. In this data set, there are two categorical features with names "SensorCondition" and "ShaftCondition".

categoricalInputNames = ["SensorCondition" "ShaftCondition"];
tbl = convertvars(tbl,categoricalInputNames,'categorical');

Loop over the categorical input variables. For each variable:

  • Convert the categorical values to one-hot encoded vectors using the onehotencode function.

  • Add the one-hot vectors to the table using the addvars function. Specify to insert the vectors after the column containing the corresponding categorical data.

  • Remove the corresponding column containing the categorical data.

for i = 1:numel(categoricalInputNames)
    name = categoricalInputNames(i);
    oh = onehotencode(tbl(:,name));
    tbl = addvars(tbl,oh,'After',name);
    tbl(:,name) = [];
end

Split the vectors into separate columns using the splitvars function.

tbl = splitvars(tbl);

View the first few rows of the table. Notice that the categorical predictors have been split into multiple columns with the categorical values as the variable names.

head(tbl)
    SigMean     SigMedian    SigRMS    SigVar     SigPeak    SigPeak2Peak    SigSkewness    SigKurtosis    SigCrestFactor    SigMAD     SigRangeCumSum    SigCorrDimension    SigApproxEntropy    SigLyapExponent    PeakFreq    HighFreqPower    EnvPower    PeakSpecKurtosis    No Sensor Drift    Sensor Drift    No Shaft Wear    Shaft Wear    GearToothCondition
    ________    _________    ______    _______    _______    ____________    ___________    ___________    ______________    _______    ______________    ________________    ________________    _______________    ________    _____________    ________    ________________    _______________    ____________    _____________    __________    __________________

    -0.94876     -0.9722     1.3726    0.98387    0.81571       3.6314        -0.041525       2.2666           2.0514         0.8081        28562              1.1429             0.031581            79.931            0          6.75e-06       3.23e-07         162.13                0                1                1              0           No Tooth Fault  
    -0.97537    -0.98958     1.3937    0.99105    0.81571       3.6314        -0.023777       2.2598           2.0203        0.81017        29418              1.1362             0.037835            70.325            0          5.08e-08       9.16e-08         226.12                0                1                1              0           No Tooth Fault  
      1.0502      1.0267     1.4449    0.98491     2.8157       3.6314         -0.04162       2.2658           1.9487        0.80853        31710              1.1479             0.031565            125.19            0          6.74e-06       2.85e-07         162.13                0                1                0              1           No Tooth Fault  
      1.0227      1.0045     1.4288    0.99553     2.8157       3.6314        -0.016356       2.2483           1.9707        0.81324        30984              1.1472             0.032088             112.5            0          4.99e-06        2.4e-07         162.13                0                1                0              1           No Tooth Fault  
      1.0123      1.0024     1.4202    0.99233     2.8157       3.6314        -0.014701       2.2542           1.9826        0.81156        30661              1.1469              0.03287            108.86            0          3.62e-06       2.28e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0275      1.0102     1.4338     1.0001     2.8157       3.6314         -0.02659       2.2439           1.9638        0.81589        31102              1.0985             0.033427            64.576            0          2.55e-06       1.65e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0464      1.0275     1.4477     1.0011     2.8157       3.6314        -0.042849       2.2455           1.9449        0.81595        31665              1.1417             0.034159            98.838            0          1.73e-06       1.55e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0459      1.0257     1.4402    0.98047     2.8157       3.6314        -0.035405       2.2757            1.955        0.80583        31554              1.1345               0.0353            44.223            0          1.11e-06       1.39e-07         230.39                0                1                0              1           No Tooth Fault  

View the class names of the data set.

classNames = categories(tbl{:,labelName})
classNames = 2x1 cell
    {'No Tooth Fault'}
    {'Tooth Fault'   }

Next, partition the data set into training and test partitions. Set aside 15% of the data for testing.

Determine the number of observations for each partition.

numObservations = size(tbl,1);
numObservationsTrain = floor(0.85*numObservations);
numObservationsTest = numObservations - numObservationsTrain;

Create an array of random indices corresponding to the observations and partition it using the partition sizes.

idx = randperm(numObservations);
idxTrain = idx(1:numObservationsTrain);
idxTest = idx(numObservationsTrain+1:end);

Partition the table of data into training and testing partitions using the indices.

tblTrain = tbl(idxTrain,:);
tblTest = tbl(idxTest,:);

Define a network with a feature input layer and specify the number of features. Also, configure the input layer to normalize the data using Z-score normalization.

numFeatures = size(tbl,2) - 1;
numClasses = numel(classNames);
 
layers = [
    featureInputLayer(numFeatures,'Normalization', 'zscore')
    fullyConnectedLayer(50)
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

Specify the training options.

miniBatchSize = 16;

options = trainingOptions('adam', ...
    'MiniBatchSize',miniBatchSize, ...
    'Shuffle','every-epoch', ...
    'Plots','training-progress', ...
    'Verbose',false);

Train the network using the architecture defined by layers, the training data, and the training options.

net = trainNetwork(tblTrain,layers,options);

Figure Training Progress (19-Aug-2023 11:44:07) contains 2 axes objects and another object of type uigridlayout. Axes object 1 with xlabel Iteration, ylabel Loss contains 7 objects of type patch, text, line. Axes object 2 with xlabel Iteration, ylabel Accuracy (%) contains 7 objects of type patch, text, line.

Predict the labels of the test data using the trained network and calculate the accuracy. The accuracy is the proportion of the labels that the network predicts correctly.

YPred = classify(net,tblTest,'MiniBatchSize',miniBatchSize);
YTest = tblTest{:,labelName};

accuracy = sum(YPred == YTest)/numel(YTest)
accuracy = 0.9688

Input Arguments

collapse all

Image data, specified as one of the following:

Data TypeDescriptionExample Usage
DatastoreImageDatastoreDatastore of images saved on disk.

Train image classification neural network with images saved on disk, where the images are the same size.

When the images are different sizes, use an AugmentedImageDatastore object.

ImageDatastore objects support image classification tasks only. To use image datastores for regression neural networks, create a transformed or combined datastore that contains the images and responses using the transform and combine functions, respectively.

AugmentedImageDatastoreDatastore that applies random affine geometric transformations, including resizing, rotation, reflection, shear, and translation.

  • Train image classification neural network with images saved on disk, where the images are different sizes.

  • Train image classification neural network and generate new data using augmentations.

TransformedDatastoreDatastore that transforms batches of data read from an underlying datastore using a custom transformation function.

  • Train image regression neural network.

  • Train neural networks with multiple inputs.

  • Transform datastores with outputs not supported by trainNetwork.

  • Apply custom transformations to datastore output.

CombinedDatastoreDatastore that reads from two or more underlying datastores.

  • Train image regression neural network.

  • Train neural networks with multiple inputs.

  • Combine predictors and responses from different data sources.

PixelLabelImageDatastore (Computer Vision Toolbox)Datastore that applies identical affine geometric transformations to images and corresponding pixel labels.Train neural network for semantic segmentation.
RandomPatchExtractionDatastore (Image Processing Toolbox)Datastore that extracts pairs of random patches from images or pixel label images and optionally applies identical random affine geometric transformations to the pairs.Train neural network for object detection.
DenoisingImageDatastore (Image Processing Toolbox)Datastore that applies randomly generated Gaussian noise.Train neural network for image denoising.
Custom mini-batch datastoreCustom datastore that returns mini-batches of data.

Train neural network using data in a format that other datastores do not support.

For details, see Develop Custom Mini-Batch Datastore.

Numeric arrayImages specified as numeric array. If you specify images as a numeric array, then you must also specify the responses argument.Train neural network using data that fits in memory and does not require additional processing like augmentation.
TableImages specified as a table. If you specify images as a table, then you can also specify which columns contain the responses using the responses argument.Train neural network using data stored in a table.

For neural networks with multiple inputs, the datastore must be a TransformedDatastore or CombinedDatastore object.

Tip

For sequences of images, for example video data, use the sequences input argument.

Datastore

Datastores read mini-batches of images and responses. Datastores are best suited when you have data that does not fit in memory or when you want to apply augmentations or transformations to the data.

The list below lists the datastores that are directly compatible with trainNetwork for image data.

For example, you can create an image datastore using the imageDatastore function and use the names of the folders containing the images as labels by setting the 'LabelSource' option to 'foldernames'. Alternatively, you can specify the labels manually using the Labels property of the image datastore.

Tip

Use augmentedImageDatastore for efficient preprocessing of images for deep learning, including image resizing. Do not use the ReadFcn option of ImageDatastore objects.

ImageDatastore allows batch reading of JPG or PNG image files using prefetching. If you set the ReadFcn option to a custom function, then ImageDatastore does not prefetch and is usually significantly slower.

You can use other built-in datastores for training deep learning neural networks by using the transform and combine functions. These functions can convert the data read from datastores to the format required by trainNetwork.

For neural networks with multiple inputs, the datastore must be a TransformedDatastore or CombinedDatastore object.

The required format of the datastore output depends on the neural network architecture.

Neural Network ArchitectureDatastore OutputExample Output
Single input layer

Table or cell array with two columns.

The first and second columns specify the predictors and targets, respectively.

Table elements must be scalars, row vectors, or 1-by-1 cell arrays containing a numeric array.

Custom mini-batch datastores must output tables.

Table for neural network with one input and one output:

data = read(ds)
data =

  4×2 table

        Predictors        Response
    __________________    ________

    {224×224×3 double}       2    
    {224×224×3 double}       7    
    {224×224×3 double}       9    
    {224×224×3 double}       9  

Cell array for neural network with one input and one output:

data = read(ds)
data =

  4×2 cell array

    {224×224×3 double}    {[2]}
    {224×224×3 double}    {[7]}
    {224×224×3 double}    {[9]}
    {224×224×3 double}    {[9]}

Multiple input layers

Cell array with (numInputs + 1) columns, where numInputs is the number of neural network inputs.

The first numInputs columns specify the predictors for each input and the last column specifies the targets.

The order of inputs is given by the InputNames property of the layer graph layers.

Cell array for neural network with two inputs and one output.

data = read(ds)
data =

  4×3 cell array

    {224×224×3 double}    {128×128×3 double}    {[2]}
    {224×224×3 double}    {128×128×3 double}    {[2]}
    {224×224×3 double}    {128×128×3 double}    {[9]}
    {224×224×3 double}    {128×128×3 double}    {[9]}

The format of the predictors depends on the type of data.

DataFormat
2-D images

h-by-w-by-c numeric array, where h, w, and c are the height, width, and number of channels of the images, respectively.

3-D imagesh-by-w-by-d-by-c numeric array, where h, w, d, and c are the height, width, depth, and number of channels of the images, respectively.

For predictors returned in tables, the elements must contain a numeric scalar, a numeric row vector, or a 1-by-1 cell array containing the numeric array.

The format of the responses depends on the type of task.

TaskResponse Format
Image classificationCategorical scalar
Image regression
  • Numeric scalar

  • Numeric vector

  • 3-D numeric array representing a 2-D image

  • 4-D numeric array representing a 3-D image

For responses returned in tables, the elements must be a categorical scalar, a numeric scalar, a numeric row vector, or a 1-by-1 cell array containing a numeric array.

For more information, see Datastores for Deep Learning.

Numeric Array

For data that fits in memory and does not require additional processing like augmentation, you can specify a data set of images as a numeric array. If you specify images as a numeric array, then you must also specify the responses argument.

The size and shape of the numeric array depends on the type of image data.

DataFormat
2-D images

h-by-w-by-c-by-N numeric array, where h, w, and c are the height, width, and number of channels of the images, respectively, and N is the number of images.

3-D imagesh-by-w-by-d-by-c-by-N numeric array, where h, w, d, and c are the height, width, depth, and number of channels of the images, respectively, and N is the number of images.

Table

As an alternative to datastores or numeric arrays, you can also specify images and responses in a table. If you specify images as a table, then you can also specify which columns contain the responses using the responses argument.

When specifying images and responses in a table, each row in the table corresponds to an observation.

For image input, the predictors must be in the first column of the table, specified as one of the following:

  • Absolute or relative file path to an image, specified as a character vector

  • 1-by-1 cell array containing a h-by-w-by-c numeric array representing a 2-D image, where h, w, and c correspond to the height, width, and number of channels of the image, respectively.

The format of the responses depends on the type of task.

TaskResponse Format
Image classificationCategorical scalar
Image regression
  • Numeric scalar

  • Two or more columns of scalar values

  • 1-by-1 cell array containing a h-by-w-by-c numeric array representing a 2-D image

  • 1-by-1 cell array containing a h-by-w-by-d-by-c numeric array representing a 3-D image

For neural networks with image input, if you do not specify responses, then the function, by default, uses the first column of tbl for the predictors and the subsequent columns as responses.

Tip

  • If the predictors or the responses contains NaNs, then they are propagated through the neural network during training. In these cases, the training usually fails to converge.

  • For regression tasks, normalizing the responses often helps to stabilize and speed up training of neural networks for regression. For more information, see Train Convolutional Neural Network for Regression.

  • To input complex-valued data into a neural network, the SplitComplexInputs option of the input layer must be 1.

Data Types: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64 | table
Complex Number Support: Yes

Sequence or time series data, specified as one of the following:

Data TypeDescriptionExample Usage
DatastoreTransformedDatastoreDatastore that transforms batches of data read from an underlying datastore using a custom transformation function.

  • Transform datastores with outputs not supported by trainNetwork.

  • Apply custom transformations to datastore output.

CombinedDatastoreDatastore that reads from two or more underlying datastores.

Combine predictors and responses from different data sources.

Custom mini-batch datastoreCustom datastore that returns mini-batches of data.

Train neural network using data in a format that other datastores do not support.

For details, see Develop Custom Mini-Batch Datastore.

Numeric or cell arrayA single sequence specified as a numeric array or a data set of sequences specified as cell array of numeric arrays. If you specify sequences as a numeric or cell array, then you must also specify the responses argument.Train neural network using data that fits in memory and does not require additional processing like custom transformations.

Datastore

Datastores read mini-batches of sequences and responses. Datastores are best suited when you have data that does not fit in memory or when you want to apply transformations to the data.

The list below lists the datastores that are directly compatible with trainNetwork for sequence data.

You can use other built-in datastores for training deep learning neural networks by using the transform and combine functions. These functions can convert the data read from datastores to the table or cell array format required by trainNetwork. For example, you can transform and combine data read from in-memory arrays and CSV files using ArrayDatastore and TabularTextDatastore objects, respectively.

The datastore must return data in a table or cell array. Custom mini-batch datastores must output tables.

Datastore OutputExample Output
Table
data = read(ds)
data =

  4×2 table

        Predictors        Response
    __________________    ________

    {12×50 double}           2    
    {12×50 double}           7    
    {12×50 double}           9    
    {12×50 double}           9  
Cell array
data = read(ds)
data =

  4×2 cell array

    {12×50 double}        {[2]}
    {12×50 double}        {[7]}
    {12×50 double}        {[9]}
    {12×50 double}        {[9]}

The format of the predictors depend on the type of data.

DataFormat of Predictors
Vector sequence

c-by-s matrix, where c is the number of features of the sequence and s is the sequence length.

1-D image sequence

h-by-c-by-s array, where h and c correspond to the height and number of channels of the image, respectively, and s is the sequence length.

Each sequence in the mini-batch must have the same sequence length.

2-D image sequence

h-by-w-by-c-by-s array, where h, w, and c correspond to the height, width, and number of channels of the image, respectively, and s is the sequence length.

Each sequence in the mini-batch must have the same sequence length.

3-D image sequence

h-by-w-by-d-by-c-by-s array, where h, w, d, and c correspond to the height, width, depth, and number of channels of the image, respectively, and s is the sequence length.

Each sequence in the mini-batch must have the same sequence length.

For predictors returned in tables, the elements must contain a numeric scalar, a numeric row vector, or a 1-by-1 cell array containing a numeric array.

The format of the responses depends on the type of task.

TaskFormat of Responses
Sequence-to-label classificationCategorical scalar
Sequence-to-one regression

Scalar

Sequence-to-vector regression

Numeric row vector

Sequence-to-sequence classification

  • 1-by-s sequence of categorical labels, where s is the sequence length of the corresponding predictor sequence.

  • h-by-w-by-s sequence of categorical labels, where h, w, and s are the height, width, and sequence length of the corresponding predictor sequence, respectively.

  • h-by-w-by-d-by-s sequence of categorical labels, where h, w, d, and s are the height, width, depth, and sequence length of the corresponding predictor sequence, respectively.

Each sequence in the mini-batch must have the same sequence length.

Sequence-to-sequence regression
  • R-by-s matrix, where R is the number of responses and s is the sequence length of the corresponding predictor sequence.

  • h-by-w-by-R-by-s sequence of numeric responses, where R is the number of responses , and h, w, and s are the height, width, and sequence length of the corresponding predictor sequence, respectively.

  • h-by-w-by-d-by-R-by-s sequence of numeric responses, where R is the number of responses , and h, w, d, and s are the height, width, depth, and sequence length of the corresponding predictor sequence, respectively.

Each sequence in the mini-batch must have the same sequence length.

For responses returned in tables, the elements must be a categorical scalar, a numeric scalar, a numeric row vector, or a 1-by-1 cell array containing a numeric array.

For more information, see Datastores for Deep Learning.

Numeric or Cell Array

For data that fits in memory and does not require additional processing like custom transformations, you can specify a single sequence as a numeric array or a data set of sequences as a cell array of numeric arrays. If you specify sequences as a cell or numeric array, then you must also specify the responses argument.

For cell array input, the cell array must be an N-by-1 cell array of numeric arrays, where N is the number of observations. The size and shape of the numeric array representing a sequence depends on the type of sequence data.

InputDescription
Vector sequencesc-by-s matrices, where c is the number of features of the sequences and s is the sequence length.
1-D image sequencesh-by-c-by-s arrays, where h and c correspond to the height and number of channels of the images, respectively, and s is the sequence length.
2-D image sequencesh-by-w-by-c-by-s arrays, where h, w, and c correspond to the height, width, and number of channels of the images, respectively, and s is the sequence length.
3-D image sequencesh-by-w-by-d-by-c-by-s, where h, w, d, and c correspond to the height, width, depth, and number of channels of the 3-D images, respectively, and s is the sequence length.

The trainNetwork function supports neural networks with at most one sequence input layer.

Tip

  • If the predictors or the responses contains NaNs, then they are propagated through the neural network during training. In these cases, the training usually fails to converge.

  • For regression tasks, normalizing the responses often helps to stabilize and speed up training. For more information, see Train Convolutional Neural Network for Regression.

  • To input complex-valued data into a neural network, the SplitComplexInputs option of the input layer must be 1.

Data Types: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64 | cell
Complex Number Support: Yes

Feature data, specified as one of the following:

Data TypeDescriptionExample Usage
DatastoreTransformedDatastoreDatastore that transforms batches of data read from an underlying datastore using a custom transformation function.

  • Train neural networks with multiple inputs.

  • Transform datastores with outputs not supported by trainNetwork.

  • Apply custom transformations to datastore output.

CombinedDatastoreDatastore that reads from two or more underlying datastores.

  • Train neural networks with multiple inputs.

  • Combine predictors and responses from different data sources.

Custom mini-batch datastoreCustom datastore that returns mini-batches of data.

Train neural network using data in a format that other datastores do not support.

For details, see Develop Custom Mini-Batch Datastore.

TableFeature data specified as a table. If you specify features as a table, then you can also specify which columns contain the responses using the responses argument.Train neural network using data stored in a table.
Numeric arrayFeature data specified as numeric array. If you specify features as a numeric array, then you must also specify the responses argument.Train neural network using data that fits in memory and does not require additional processing like custom transformations.

Datastore

Datastores read mini-batches of feature data and responses. Datastores are best suited when you have data that does not fit in memory or when you want to apply transformations to the data.

The list below lists the datastores that are directly compatible with trainNetwork for feature data.

You can use other built-in datastores for training deep learning neural networks by using the transform and combine functions. These functions can convert the data read from datastores to the table or cell array format required by trainNetwork. For more information, see Datastores for Deep Learning.

For neural networks with multiple inputs, the datastore must be a TransformedDatastore or CombinedDatastore object.

The datastore must return data in a table or a cell array. Custom mini-batch datastores must output tables. The format of the datastore output depends on the neural network architecture.

Neural Network ArchitectureDatastore OutputExample Output
Single input layer

Table or cell array with two columns.

The first and second columns specify the predictors and responses, respectively.

Table elements must be scalars, row vectors, or 1-by-1 cell arrays containing a numeric array.

Custom mini-batch datastores must output tables.

Table for neural network with one input and one output:

data = read(ds)
data =

  4×2 table

        Predictors        Response
    __________________    ________

    {24×1 double}            2    
    {24×1 double}            7    
    {24×1 double}            9    
    {24×1 double}            9  

Cell array for neural network with one input and one output:

data = read(ds)
data =

  4×2 cell array

    {24×1 double}    {[2]}
    {24×1 double}    {[7]}
    {24×1 double}    {[9]}
    {24×1 double}    {[9]}

Multiple input layers

Cell array with (numInputs + 1) columns, where numInputs is the number of neural network inputs.

The first numInputs columns specify the predictors for each input and the last column specifies the responses.

The order of inputs is given by the InputNames property of the layer graph layers.

Cell array for neural network with two inputs and one output:

data = read(ds)
data =

  4×3 cell array

    {24×1 double}    {28×1 double}    {[2]}
    {24×1 double}    {28×1 double}    {[2]}
    {24×1 double}    {28×1 double}    {[9]}
    {24×1 double}    {28×1 double}    {[9]}

The predictors must be c-by-1 column vectors, where c is the number of features.

The format of the responses depends on the type of task.

TaskFormat of Responses
ClassificationCategorical scalar
Regression

  • Scalar

  • Numeric vector

For more information, see Datastores for Deep Learning.

Table

For feature data that fits in memory and does not require additional processing like custom transformations, you can specify feature data and responses as a table.

Each row in the table corresponds to an observation. The arrangement of predictors and responses in the table columns depends on the type of task.

TaskPredictorsResponses
Feature classification

Features specified in one or more columns as scalars.

If you do not specify the responses argument, then the predictors must be in the first numFeatures columns of the table, where numFeatures is the number of features of the input data.

Categorical label

Feature regression

One or more columns of scalar values

For classification neural networks with feature input, if you do not specify the responses argument, then the function, by default, uses the first (numColumns - 1) columns of tbl for the predictors and the last column for the labels, where numFeatures is the number of features in the input data.

For regression neural networks with feature input, if you do not specify the responseNames argument, then the function, by default, uses the first numFeatures columns for the predictors and the subsequent columns for the responses, where numFeatures is the number of features in the input data.

Numeric Array

For feature data that fits in memory and does not require additional processing like custom transformations, you can specify feature data as a numeric array. If you specify feature data as a numeric array, then you must also specify the responses argument.

The numeric array must be an N-by-numFeatures numeric array, where N is the number of observations and numFeatures is the number of features of the input data.

Tip

  • Normalizing the responses often helps to stabilize and speed up training of neural networks for regression. For more information, see Train Convolutional Neural Network for Regression.

  • Responses must not contain NaNs. If the predictor data contains NaNs, then they are propagated through the training. However, in most cases, the training fails to converge.

  • To input complex-valued data into a neural network, the SplitComplexInputs option of the input layer must be 1.

Data Types: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64 | table
Complex Number Support: Yes

Mixed data and responses, specified as one of the following:

Data TypeDescriptionExample Usage
TransformedDatastoreDatastore that transforms batches of data read from an underlying datastore using a custom transformation function.

  • Train neural networks with multiple inputs.

  • Transform outputs of datastores not supported by trainNetwork to the have the required format.

  • Apply custom transformations to datastore output.

CombinedDatastoreDatastore that reads from two or more underlying datastores.

  • Train neural networks with multiple inputs.

  • Combine predictors and responses from different data sources.

Custom mini-batch datastoreCustom datastore that returns mini-batches of data.

Train neural network using data in a format that other datastores do not support.

For details, see Develop Custom Mini-Batch Datastore.

You can use other built-in datastores for training deep learning neural networks by using the transform and combine functions. These functions can convert the data read from datastores to the table or cell array format required by trainNetwork. For more information, see Datastores for Deep Learning.

The datastore must return data in a table or a cell array. Custom mini-batch datastores must output tables. The format of the datastore output depends on the neural network architecture.

Datastore OutputExample Output

Cell array with (numInputs + 1) columns, where numInputs is the number of neural network inputs.

The first numInputs columns specify the predictors for each input and the last column specifies the responses.

The order of inputs is given by the InputNames property of the layer graph layers.

data = read(ds)
data =

  4×3 cell array

    {24×1 double}    {28×1 double}    {[2]}
    {24×1 double}    {28×1 double}    {[2]}
    {24×1 double}    {28×1 double}    {[9]}
    {24×1 double}    {28×1 double}    {[9]}

For image, sequence, and feature predictor input, the format of the predictors must match the formats described in the images, sequences, or features argument descriptions, respectively. Similarly, the format of the responses must match the formats described in the images, sequences, or features argument descriptions that corresponds to the type of task.

The trainNetwork function supports neural networks with at most one sequence input layer.

For an example showing how to train a neural network with multiple inputs, see Train Network on Image and Feature Data.

Tip

  • To convert a numeric array to a datastore, use ArrayDatastore.

  • When combining layers in a neural network with mixed types of data, you may need to reformat the data before passing it to a combination layer (such as a concatenation or an addition layer). To reformat the data, you can use a flatten layer to flatten the spatial dimensions into the channel dimension, or create a FunctionLayer object or custom layer that reformats and reshapes.

Responses.

When the input data is a numeric array or a cell array, specify the responses as one of the following.

  • categorical vector of labels

  • numeric array of numeric responses

  • cell array of categorical or numeric sequences

When the input data is a table, you can optionally specify which columns of the table contains the responses as one of the following:

  • character vector

  • cell array of character vectors

  • string array

When the input data is a numeric array or a cell array, then the format of the responses depends on the type of task.

TaskFormat
ClassificationImage classificationN-by-1 categorical vector of labels, where N is the number of observations.
Feature classification
Sequence-to-label classification
Sequence-to-sequence classification

N-by-1 cell array of categorical sequences of labels, where N is the number of observations. Each sequence must have the same number of time steps as the corresponding predictor sequence.

For sequence-to-sequence classification tasks with one observation, sequences can also be a vector. In this case, responses must be a categorical row-vector of labels.

Regression2-D image regression
  • N-by-R matrix, where N is the number of images and R is the number of responses.

  • h-by-w-by-c-by-N numeric array, where h, w, and c are the height, width, and number of channels of the images, respectively, and N is the number of images.

3-D image regression
  • N-by-R matrix, where N is the number of images and R is the number of responses.

  • h-by-w-by-d-by-c-by-N numeric array, where h, w, d, and c are the height, width, depth, and number of channels of the images, respectively, and N is the number of images.

Feature regression

N-by-R matrix, where N is the number of observations and R is the number of responses.

Sequence-to-one regressionN-by-R matrix, where N is the number of sequences and R is the number of responses.
Sequence-to-sequence regression

N-by-1 cell array of numeric sequences, where N is the number of sequences, with sequences given by one of the following:

  • R-by-s matrix, where R is the number of responses and s is the sequence length of the corresponding predictor sequence.

  • h-by-w-by-R-by-s array, where h and w are the height and width of the output, respectively, R is the number of responses, and s is the sequence length of the corresponding predictor sequence.

  • h-by-w-by-d-by-R-by-s array, where h, w, and d are the height, width, and depth of the output, respectively, R is the number of responses, and s is the sequence length of the corresponding predictor sequence.

For sequence-to-sequence regression tasks with one observation, sequences can be a numeric array. In this case, responses must be a numeric array of responses.

Tip

Normalizing the responses often helps to stabilize and speed up training of neural networks for regression. For more information, see Train Convolutional Neural Network for Regression.

Tip

Responses must not contain NaNs. If the predictor data contains NaNs, then they are propagated through the training. However, in most cases, the training fails to converge.

Neural network layers, specified as a Layer array or a LayerGraph object.

To create a neural network with all layers connected sequentially, you can use a Layer array as the input argument. In this case, the returned neural network is a SeriesNetwork object.

A directed acyclic graph (DAG) neural network has a complex structure in which layers can have multiple inputs and outputs. To create a DAG neural network, specify the neural network architecture as a LayerGraph object and then use that layer graph as the input argument to trainNetwork.

The trainNetwork function supports neural networks with at most one sequence input layer.

For a list of built-in layers, see List of Deep Learning Layers.

Training options, specified as a TrainingOptionsSGDM, TrainingOptionsRMSProp, or TrainingOptionsADAM object returned by the trainingOptions function.

Output Arguments

collapse all

Trained neural network, returned as a SeriesNetwork object or a DAGNetwork object.

If you train the neural network using a Layer array, then net is a SeriesNetwork object. If you train the neural network using a LayerGraph object, then net is a DAGNetwork object.

Training information, returned as a structure, where each field is a scalar or a numeric vector with one element per training iteration.

For classification tasks, info contains the following fields:

  • TrainingLoss — Loss function values

  • TrainingAccuracy — Training accuracies

  • ValidationLoss — Loss function values

  • ValidationAccuracy — Validation accuracies

  • BaseLearnRate — Learning rates

  • FinalValidationLoss — Validation loss of returned neural network

  • FinalValidationAccuracy — Validation accuracy of returned neural network

  • OutputNetworkIteration — Iteration number of returned neural network

For regression tasks, info contains the following fields:

  • TrainingLoss — Loss function values

  • TrainingRMSE — Training RMSE values

  • ValidationLoss — Loss function values

  • ValidationRMSE — Validation RMSE values

  • BaseLearnRate — Learning rates

  • FinalValidationLoss — Validation loss of returned neural network

  • FinalValidationRMSE — Validation RMSE of returned neural network

  • OutputNetworkIteration — Iteration number of returned neural network

The structure only contains the fields ValidationLoss, ValidationAccuracy, ValidationRMSE , FinalValidationLoss , FinalValidationAccuracy, and FinalValidationRMSE when options specifies validation data. The ValidationFrequency training option determines which iterations the software calculates validation metrics. The final validation metrics are scalar. The other fields of the structure are row vectors, where each element corresponds to a training iteration. For iterations when the software does not calculate validation metrics, the corresponding values in the structure are NaN.

For neural networks containing batch normalization layers, if the BatchNormalizationStatistics training option is 'population' then the final validation metrics are often different from the validation metrics evaluated during training. This is because batch normalization layers in the final neural network perform different operations than during training. For more information, see batchNormalizationLayer.

More About

collapse all

Save Checkpoint Neural Networks and Resume Training

Deep Learning Toolbox™ enables you to save neural networks as .mat files during training. This periodic saving is especially useful when you have a large neural network or a large data set, and training takes a long time. If the training is interrupted for some reason, you can resume training from the last saved checkpoint neural network. If you want the trainnet and trainNetwork functions to save checkpoint neural networks, then you must specify the name of the path by using the CheckpointPath option of trainingOptions. If the path that you specify does not exist, then trainingOptions returns an error.

The software automatically assigns unique names to checkpoint neural network files. In the example name, net_checkpoint__351__2018_04_12__18_09_52.mat, 351 is the iteration number, 2018_04_12 is the date, and 18_09_52 is the time at which the software saves the neural network. You can load a checkpoint neural network file by double-clicking it or using the load command at the command line. For example:

load net_checkpoint__351__2018_04_12__18_09_52.mat
You can then resume training by using the layers of the neural network as an input argument to trainnet or trainNetwork. For example:

trainNetwork(XTrain,TTrain,net.Layers,options)
You must manually specify the training options and the input data, because the checkpoint neural network does not contain this information. For an example, see Resume Training from Checkpoint Network.

Floating-Point Arithmetic

When you train a neural network using the trainnet or trainNetwork functions, or when you use prediction or validation functions with DAGNetwork and SeriesNetwork objects, the software performs these computations using single-precision, floating-point arithmetic. Functions for prediction and validation include predict, classify, and activations. The software uses single-precision arithmetic when you train neural networks using both CPUs and GPUs.

Reproducibility

To provide the best performance, deep learning using a GPU in MATLAB® is not guaranteed to be deterministic. Depending on your network architecture, under some conditions you might get different results when using a GPU to train two identical networks or make two predictions using the same network and data.

Extended Capabilities

Version History

Introduced in R2016a

expand all