How to retrieve the cell/hidden state of an LSTM layer during training
    6 views (last 30 days)
  
       Show older comments
    
Hi everyone,
as the title says, I'm trying to extract the cell & hidden state from an LSTM layer after training. Unfortunately, I haven't found a solution for that yet. 
Does anyone know, how that works or if it is even possible?
Thanks for any advice!
0 Comments
Answers (5)
  Da-Ting Lin
 on 11 Feb 2020
        I also have this question. Hopefully it may be included in an upcoming release?
0 Comments
  Haoyuan Ma
 on 16 Mar 2020
        I have this question too...
I have tried many times before seeing this page.
0 Comments
  Giuseppe Dell'Aversana
 on 16 Apr 2020
        I also have this question.. maybe someone has the answer now?
0 Comments
  Yildirim Kocoglu
 on 10 Jan 2021
        It's a little late but, I had the same question and I came across this: https://www.mathworks.com/help/ident/ug/use-lstm-for-linear-system-identification.html
I haven't tried this yet but, please read this carefully as it may help.
Read the part: Set Network Initial State
It says: As the network performs estimation using a step input from 0 to 1, the states of the LSTM network (cell and hidden states of the LSTM layers) drift toward the correct initial condition. To visualize this, extract the cell and hidden state of the network at every time step using the predictAndUpdateState function.
Here is some code from the documentation which you can try to modify to achieve what you need:
stepMarker = time <= 2;
yhat = zeros(sum(stepMarker),1);
hiddenState = zeros(sum(stepMarker),200); % 200 LSTM units
cellState = zeros(sum(stepMarker),200);
for ntime = 1:sum(stepMarker)
    [fourthOrderNet,yhat(ntime)] = predictAndUpdateState(fourthOrderNet,stepSignal(ntime)');
    hiddenState(ntime,:) = fourthOrderNet.Layers(2,1).HiddenState;
    cellState(ntime,:) = fourthOrderNet.Layers(2,1).CellState;
end
If you have multiple batches you can re-use the same batch in a for loop and just predict on your trained network (feed into the network one batch at a time like this for i=1:batch_size) and if you use net = resetState(net) (if you saved your trained network as 'net') at the very beginning of each prediction in the for loop it resets the states to initial states (which is usually zeros if you did not specify them beforehand). It is the same initial states used during your training so, you should be able to see the hiddenstates and cell states of each time step according to the code provided for each batch.
I personally needed to extract the final states to continue the prediction because I'm working on a forecasting problem.
0 Comments
  Sathyseelan Mayilvahanam
 on 19 Sep 2022
        The above mentioned code created matrices with values zeros when I run it. Kindly provide any solutions or code with complete example data. 
2 Comments
  Yildirim Kocoglu
 on 19 Sep 2022
				At which stage (time step) are you trying to extract the hidden/cell state and what is your purpose in extracting it or what kind of problem are you working on (classification, forecasting or something else?). Have you tried printing the hidden/cell states within the for loop in the code? The code I provided is not complete by the way as I borrowed it from the Matlab documentation as far as I remember (check the link I provided for more details). I don’t have an example I can provide as I moved to a different coding language altogether for a different project. The provided code snippet sets them to be zeros at the beginning and if you were to use resetState(net) within the for loop, that will reset the hidden/cell states to their initial states (initial_states = zeros by default if you did not specify the values yourself at the beginning -in this case the code snippet specifies the hidden state to be zeroes before entering the for loop). The hidden/cell states will get updated as you progress through each time step of a sequence and you should be able to print it out within the for loop.
See Also
Categories
				Find more on Parallel and Cloud in Help Center and File Exchange
			
	Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!




