How to use KNN to classify data in MATLAB?

I'm having problems in understanding how K-NN classification works in MATLAB.´ Here's the problem, I have a large dataset (65 features for over 1500 subjects) and its respective classes' label (0 or 1). According to what's been explained to me, I have to divide the data into training, test and validation subsets to perform supervised training on the data, and classify it via K-NN. First of all, what's the best ratio to divide the 3 subgroups (1/3 of the size of the dataset each?).
I've looked into ClassificationKNN/fitcknn functions, as well as the crossval function (idealy to divide data), but I'm really not sure how to use them.
To sum up, I wanted to - divide data into 3 groups - "train" the KNN (I know it's not a method that requires training, but the equivalent to training) with the training subset - classify the test subset and get it's classification error/performance - what's the point of having a validation test?
I hope you can help me, thank you in advance

1 Comment

Hi
please anyone can help me on how to create open circuit fault inverter using simulink (tutorial ,i need please)

Sign in to comment.

Answers (1)

You need a validation set if you want to tune certain parameters in the classifier. For example if you were to use SVM with rbf kernel, then you can choose the kernel parameters using validation. A model trained on the training data is tested on Test data to see how it performs on unseen data. If these concepts are not clear, I recommend reading some literature on this topic before proceeding.
The approach you mention is holdout. There are other way to do cross validation. For holdout, how much to divide the data is upto you and of course the nature of the dataset.
Take a look at how you can use cvpartition to try different cross validation techniques:
All the function have examples you can run directly.

4 Comments

Thank you for your answer, now I'm able to understand the purpose of the 3 sets and how to obtain them. I've read and tried the examples, I wouldn't bore anyone without trying it first. I'm still not quite able to classify via KNN nor am I sure how can I adapt any of its parameters based on the validation results. If you or anyone could help me with this (an example would be fantastic) I would be very greatfull, but thank you already for your colaboration
It is not clear what you need help with. Can you provide examples of what you have tried and what is not working?
I'm still not quite able to classify via KNN
Take a look at the example here on how to classify:
more examples:
nor am I sure how can I adapt any of its parameters based on the validation results
What parameters do you want to tune? If you want to tune the number of neighbors (for whatever reason) then run a grid search on the number of neighbors and each time test its performance on validation test and then choose the number that gave you best results/lowest error. A more sophisticated way of doing this is to run an optimization. Here is an example of just that for SVM, depending on what you want to do with fitcknn you may apply similar techniques:
I think I was able to do it, but, if that's not asking too much, could you see if I missed something? This is my code, for a random case:
nfeats=60;ninds=1000;
trainRatio=0.8;valRatio=.1;testRatio=.1;
kmax=100; %for instance...
data=randi(100,nfeats,ninds);
class=randi(2,1,ninds);
[trainInd,valInd,testInd] = dividerand(1000,trainRatio,valRatio,testRatio);
train=data(:,trainInd);
test=data(:,testInd);
val=data(:,valInd);
train_class=class(:,trainInd);
test_class=class(:,testInd);
val_class=class(:,valInd);
precisionmax=0;
koptimal=0;
for know=1:kmax
%is it the same thing use knnclassify or fitcknn+predict??
predicted_class = knnclassify(val', train', train_class',know);
mdl = fitcknn(train',train_class','NumNeighbors',know) ;
label = predict(mdl,val');
consistency=sum(label==val_class')/length(val_class);
if consistency>precisionmax
precisionmax=consistency;
koptimal=know;
end
end
mdl_final = fitcknn(train',train_class','NumNeighbors',know) ;
label_final = predict(mdl,test');
consistency_final=sum(label==test_class')/length(test_class);
Thank you very much for all your help, Shashank Prasanna.
I don't know if there's a better way of comparing the true test_label and the predicted label...

Sign in to comment.

Asked:

on 11 Jul 2014

Commented:

on 2 Dec 2022

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!