How can I obtain the Shapley values from a Neural Network Object?

10 views (last 30 days)
I have created a neural network for pattern recognition with the "patternnet" function and would like the calculate its Shapley values by executing this code:
[x,t] = iris_dataset; 
net = patternnet(7);  
net = train(net,x,t) 
queryPoint=x(:,1)'
explainer = shapley(net,x,'QueryPoint',queryPoint)
However, I receive the following error:
Error using shapley
Blackbox model must be a classification model, regression model, or function handle
Is there a way to obtain the Shapley values from a "network" object as the one above? 

Accepted Answer

MathWorks Support Team
MathWorks Support Team on 18 Jun 2024
Edited: MathWorks Support Team on 18 Jun 2024
It is possible to obtain Shapley values for a pattern recognition network by passing a function handle to the "shapley" function.  The function handle needs to output the score for the class of interest.  Note that "shapley" expects inputs and outputs for the function handle to be row vectors rather than column vectors, so some transposes are needed for the function to work as expected.
Below is an example of how to do this using the Fisher Iris data
% Train a neural network on the iris data
[x,t] = iris_dataset;
net = patternnet(10);
net = train(net,x,t);
% Choose an observation to explain. We need its class as an index.
x1 = x(:,1);
t1 = find(t(:,1));
% Plot Shapley values. For Setosa (the first class) the petal length (x3)
% is usually the most informative feature.
explainer = shapley( ...
@(x)predictScoreForSpecifiedClass(net,x,t1), ...
x', "QueryPoint", x1' );
plot(explainer)
% Helpers
function score = predictScoreForSpecifiedClass(net, x, classIndex)
Y = net(x');
score = Y(classIndex,:)';
end

More Answers (0)

Products


Release

R2021a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!