Main Content

RegressionPartitionedNeuralNetwork

Cross-validated regression neural network model

Since R2023b

    Description

    RegressionPartitionedNeuralNetwork is a set of regression neural network models trained on cross-validated folds. Estimate the quality of the cross-validated regression by using one or more kfold functions: kfoldPredict, kfoldLoss, and kfoldfun.

    Every kfold object function uses models trained on training-fold (in-fold) observations to predict the response for validation-fold (out-of-fold) observations. For example, suppose you cross-validate using five folds. The software randomly assigns each observation into five groups of equal size (roughly). The training fold contains four of the groups (roughly 4/5 of the data), and the validation fold contains the other group (roughly 1/5 of the data). In this case, cross-validation proceeds as follows:

    1. The software trains the first model (stored in CVMdl.Trained{1}) by using the observations in the last four groups, and reserves the observations in the first group for validation.

    2. The software trains the second model (stored in CVMdl.Trained{2}) by using the observations in the first group and the last three groups. The software reserves the observations in the second group for validation.

    3. The software proceeds in a similar manner for the third, fourth, and fifth models.

    If you validate by using kfoldPredict, the software computes predictions for the observations in group i by using model i. In short, the software estimates a response for every observation by using the model trained without that observation.

    Creation

    You can create a RegressionPartitionedNeuralNetwork object in two ways:

    • Create a cross-validated model from a regression neural network model object RegressionNeuralNetwork by using the crossval object function.

    • Create a cross-validated model by using the fitrnet function and specifying one of the name-value arguments CrossVal, CVPartition, Holdout, KFold, or Leaveout.

    Properties

    expand all

    Cross-Validation Properties

    This property is read-only.

    Cross-validated model name, specified as 'NeuralNetwork'.

    Data Types: char

    This property is read-only.

    Number of cross-validated folds, specified as a positive integer.

    Data Types: double

    This property is read-only.

    Cross-validation parameter values, specified as an EnsembleParams object. The parameter values correspond to the values of the name-value arguments used to cross-validate the neural network model. ModelParameters does not contain estimated parameters.

    You can access the properties of ModelParameters using dot notation.

    This property is read-only.

    Data partition indicating how the software splits the data into cross-validation folds, specified as a cvpartition model.

    This property is read-only.

    Compact models trained on cross-validation folds, specified as a cell array of CompactRegressionNeuralNetwork model objects. Trained has k cells, where k is the number of folds.

    Data Types: cell

    Other Regression Properties

    This property is read-only.

    Categorical predictor indices, specified as a vector of positive integers. CategoricalPredictors contains index values indicating that the corresponding predictors are categorical. The index values are between 1 and p, where p is the number of predictors used to train the model. If none of the predictors are categorical, then this property is empty ([]).

    Data Types: double

    This property is read-only.

    Number of observations in the training data stored in X and Y, specified as a numeric scalar.

    Data Types: double

    This property is read-only.

    Predictor variable names, specified as a cell array of character vectors. The order of the elements in PredictorNames corresponds to the order in which the predictor names appear in the training data.

    Data Types: cell

    This property is read-only.

    Response variable name, specified as a character vector.

    Data Types: char

    Response transformation function, specified as 'none' or a function handle. ResponseTransform describes how the software transforms raw response values.

    For a MATLAB® function or a function that you define, enter its function handle. For example, you can enter CVMdl.ResponseTransform = @function, where function accepts a numeric vector of the original responses and returns a numeric vector of the same size containing the transformed responses.

    Data Types: char | function_handle

    This property is read-only.

    Observation weights, specified as an n-by-1 numeric vector, where n is the number of observations (NumObservations). The software normalizes the observation weights so that the elements of W sum to 1.

    Data Types: double

    This property is read-only.

    Unstandardized predictors used to cross-validate the model, specified as a numeric matrix or table. X retains its original orientation, with observations in rows or columns depending on the value of the ObservationsIn name-value argument in the call to fitrnet.

    Data Types: single | double | table

    This property is read-only.

    Response values used to cross-validate the model, specified as a numeric vector. Each row of Y represents the response value of the corresponding observation in X.

    Data Types: single | double

    Object Functions

    gatherGather properties of Statistics and Machine Learning Toolbox object from GPU
    kfoldLossLoss for cross-validated partitioned regression model
    kfoldPredictPredict responses for observations in cross-validated regression model
    kfoldfunCross-validate function for regression

    Examples

    collapse all

    Train a cross-validated regression neural network with 10 folds, which is the default cross-validation option, by using fitrnet. Then, use kfoldPredict to predict responses for validation-fold observations using a model trained on training-fold observations.

    Load the carbig data set, which contains measurements of cars made in the 1970s and early 1980s.

    load carbig

    Create a table that contains the predictor variables (Acceleration, Displacement, Horsepower, and Weight) and the response variable (MPG).

    Tbl = table(Acceleration,Displacement,Horsepower,Weight,MPG);

    Create a cross-validated regression neural network by using the default cross-validation option. Specify the CrossVal value as "on". For a better model fit, standardize the numeric predictors.

    rng("default") % For reproducibility
    CVMdl = fitrnet(Tbl,"MPG", ...
        CrossVal="on",Standardize=true)
    CVMdl = 
      RegressionPartitionedNeuralNetwork
        CrossValidatedModel: 'NeuralNetwork'
             PredictorNames: {'Acceleration'  'Displacement'  'Horsepower'  'Weight'}
               ResponseName: 'MPG'
            NumObservations: 398
                      KFold: 10
                  Partition: [1x1 cvpartition]
          ResponseTransform: 'none'
    
    
    

    The fitrnet function creates a RegressionPartitionedNeuralNetwork model object CVMdl with 10 folds. During cross-validation, the software completes these steps:

    1. Randomly partition the data into 10 sets.

    2. For each set, reserve the set as validation data, and train the model using the other 9 sets.

    3. Store the 10 compact trained models in a 10-by-1 cell vector in the Trained property of the cross-validated model object.

    You can override the default cross-validation setting by using the CVPartition, Holdout, KFold, or Leaveout name-value argument.

    Predict responses for the observations in Tbl by using kfoldPredict. The function predicts responses for every observation using the model trained without that observation.

    yHat = kfoldPredict(CVMdl);

    yHat is a numeric vector. Display the first five predicted responses.

    yHat(1:5)
    ans = 5×1
    
       17.2245
       15.3107
       15.8676
       15.2423
       16.3009
    
    

    Compute the regression loss (mean squared error).

    L = kfoldLoss(CVMdl)
    L = 
    16.3743
    

    kfoldLoss returns the average mean squared error over 10 folds.

    Extended Capabilities

    Version History

    Introduced in R2023b

    expand all