Lasso Regularization
This example demonstrates the use of lasso for feature selection by looking at a dataset and identifying predictors of diabetes in a population. The dataset contains 10 predictors. The goal is to identify important predictors and discard those that are unnecessary.
View the complete set of data and functions for this demonstration.
filename = 'diabetes.txt'; urlwrite('http://www.stanford.edu/~hastie/Papers/LARS/diabetes.data',filename);
Once the file is saved, you can import data into MATLAB as a table using the Import Tool with default options. Alternatively, you can use the following code which can be auto generated from the Import Tool:
formatSpec = '%f%f%f%f%f%f%f%f%f%f%f%[^\n\r]'; fileID = fopen(filename,'r'); dataArray = textscan(fileID, formatSpec, 'Delimiter', '\t', 'HeaderLines' ,1, 'ReturnOnError', false); fclose(fileID); diabetes = table(dataArray{1:end-1}, 'VariableNames', {'AGE','SEX','BMI','BP','S1','S2','S3','S4','S5','S6','Y'}); clearvars filename delimiter startRow formatSpec fileID dataArray ans; % Delete the file delete diabetes.txt
predNames = diabetes.Properties.VariableNames(1:end-1); X = diabetes{:,1:end-1}; y = diabetes{:,end};
[beta, FitInfo] = lasso(X,y,'Standardize',true,'CV',10,'PredictorNames',predNames); lassoPlot(beta,FitInfo,'PlotType','Lambda','XScale','log'); hlplot = get(gca,'Children'); % Generating colors for each line in the plot colors = hsv(numel(hlplot)); for ii = 1:numel(hlplot) set(hlplot(ii),'color',colors(ii,:)); end set(hlplot,'LineWidth',2) set(gcf,'Units','Normalized','Position',[0.2 0.4 0.5 0.35]) legend('Location','Best')
Larger values of lambda appear on the left side of the graph, which means that there is increased regularization. As the lambda value increases, the number of nonzero predictors also increases.
As a rule of thumb, one standard-error value is often used for choosing a smaller model with a good fit.
lam = FitInfo.Index1SE; isImportant = beta(:,lam) ~= 0; disp(predNames(isImportant))
'BMI' 'BP' 'S3' 'S5'
mdlFull = fitlm(X,y,'Intercept',false);
disp(mdlFull)
Linear regression model: y ~ x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + x10 Estimated Coefficients: Estimate SE tStat pValue x1 0.022296 0.22256 0.10018 0.92025 x2 -26.073 5.9561 -4.3775 1.5074e-05 x3 5.3537 0.73462 7.2877 1.5112e-12 x4 1.0178 0.2304 4.4175 1.2635e-05 x5 1.2636 0.33044 3.8239 0.00015068 x6 -1.2849 0.3468 -3.7051 0.00023877 x7 -3.0683 0.37189 -8.2505 1.9259e-15 x8 -5.508 5.5883 -0.98565 0.32486 x9 5.5034 9.4293 0.58365 0.55976 x10 0.12339 0.2788 0.44256 0.6583 Number of observations: 442, Error degrees of freedom: 432 Root Mean Squared Error: 55.6
Compare the MSE for regularized and unregularized models.
disp(['Lasso MSE: ', num2str(FitInfo.MSE(lam))]) disp(['Full MSE: ', num2str(mdlFull.MSE)])
Lasso MSE: 3176.5163 Full MSE: 3092.896
The mean squared error (MSE) of the fit using only the important predictors as determined by lasso, is quite close to the error from the linear model that uses all the predictors. Lasso is often used to prevent overfitting or remove redundant predictors to improve model accuracy.