Main Content

Train Neural Network with Tabular Data

Since R2023b

This example shows how to train a neural network with tabular data.

If you have a data set of numeric and categorical features (for example tabular data without spatial or time dimensions), then you can train a deep neural network using a feature input layer. This example trains a neural network that predicts the gear tooth condition given a table of numeric and categorical sensor readings.

Load Training Data

Read the transmission casing data from the CSV file "transmissionCasingData.csv".

filename = "transmissionCasingData.csv";
tbl = readtable(filename,TextType="String");

Convert the labels for prediction to categorical using the convertvars function.

labelName = "GearToothCondition";
tbl = convertvars(tbl,labelName,"categorical");

To train a network using categorical features, you must first convert the categorical features to numeric. First, convert the categorical predictors to categorical using the convertvars function by specifying a string array containing the names of all the categorical input variables. In this data set, there are two categorical features with names "SensorCondition" and "ShaftCondition".

categoricalPredictorNames = ["SensorCondition" "ShaftCondition"];
tbl = convertvars(tbl,categoricalPredictorNames,"categorical");

Loop over the categorical input variables. For each variable, convert the categorical values to one-hot encoded vectors using the onehotencode function.

for i = 1:numel(categoricalPredictorNames)
    name = categoricalPredictorNames(i);
    tbl.(name) = onehotencode(tbl.(name),2);
end

View the first few rows of the table. Notice that the categorical predictors have been split into multiple columns.

head(tbl)
    SigMean     SigMedian    SigRMS    SigVar     SigPeak    SigPeak2Peak    SigSkewness    SigKurtosis    SigCrestFactor    SigMAD     SigRangeCumSum    SigCorrDimension    SigApproxEntropy    SigLyapExponent    PeakFreq    HighFreqPower    EnvPower    PeakSpecKurtosis    SensorCondition    ShaftCondition    GearToothCondition
    ________    _________    ______    _______    _______    ____________    ___________    ___________    ______________    _______    ______________    ________________    ________________    _______________    ________    _____________    ________    ________________    _______________    ______________    __________________

    -0.94876     -0.9722     1.3726    0.98387    0.81571       3.6314        -0.041525       2.2666           2.0514         0.8081        28562              1.1429             0.031581            79.931            0          6.75e-06       3.23e-07         162.13             0    1             1    0          No Tooth Fault  
    -0.97537    -0.98958     1.3937    0.99105    0.81571       3.6314        -0.023777       2.2598           2.0203        0.81017        29418              1.1362             0.037835            70.325            0          5.08e-08       9.16e-08         226.12             0    1             1    0          No Tooth Fault  
      1.0502      1.0267     1.4449    0.98491     2.8157       3.6314         -0.04162       2.2658           1.9487        0.80853        31710              1.1479             0.031565            125.19            0          6.74e-06       2.85e-07         162.13             0    1             0    1          No Tooth Fault  
      1.0227      1.0045     1.4288    0.99553     2.8157       3.6314        -0.016356       2.2483           1.9707        0.81324        30984              1.1472             0.032088             112.5            0          4.99e-06        2.4e-07         162.13             0    1             0    1          No Tooth Fault  
      1.0123      1.0024     1.4202    0.99233     2.8157       3.6314        -0.014701       2.2542           1.9826        0.81156        30661              1.1469              0.03287            108.86            0          3.62e-06       2.28e-07         230.39             0    1             0    1          No Tooth Fault  
      1.0275      1.0102     1.4338     1.0001     2.8157       3.6314         -0.02659       2.2439           1.9638        0.81589        31102              1.0985             0.033427            64.576            0          2.55e-06       1.65e-07         230.39             0    1             0    1          No Tooth Fault  
      1.0464      1.0275     1.4477     1.0011     2.8157       3.6314        -0.042849       2.2455           1.9449        0.81595        31665              1.1417             0.034159            98.838            0          1.73e-06       1.55e-07         230.39             0    1             0    1          No Tooth Fault  
      1.0459      1.0257     1.4402    0.98047     2.8157       3.6314        -0.035405       2.2757            1.955        0.80583        31554              1.1345               0.0353            44.223            0          1.11e-06       1.39e-07         230.39             0    1             0    1          No Tooth Fault  

View the class names of the data set.

classNames = categories(tbl{:,labelName})
classNames = 2x1 cell
    {'No Tooth Fault'}
    {'Tooth Fault'   }

Set aside data for testing. Partition the data into a training set containing 80% of the data, a validation set containing 10% of the data, and a test set containing the remaining 10% of the data. To partition the data, use the trainingPartitions function, attached to this example as a supporting file. To access this file, open the example as a live script.

numObservations = size(tbl,1);
[idxTrain,idxValidation,idxTest] = trainingPartitions(numObservations,[0.80 0.1 0.1]);

tblTrain = tbl(idxTrain,:);
tblValidation = tbl(idxValidation,:);
tblTest = tbl(idxTest,:);

Convert the data to a format that the trainnet function supports. Convert the predictors and targets to numeric and categorical arrays, respectively. For feature input, the network expects data with rows that correspond to observations and columns that correspond to the features. If your data has a different layout, then you can preprocess your data to have this layout or you can provide layout information using data formats. For more information, see Deep Learning Data Formats.

predictorNames = ["SigMean" "SigMedian" "SigRMS" "SigVar" "SigPeak" "SigPeak2Peak" ...
    "SigSkewness" "SigKurtosis" "SigCrestFactor" "SigMAD" "SigRangeCumSum" ...
    "SigCorrDimension" "SigApproxEntropy" "SigLyapExponent" "PeakFreq" ...
    "HighFreqPower" "EnvPower" "PeakSpecKurtosis" "SensorCondition" "ShaftCondition"];
XTrain = table2array(tblTrain(:,predictorNames));
TTrain = tblTrain.(labelName);

View the sizes of the training predictors and targets.

size(XTrain)
ans = 1×2

   166    22

size(TTrain)
ans = 1×2

   166     1

Extract the validation and test predictors and targets using the same steps.

XValidation = table2array(tblValidation(:,predictorNames));
TValidation = tblValidation.(labelName);

XTest = table2array(tblTest(:,predictorNames));
TTest = tblTest.(labelName);

Define Neural Network Architecture

Define the neural network architecture.

  • For feature input, specify a feature input layer with the number of features. Normalize the input using Z-score normalization.

  • Specify a fully connected layer with a size of 16, followed by a layer normalization and ReLU layer

  • For classification output, specify a fully connected layer with a size that matches the number of classes, followed by a softmax layer.

numFeatures = size(XTrain,2);
hiddenSize = 16;
numClasses = numel(classNames);
 
layers = [
    featureInputLayer(numFeatures,Normalization="zscore")
    fullyConnectedLayer(hiddenSize)
    layerNormalizationLayer
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer];

Specify Training Options

Specify the training options:

  • Train using the L-BFGS solver. This solver suits tasks with small networks and when the data fits in memory.

  • Train using the CPU. Because the network and data are small, the CPU is better suited.

  • Validate the network every 5 iterations using the validation data.

  • Return the network with the lowest validation loss.

  • Display the training progress in a plot and monitor the accuracy metric.

  • Suppress the verbose output.

options = trainingOptions("lbfgs", ...
    ExecutionEnvironment="cpu", ...
    ValidationData={XValidation,TValidation}, ...
    ValidationFrequency=5, ...
    OutputNetwork="best-validation-loss", ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=false);

Train Neural Network

Train the neural network using the trainnet function. For classification, use cross-entropy loss.

net = trainnet(XTrain,TTrain,layers,"crossentropy",options);

The plot shows the training and validation accuracy and loss. When training completes, the plot shows the stopping reason. When you use the L-BFGS solver, the stopping reason can show that the line search failed and that the software was unable to find a suitable learning rate. This scenario can happen when the solver reaches a minimal loss value quickly, or when the step and gradient norms are close to zero.

Test Neural Network

Predict the labels of the test data using the trained network. Predict the classification scores using the trained network then convert the predictions to labels using the onehotdecode function.

scoresTest = predict(net,XTest);
YTest = onehotdecode(scoresTest,classNames,2);

Visualize the predictions in a confusion chart.

confusionchart(TTest,YTest)

Figure contains an object of type ConfusionMatrixChart.

Calculate the classification accuracy, The accuracy is the proportion of the labels that the network predicts correctly.

accuracy = mean(YTest == TTest)
accuracy = 0.8182

See Also

| | | |

Related Topics