Clear Filters
Clear Filters

BiLSTM Empty Backward Cell states and Hidden states

3 views (last 30 days)
I noticed that for the BiLSTM layer, the second half of CellStates and HiddenStates (which corresponds to the backwards states) are always 0, despite that I have called classifyAndUpdateStates().
Note that the NumHiddenUnits of the layer is set to 15. The first 15 elements in the CellState and HiddenState array are not 0, while the subsequent 15 elements are 0.
I have also tried to replicate the behaviour of the forward pass of a biLSTM layer of matlab using the codes below:
%% Testing using BiLSTM layer
% Selecting one input data
n = 1;
X = Xdata{n};
Y = Ydata(n);
n_seq = size(X, 2); % number of sequences to be fed in
layer_num = 3; % First biLSTM layer
inputSize = size(Xdata{1}, 1);
numHU = net_best.Layers(layer_num).NumHiddenUnits;
% Activation using matlab functions
feature = activations(net_best, X(:, 1:n_seq), layer_num);
[net_best_seq, Ypred] = classifyAndUpdateState(net_best, X(:, 1:n_seq));
hidden_state_actual = feature{1};
cell_state_actual = net_best_seq.Layers(layer_num).CellState;
% Recreation of activation
% Getting weights and biases
W.Wall = net_best.Layers(layer_num).InputWeights;
b.ball = net_best.Layers(layer_num).Bias;
R.Rall = net_best.Layers(layer_num).RecurrentWeights;
W.Wi_f = W.Wall(0*numHU+1:1*numHU, :);
W.Wf_f = W.Wall(1*numHU+1:2*numHU, :);
W.Wg_f = W.Wall(2*numHU+1:3*numHU, :);
W.Wo_f = W.Wall(3*numHU+1:4*numHU, :);
W.Wi_b = W.Wall(4*numHU+1:5*numHU, :);
W.Wf_b = W.Wall(5*numHU+1:6*numHU, :);
W.Wg_b = W.Wall(6*numHU+1:7*numHU, :);
W.Wo_b = W.Wall(7*numHU+1:8*numHU, :);
b.bi_f = b.ball(0*numHU+1:1*numHU, :);
b.bf_f = b.ball(1*numHU+1:2*numHU, :);
b.bg_f = b.ball(2*numHU+1:3*numHU, :);
b.bo_f = b.ball(3*numHU+1:4*numHU, :);
b.bi_b = b.ball(4*numHU+1:5*numHU, :);
b.bf_b = b.ball(5*numHU+1:6*numHU, :);
b.bg_b = b.ball(6*numHU+1:7*numHU, :);
b.bo_b = b.ball(7*numHU+1:8*numHU, :);
R.Ri_f = R.Rall(0*numHU+1:1*numHU, :);
R.Rf_f = R.Rall(1*numHU+1:2*numHU, :);
R.Rg_f = R.Rall(2*numHU+1:3*numHU, :);
R.Ro_f = R.Rall(3*numHU+1:4*numHU, :);
R.Ri_b = R.Rall(4*numHU+1:5*numHU, :);
R.Rf_b = R.Rall(5*numHU+1:6*numHU, :);
R.Rg_b = R.Rall(6*numHU+1:7*numHU, :);
R.Ro_b = R.Rall(7*numHU+1:8*numHU, :);
c_prev = zeros(2*inputSize, 1);
c_prev_f = c_prev(0*numHU+1:1*numHU, :);
c_prev_b = c_prev(1*numHU+1:2*numHU, :);
h_prev = zeros(2*inputSize, 1);
h_prev_f = h_prev(0*numHU+1:1*numHU, :);
h_prev_b = h_prev(1*numHU+1:2*numHU, :);
% Initialising
c_f = c_prev_f;
h_f = h_prev_f;
c_b = c_prev_b;
h_b = h_prev_b;
for a = 1:n_seq
c_prev_f = c_f;
h_prev_f = h_f;
c_prev_b = zeros(numHU, 1); %c_b; % !! resetting backward cell state instead of inheriting it from previous sequence
h_prev_b = zeros(numHU, 1); %c_f; % !! resetting backward hidden state instead of inheriting it from previous sequence
Xin = activations(net_best, X(:, a), layer_num-1); % get the input after BN
Xin = Xin{1};
% forward pass
i_f = sigmoid(dlarray(W.Wi_f*Xin+R.Ri_f*h_prev_f+b.bi_f));
f_f = sigmoid(dlarray(W.Wf_f*Xin+R.Rf_f*h_prev_f+b.bf_f));
g_f = tanh(dlarray(W.Wg_f*Xin+R.Rg_f*h_prev_f+b.bg_f));
o_f = sigmoid(dlarray(W.Wo_f*Xin+R.Ro_f*h_prev_f+b.bo_f));
c_f = f_f.*c_prev_f+i_f.*g_f;
h_f = o_f.*tanh(c_f);
% backward pass
i_b = sigmoid(dlarray(W.Wi_b*Xin+R.Ri_b*h_prev_b+b.bi_b));
f_b = sigmoid(dlarray(W.Wf_b*Xin+R.Rf_b*h_prev_b+b.bf_b));
g_b = tanh(dlarray(W.Wg_b*Xin+R.Rg_b*h_prev_b+b.bg_b));
o_b = sigmoid(dlarray(W.Wo_b*Xin+R.Ro_b*h_prev_b+b.bo_b));
c_b = f_b.*c_prev_b+i_b.*g_b;
h_b = o_b.*tanh(c_b);
c = [c_f; c_b];
h = [h_f; h_b];
hidden_state_actual = hidden_state_actual(:, end);
% comparing custom implementation and implementation by matlab
fprintf('Cell state: %f\n', sum(abs(single(c)-cell_state_actual)));
fprintf('Hidden state: %f\n', sum(abs(single(h)-hidden_state_actual)));
I notice that the output of the forward pass by matlab function can be replicated, only if the backward cell states and hidden states are reset after every sequence.
Is this a bug?
Thanks in advance.

Answers (1)

Aditya on 26 Feb 2024
Although it might not be a bug, it might be a feature unique to MATLAB's handling of state updates in bidirectional LSTM (BiLSTM) layers when classifyAndUpdateState is used. Both forward and backward data processing is done by the BiLSTM layer, and information from the full sequence should be captured in the final states.
The backward pass of a BiLSTM typically begins at the end of the sequence and travels towards the beginning during a forward pass. Because the backward state does not yet have any future context, it is assumed to be zero-initialized for each new sequence while processing a sequence one time step at a time.
However, when you process the entire sequence at once, MATLAB's classifyAndUpdateState function should update both the forward and backward states across the whole sequence. If the backward states remain zero after processing the entire sequence, it could indicate an issue with how the states are being handled or updated.
In your code snippet, it appears that you are manually resetting the backward cell states and hidden states to zero after every sequence (as indicated by the comments in the code). This would not be the correct approach when processing the entire sequence at once, as the backward states should carry information from the end of the sequence back to the begining.




Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!