At the high-level, the way to build this plot is:
(1) Load data and build the model
(2) Calculate the Shapley values using the shapley function (introduced in 21a)
(3) Create the Shapley summary plot
For R2024a and higher:
(As of this writing in January 2024, R2024a can be accessed by using the prerelease. The general release of R2024a is planned for March 2024.)
Given a trained machine learning model, you can now use the shapley and fit functions to compute Shapley values for multiple query points. After using the QueryPoints name-value argument to specify multiple query points, you can visualize the results in the fitted shapley object by using the following object functions: - plot — Plot mean absolute Shapley values using bar graphs.
- boxchart — Visualize Shapley values using box charts (box plots).
- swarmchart — Visualize Shapley values using swarm scatter charts.
Here is a regression example on the widely available wine quality data set which was pictured in the question.
alldata=readtable("winequality-red.csv");
c = cvpartition(size(alldata,1),"Holdout",0.20);
x_train=alldata(training(c),:);
x_test=alldata(test(c),:);
mdl=fitrensemble(x_train,response,'Method','Bag', ...
'Learners',templateTree('MaxNumSplits',126),'NumLearningCycles',10);
explainer = shapley(mdl,x_train,'QueryPoints',alldata)
swarmchart(explainer,NumImportantPredictors=11,ColorMap='bluered')
Here is the resulting figure when using the default figure window size. Notice how easy it is to use a single line of code to calculate the shapley values across all of the query points, and a single line of code to create this plot.
Eleven predictors were squeezed onto that plot, so there is some undesirable overlap of data points from different predictors. That can easily be fixed in a variety of ways.
(1) If the figure window is simply resized to be sufficiently large, then the undersirable overlap of data points from different predictors goes away. Next, if desired, the font size of the labels can be increased. The image below is an example result. There are many variations depending on the figure window size and aspect ratio, and the size of the font for the labels.
(2) Alternately, one could get the handle to the underlying scatter object for each predictor, and adjust the data marker size, that is, the 'SizeData' property.
ReductionFactorMarkerSize = 0.2;
if (strcmp(class(h(i)),"matlab.graphics.chart.primitive.Scatter"))
InitialMarkerSize = get(h(i),'SizeData');
set(h(i),'SizeData',InitialMarkerSize*ReductionFactorMarkerSize);
This also achieves the goal of avoiding overlap between data points for different predictors. Here is the resulting figure, at the default figure window size. This plot can easily be resized to adjust the look as desired.
For R2024a or higher, here is a classification example using the fisheriris dataset:
alldata=readtable("fisheriris.csv");
c = cvpartition(alldata.(response),"Holdout",0.20);
x_train=alldata(training(c),:);
x_test=alldata(test(c),:);
mdl=fitcensemble(x_train,response,'Method','Bag','NumLearningCycles',7);
explainer = shapley(mdl, x_train, 'QueryPoints',alldata)
figure(1); clf; tiledlayout(2,2); nexttile(1);
swarmchart(explainer,ClassName=mdl.ClassNames{i-1},ColorMap='bluered')
Here is the resulting tiled plot which includes a shapley importance plot in the upper left, and a shapley summary plot swarmchart for each output class
For R2023b and earlier:
An example is shown below using the carsmall dataset, without using the new shapley functionality in R2024a. The two predictors Horsepower and Weight are used to predict the MPG (Miles Per Gallon) of a car.
(1) Load the data and build the model
x_train=table(Horsepower,Weight,MPG);
Mdl=fitrensemble(x_train,response,'Method','LSBoost','NumLearningCycles',100);
(2) Calculate Shapley values one query point at a time
explainer=shapley(Mdl,'QueryPoint',x_train(i,:),'UseParallel',false);
shap(i,:)=explainer.ShapleyValues{:,2};
toc;
Elapsed time is 8.878466 seconds.
(3) Create the Shapley summary plot using multiple calls to "scatter"
[sortedMeanAbsShapValues,sortedPredictorIndices]=sort(mean(abs(shap)));
scatter(shap(:,sortedPredictorIndices(p)), ...
normalize(table2array(x_train(:,sortedPredictorIndices(p))),'range',[1 256]), ...
title('Shapley Summary plot');
xlabel('Shapley Value (impact on model output)')
yticklabels(x_train.Properties.VariableNames(sortedPredictorIndices));
colormap(CoolBlueToWarmRedColormap);
cb= colorbar('Ticks', [1 256], 'TickLabels', {'Low', 'High'});
cb.Label.String = "Scaled Feature Value";
xline(0, 'LineWidth', 1);
A few notes:
- For Weight and Horsepower, there are many query points where high values of those features have negative Shapley values. This is as expected, since high values for those predictors will generally tend to reduce the MPG of a car.
- The Shapley summary plot colorbar can be extended to categorical features by mapping the categories to integers using the "unique" function, e.g., [~, ~, integerReplacement]=unique(originalCategoricalArray).
- For classification problems, a Shapley summary plot can be created for each output class. In that case, the shap variable could be a tensor ("3-D matrix") with indices as: (query-point-index, predictor-index, output-class-index)
Function to create CoolBlueToWarmRedColormap
function colormap = CoolBlueToWarmRedColormap()
blue_lch = [54 70 4.6588];
red_lch = [54 90 6.6378909];
lch=[[linspace(blue_lch(1), l_mid, nsteps/2), linspace(l_mid, red_lch(1), nsteps/2)]', ...
[linspace(blue_lch(2), red_lch(2), nsteps)]', ...
[linspace(blue_lch(3), red_lch(3), nsteps)]']; ...
lab=[lch(:,1) lch(:,2).*cos(lch(:,3)) lch(:,2).*sin(lch(:,3))];
colormap=lab2rgb(lab,'OutputType','uint8');