How can I obtain the Shapley values from a Long Short Term Mermory network?

7 views (last 30 days)
I have created a LSTM neural network and did regression analysis. I want to calculate its Shapley value by executing this code:
layers = [ ...
sequenceInputLayer(numFeatures)
lstmLayer(numHiddenUnits)
fullyConnectedLayer(numResponses)
regressionLayer];
net = trainNetwork(Xtrain,Ytrain,layers,options);
blackbox = @(x)predict(net,Xtrain);
explainer = shapley(blackbox,Xtrain);
I get the following error:
validateattributes(out,{'double','single'}, {'column','nonempty'},
mfilename,getString(message('stats:shapley:FunctionHandleOutput')));

Answers (1)

Ahmadreza
Ahmadreza on 27 Jan 2023
I think that the LSTM network is not compatible with the Shapley function. Currently, only the following models are supported:
Regression Model Object: Ensemble of regression models, Gaussian kernel regression model using random feature expansion, Gaussian process regression, Generalized additive model, Linear regression for high-dimensional data, Neural
Classification Model Object: Discriminant analysis classifier, Multiclass model for support vector machines or other classifiers, Ensemble of learners for classification, Gaussian kernel classification model using random feature expansion, Generalized additive model, k-nearest neighbor classifier, Linear classification model, Multiclass naive Bayes model, Neural network classifier, Support vector machine classifier for one-class and binary classification, Binary decision tree for multiclass classification.
https://www.mathworks.com/help/stats/shapley.html#mw_c2327b12-104d-48ef-8a71-1f0e8769549b

Categories

Find more on Dimensionality Reduction and Feature Extraction in Help Center and File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!