kfoldfun
Cross-validate function for classification
Description
Examples
Estimate Classification Loss Using Custom Loss Function
Train a classification tree classifier, and then cross-validate it using a custom k-fold loss function.
Load Fisher’s iris data set.
load fisheriris
Train a classification tree classifier.
Mdl = fitctree(meas,species);
Mdl
is a ClassificationTree
model.
Cross-validate Mdl
using the default 10-fold cross-validation. Compute the classification error (proportion of misclassified observations) for the validation-fold observations.
rng(1); % For reproducibility CVMdl = crossval(Mdl); L = kfoldLoss(CVMdl,'LossFun','classiferror')
L = 0.0467
Examine the result when the cost of misclassifying a flower as versicolor
is 10
, and the cost of any other misclassification is 1
. Create the custom function noversicolor
(shown at the end of this example). This function attributes a cost of 10
for misclassifying a flower as versicolor
, and a cost of 1
for any other misclassification.
Compute the mean misclassification error with the noversicolor
cost.
mean(kfoldfun(CVMdl,@noversicolor))
ans = 0.2267
This code creates the function noversicolor
.
function averageCost = noversicolor(CMP,~,~,~,Xtest,Ytest,~) % noversicolor Example custom cross-validation function % Attributes a cost of 10 for misclassifying versicolor irises, and 1 for % the other irises. This example function requires the fisheriris data % set. Ypredict = predict(CMP,Xtest); misclassified = not(strcmp(Ypredict,Ytest)); % Different result classifiedAsVersicolor = strcmp(Ypredict,'versicolor'); % Index of bad decisions cost = sum(misclassified) + ... 9*sum(misclassified & classifiedAsVersicolor); % Total differences averageCost = cost/numel(Ytest); % Average error end
Input Arguments
CVMdl
— Cross-validated model
ClassificationPartitionedModel
object | ClassificationPartitionedEnsemble
object | ClassificationPartitionedGAM
object
Cross-validated model, specified as a ClassificationPartitionedModel
object,
ClassificationPartitionedEnsemble
object, or
ClassificationPartitionedGAM
object.
fun
— Cross-validated function
function handle
Cross-validated function, specified as a function handle. fun
has the
syntax:
testvals = fun(CMP,Xtrain,Ytrain,Wtrain,Xtest,Ytest,Wtest)
CMP
is a compact model stored in one element of theCVMdl
.Trained
property.Xtrain
is the training matrix of predictor values.Ytrain
is the training array of response values.Wtrain
are the training weights for observations.Xtest
andYtest
are the test data, with associated weightsWtest
.The returned value
testvals
must have the same size across all folds.
Data Types: function_handle
Output Arguments
vals
— Cross-validation results
numeric matrix
Cross-validation results, returned as a numeric matrix. vals
contains the
arrays of testvals
output, concatenated vertically over all folds.
For example, if testvals
from every fold is a numeric vector of
length N
, kfoldfun
returns a
KFold
-by-N
numeric matrix with one row per
fold.
Data Types: double
Extended Capabilities
GPU Arrays
Accelerate code by running on a graphics processing unit (GPU) using Parallel Computing Toolbox™.
Usage notes and limitations:
fun
must acceptgpuArray
inputs.This function fully supports GPU arrays for the following cross-validated model objects:
Ensemble classifier trained with
fitcensemble
k-nearest neighbor classifier trained with
fitcknn
Support vector machine classifier trained with
fitcsvm
Binary decision tree for multiclass classification trained with
fitctree
Neural network for classification trained with
fitcnet
For more information, see Run MATLAB Functions on a GPU (Parallel Computing Toolbox).
Version History
Introduced in R2011aR2024b: Specify GPU arrays for neural network models (requires Parallel Computing Toolbox)
kfoldfun
fully supports GPU arrays for ClassificationPartitionedModel
models trained using
fitcnet
.
See Also
ClassificationPartitionedModel
| kfoldPredict
| kfoldEdge
| kfoldMargin
| kfoldLoss
MATLAB Command
You clicked a link that corresponds to this MATLAB command:
Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: United States.
You can also select a web site from the following list
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)
Asia Pacific
- Australia (English)
- India (English)
- New Zealand (English)
- 中国
- 日本Japanese (日本語)
- 한국Korean (한국어)