Main Content

Assess Regression Neural Network Performance

Create a feedforward regression neural network model with fully connected layers using fitrnet. Use validation data for early stopping of the training process to prevent overfitting the model. Then, use the object functions of the model to assess its performance on test data.

Load Sample Data

Load the carbig data set, which contains measurements of cars made in the 1970s and early 1980s.

load carbig

Convert the Origin variable to a categorical variable. Then create a table containing the predictor variables Acceleration, Displacement, and so on, as well as the response variable MPG. Each row contains the measurements for a single car.

Origin = categorical(cellstr(Origin));
Tbl = table(Acceleration,Displacement,Horsepower, ...
    Model_Year,Origin,Weight,MPG);

Partition Data

Split the data into training, validation, and test sets. First, reserve approximately one third of the observations for the test set. Then, split the remaining data in half to create the training and validation sets.

rng("default") % For reproducibility of the data partitions
cvp1 = cvpartition(size(Tbl,1),"Holdout",1/3);
testTbl = Tbl(test(cvp1),:);
remainingTbl = Tbl(training(cvp1),:);

cvp2 = cvpartition(size(remainingTbl,1),"Holdout",1/2);
validationTbl = remainingTbl(test(cvp2),:);
trainTbl = remainingTbl(training(cvp2),:);

Train Neural Network

Train a regression neural network model by using the training set. Specify the MPG column of tblTrain as the response variable, and standardize the numeric predictors. Evaluate the model at each iteration by using the validation set. Specify to display the training information at each iteration by using the Verbose name-value argument. By default, the training process ends early if the validation loss is greater than or equal to the minimum validation loss computed so far, six times in a row. To change the number of times the validation loss is allowed to be greater than or equal to the minimum, specify the ValidationPatience name-value argument.

Mdl = fitrnet(trainTbl,"MPG","Standardize",true, ...
    "ValidationData",validationTbl, ...
    "Verbose",1);
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|           1|   71.063537|   22.623354|    6.466959|    0.001272|   72.648960|           0|
|           2|   48.608700|   22.384995|    1.022929|    0.001808|   43.435698|           0|
|           3|   30.584887|   13.433471|    0.537190|    0.000903|   29.134447|           0|
|           4|   17.781636|   11.159801|    1.401355|    0.000461|   16.542207|           0|
|           5|   13.075804|    4.605991|    0.419875|    0.000387|   12.946670|           0|
|           6|   11.697936|    3.197944|    0.226945|    0.000543|   12.025502|           0|
|           7|    9.494801|    2.269831|    0.751711|    0.000452|   12.596499|           1|
|           8|    8.390979|    1.970589|    0.337301|    0.000398|   11.490990|           0|
|           9|    6.853097|    1.029078|    0.866974|    0.000378|    9.449945|           0|
|          10|    6.531678|    0.924820|    0.306913|    0.000429|    9.350721|           0|
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|          11|    6.152995|    1.872684|    0.457744|    0.000403|    9.223829|           0|
|          12|    5.924852|    0.718386|    0.447879|    0.000402|    9.656166|           1|
|          13|    5.792836|    0.500170|    0.216351|    0.000387|    9.733226|           2|
|          14|    5.613473|    1.151197|    0.316828|    0.000531|    9.788646|           3|
|          15|    5.415889|    1.513493|    0.327937|    0.000485|    9.607953|           4|
|          16|    5.008195|    1.398069|    1.085660|    0.000430|    9.251971|           5|
|          17|    5.004176|    2.070041|    0.890201|    0.000383|    8.719334|           0|
|          18|    4.738386|    0.483667|    0.338897|    0.000374|    8.523728|           0|
|          19|    4.680213|    0.437918|    0.107667|    0.000371|    8.369271|           0|
|          20|    4.587350|    0.510639|    0.146276|    0.000385|    8.100236|           0|
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|          21|    4.479929|    0.565635|    0.228198|    0.000381|    8.062927|           0|
|          22|    4.380618|    0.892717|    0.377776|    0.000554|    7.843234|           0|
|          23|    4.189344|    0.403227|    0.362307|    0.000434|    7.834582|           0|
|          24|    4.182775|    1.150234|    1.908768|    0.000408|    9.436226|           1|
|          25|    3.985939|    0.908479|    0.518217|    0.000570|    8.973756|           2|
|          26|    3.873835|    0.826655|    0.477740|    0.000505|    8.863599|           3|
|          27|    3.830830|    0.331936|    0.220000|    0.000539|    8.574682|           4|
|          28|    3.796605|    0.232756|    0.075643|    0.000492|    8.591758|           5|
|          29|    3.706326|    0.470116|    0.249292|    0.000396|    8.517317|           6|
|==========================================================================================|

Use the information inside the TrainingHistory property of the object Mdl to check the iteration that corresponds to the minimum validation mean squared error (MSE). The final returned model Mdl is the model trained at this iteration.

iteration = Mdl.TrainingHistory.Iteration;
valLosses = Mdl.TrainingHistory.ValidationLoss;
[~,minIdx] = min(valLosses);
iteration(minIdx)
ans = 23

Evaluate Test Set Performance

Evaluate the performance of the trained model Mdl on the test set testTbl by using the loss and predict object functions.

Compute the test set mean squared error (MSE). Smaller MSE values indicate better performance.

mse = loss(Mdl,testTbl,"MPG")
mse = 25.4145

Compare the predicted test set response values to the true response values. Plot the predicted miles per gallon (MPG) along the vertical axis and the true MPG along the horizontal axis. Points on the reference line indicate correct predictions. A good model produces predictions that are scattered near the line.

predictedY = predict(Mdl,testTbl);

plot(testTbl.MPG,predictedY,".")
hold on
plot(testTbl.MPG,testTbl.MPG)
hold off
xlabel("True Miles Per Gallon (MPG)")
ylabel("Predicted Miles Per Gallon (MPG)")

Use box plots to compare the distribution of predicted and true MPG values by country of origin. Create the box plots by using the boxchart function. Each box plot displays the median, the lower and upper quartiles, any outliers (computed using the interquartile range), and the minimum and maximum values that are not outliers. In particular, the line inside each box is the sample median, and the circular markers indicate outliers.

For each country of origin, compare the red box plot (showing the distribution of predicted MPG values) to the blue box plot (showing the distribution of true MPG values). Similar distributions for the predicted and true MPG values indicate good predictions.

boxchart(testTbl.Origin,testTbl.MPG)
hold on
boxchart(testTbl.Origin,predictedY)
hold off
legend(["True MPG","Predicted MPG"])
xlabel("Country of Origin")
ylabel("Miles Per Gallon (MPG)")

For most countries, the predicted and true MPG values have similar distributions. However, the neural network model tends to underestimate the MPG values for cars made in France. This discrepancy is possibly due to the small number of French cars in the training and test sets.

Compare the range of MPG values for French cars in the training and test sets.

trainSummary = grpstats(trainTbl(:,["MPG","Origin"]),"Origin", ...
    ["min","max"])
trainSummary=6×4 table
               Origin     GroupCount    min_MPG    max_MPG
               _______    __________    _______    _______

    France     France          3         16.2         27  
    Germany    Germany        11           20       44.3  
    Italy      Italy           1         37.3       37.3  
    Japan      Japan          24           20       40.8  
    Sweden     Sweden          3           19       21.6  
    USA        USA            94            9         39  

testSummary = grpstats(testTbl(:,["MPG","Origin"]),"Origin", ...
    ["min","max"])
testSummary=6×4 table
               Origin     GroupCount    min_MPG    max_MPG
               _______    __________    _______    _______

    France     France          3         28.1       40.9  
    Germany    Germany        12         21.5         44  
    Italy      Italy           3           28         30  
    Japan      Japan          32           18       46.6  
    Sweden     Sweden          3           17         24  
    USA        USA            82           10       36.1  

In the training set, the MPG values for cars made in France range from 16.2 to 27. However, in the test set, the MPG values for cars made in France range from 28.1 to 40.9.

Plot the test set residuals. A good model usually has residuals scattered roughly symmetrically around 0. Clear patterns in the residuals are a sign that you can improve your model.

residuals = testTbl.MPG - predictedY;
plot(testTbl.MPG,residuals,".")
hold on
yline(0)
hold off
xlabel("True Miles Per Gallon (MPG)")
ylabel("MPG Residuals")

The plot suggests that some residual values are outliers. Find out more information about the observation with the greatest residual.

[outlierResidual,outlierIdx] = max(residuals)
outlierResidual = 37.8727
outlierIdx = 113
testTbl(outlierIdx,:)
ans=1×7 table
    Acceleration    Displacement    Horsepower    Model_Year    Origin    Weight    MPG 
    ____________    ____________    __________    __________    ______    ______    ____

        17.3             85            NaN            80        France     1835     40.9

The observation corresponds to a car whose Horsepower value is missing and whose country of origin is France, a category with few observations.

See Also

| | | |