MATLAB Answers

Performing LSTM Regression on matrices without reshaping matrix elements into sequence of vectors, or performing LSTM Regression on 3-D data?

3 views (last 30 days)
Anand Dandavate
Anand Dandavate on 22 Nov 2020
Commented: Anand Dandavate on 28 Nov 2020
Hello Everyone,
I am trying to perform LSTM Regression on matrices without reshaping matrix elements into sequence of vectors, (i.e. directly perform LSTM regression on sequence of matrices) or performing LSTM Regression on 3-D data where the first two dimension form the content of a mtrix and the third dimension is a the time dimension leading to time sequence of matrices.
This is giving me errors.
I want to input sequence of matrices to lstm regression and get the output in the form of matrices only. Is this possible?
Any reference code anywhere would be helpful?
Can this acheived in programs other than Matlab?

  0 Comments

Sign in to comment.

Answers (1)

Srivardhan Gadila
Srivardhan Gadila on 28 Nov 2020
You can define your deep neural network and use the following trainNetwork syntax: net = trainNetwork(sequences,Y,layers,options). Refer to the expalanation of sequences & Y and arrange the data format accordingly. Also the following code may give you an idea:
inputSize = [13 11 1 5];
nTrainSamples = 50;
filterSize = 5;
numFilters = 20;
numHiddenUnits = 200;
numResponses = 5;
layers = [ ...
sequenceInputLayer(inputSize,'Name','input')
flattenLayer('Name','flatten')
lstmLayer(numHiddenUnits,'Name','lstm','OutputMode','sequence')
fullyConnectedLayer(numResponses, 'Name','fc')
regressionLayer('Name','regression')];
lgraph = layerGraph(layers);
analyzeNetwork(layers)
%%
trainData = arrayfun(@(x)rand([inputSize(:)' 1]),1:nTrainSamples,'UniformOutput',false)';
trainLabels = arrayfun(@(x)rand(numResponses,1),1:nTrainSamples,'UniformOutput',false)';
size(trainData)
size(trainLabels)
%%
options = trainingOptions('adam', ...
'InitialLearnRate',0.005, ...
'LearnRateSchedule','piecewise',...
'MaxEpochs',300, ...
'MiniBatchSize',1024, ...
'Verbose',1, ...
'Plots','training-progress');
net = trainNetwork(trainData,trainLabels,lgraph,options);

Products


Release

R2020a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!