5 fold cross validation code for a dataset

I want to split my data set into train set as well as test set using 5 fold cross validation .

Answers (1)

Hi Subhasmita,
In MATLAB, you can perform k-fold cross-validation to split your dataset into training and test sets. In k-fold cross-validation, the dataset is divided into k subsets (folds). The model is trained k times, each time using a different fold as the test set and the remaining folds as the training set.
Here's how you can perform 5-fold cross-validation in MATLAB:
% Load your dataset
load fisheriris % Example dataset
X = meas; % Features
y = species; % Labels
% Define the number of folds
k = 5;
% Create a cross-validation partition
cv = cvpartition(y, 'KFold', k);
% Initialize variable to store accuracy for each fold
accuracy = zeros(k, 1);
for i = 1:k
% Get the training and test indices for the current fold
trainIdx = training(cv, i);
testIdx = test(cv, i);
% Split the data into training and test sets for this fold
XTrain = X(trainIdx, :);
yTrain = y(trainIdx, :);
XTest = X(testIdx, :);
yTest = y(testIdx, :);
% Train the model on the training set
model = fitcsvm(XTrain, yTrain);
% Test the model on the test set
predictions = predict(model, XTest);
% Calculate accuracy for the current fold
accuracy(i) = sum(predictions == yTest) / length(yTest);
% Display accuracy for the current fold
fprintf('Fold %d Accuracy: %.2f%%\n', i, accuracy(i) * 100);
end
% Calculate and display the average accuracy across all folds
averageAccuracy = mean(accuracy);
fprintf('Average Accuracy: %.2f%%\n', averageAccuracy * 100);
Explanation
  1. Data Loading: We use the fisheriris dataset for demonstration, where X contains the features and y contains the labels.
  2. Cross-Validation Partition: We create a 5-fold partition using cvpartition with the option 'KFold', k.
  3. Loop Through Folds: For each fold, we:
  • Extract training and test indices.
  • Split the data into training and test sets.
  • Train a support vector machine (SVM) model using fitcsvm.
  • Predict on the test set and calculate accuracy.
4. Accuracy Calculation: We calculate and print the accuracy for each fold and the average accuracy across all folds.
Additional Notes
  • You can replace fitcsvm with any other classifier that suits your needs.
  • Ensure that your dataset is suitable for cross-validation, especially regarding class balance.
  • You might also want to explore MATLAB's crossval function, which automates some parts of this process.

Asked:

on 24 May 2018

Answered:

on 4 Sep 2024

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!