Applying SHAP on a Reinforcement Learning Algorithm

I am trying to Apply SHAP on a reinforcement Learning Algorithm and I am not sure if MATLAB has the required SHAP packages such as shap.DeepExplainer() which is a python package.
If anyone has any further information on how to apply SHAP on the Neural Network agent of a Reinforcement Learning model, please let me know.

6 Comments

I have been wondering the same @Mahsa Raeisinezhad. Have you found further info? Documentation on shapley is not too clear and never mentions SHAP
figure;
for p=1:d
scatter(shap(:,sortedPredictorIndicesMEAN(p)), ... % x-value of each point is the shapley value
p*ones(n,1), ... % y-value of each point is an integer corresponding to a predictor (to be jittered below)
[], ... % Marker size for each data point, taking the default here
normalize(table2array(tbl(1:n,sortedPredictorIndicesMEAN(p))),'range',[1 256]), ... % Colors based on feature values
'filled', ... % Fills the circles representing data points
'YJitter','density', ... % YJitter according to the density of the points in this row
'YJitterWidth',0.8)
if (p==1)
hold on;
end
end
title('Shapley Summary plot');
xlabel('Shapley Value (impact on model output)')
yticks([1:d]);
yticklabels(tbl.Properties.VariableNames(sortedPredictorIndicesMEAN));
% Set colormap as desired
colormap(CoolBlueToWarmRedColormap); % This colormap is like the one used in many Shapley summary plots
% colormap(parula); % This is the default colormap
cb= colorbar('Ticks', [1 256], 'TickLabels', {'Low', 'High'});
cb.Label.String = "Scaled Feature Value";
cb.Label.FontSize = 12;
cb.Label.Rotation = 270;
set(gca, 'YGrid', 'on');
xline(0, 'LineWidth', 1);
hold off;
%%
function colormap = CoolBlueToWarmRedColormap()
% Define start point, middle luminance, and end point in L*ch colorspace
% https://www.mathworks.com/help/images/device-independent-color-spaces.html
% The three components of L*ch are Luminance, chroma, and hue.
blue_lch = [54 70 4.6588]; % Starting blue point
l_mid = 40; % luminance of the midpoint
red_lch = [54 90 6.6378909]; % Ending red point
nsteps = 256;
% Build matrix of L*ch colors that is nsteps x 3 in size
% Luminance changes linearly from start to middle, and middle to end.
% Chroma and hue change linearly from start to end.
lch=[[linspace(blue_lch(1), l_mid, nsteps/2), linspace(l_mid, red_lch(1), nsteps/2)]', ... luminance column
[linspace(blue_lch(2), red_lch(2), nsteps)]', ... chroma column
[linspace(blue_lch(3), red_lch(3), nsteps)]']; ... hue column
% Convert L*ch to L*a*b, where a = c * cos(h) and b = c * sin(h)
lab=[lch(:,1) lch(:,2).*cos(lch(:,3)) lch(:,2).*sin(lch(:,3))];
% Convert L*a*b to RGB
colormap=lab2rgb(lab,'OutputType','uint8');
myAct = @(env) predict_01(pretrainedAgent, env);
% number of itereations
shap=zeros(n,d);
figure;
hold on;
for i = 1:n
explainer = shapley(myAct, env);
explainer = fit(explainer, env);
shap(i,:)=explainer.ShapleyValues{:,2};
plot(explainer)
explainer_{i} = explainer;
end
function myAct_ = predict_01(tbl, env)
obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);
agent = rlPPOAgent(obsInfo,actInfo);
actor = getActor(agent);
actorNet = getModel(actor);
% Get the i-th slice of transposedArray
dataArray = table2array(tbl);
dlObservationsF = dlarray(dataArray, 'C');
% dlObservationsF = dlarray(dataArray(i,:), 'C');
% Predict using the actor network
myAct = predict(actorNet, dlObservationsF);
myAct_ = extractdata(myAct);
myAct_ = max(myAct);
end
The above is how tried to use SHAP in Matlab, I created a function handle of my Neural Network (agent) predictions using the Environment and applied SHAP on each decision, but I still highly recommend transfering to Python and using Python shap packages.

Sign in to comment.

Answers (2)

I decided to transfer everything in Python and use python packages. I used ONNX and Tensorflow for transferring everything. Hopefully if I have time in the future I write my own code to create same outcomes in Matlab.

1 Comment

Starting in R2024a, MATLAB has new functionality to more easily create shapley summary plots. This is described in the release notes https://www.mathworks.com/help/releases/R2024a/stats/release-notes.html.
To see an example on MATLAB answers, see:

Sign in to comment.

Document and refs are clear enough if you're aware enough what your intentions are. Follow this example if you're interested to still stick to MATLAB.

Asked:

on 5 Jun 2023

Commented:

on 5 Mar 2025

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!