LSTM not outputting sequence
Show older comments
I am attempting to do sequence-to-sequence classification.
I haveNtime series of
observations each, and each observation collectsp features.
I build a
cell array XTrain. I set XTrain{i} to be the i-th
time series in my database.
I have two classes. I build a
cell array YTrain, where YTrain{i} is a
categorical vector telling me which class is at which time.
Now I build the following network:
inputSize = [p, 1, 1];
filterSize = [2 1];
numFilters = 20;
numHiddenUnits = 128;
numClasses = 2;
layers = [ ...
sequenceInputLayer(inputSize,'Name','input')
sequenceFoldingLayer('Name','fold')
convolution2dLayer(filterSize,numFilters,'Name','conv1')
reluLayer('Name','relu1')
convolution2dLayer(filterSize,numFilters,'Name','conv2')
reluLayer('Name','relu2')
flattenLayer('Name','flatten')
sequenceUnfoldingLayer('Name','unfold')
lstmLayer(numHiddenUnits,'OutputMode','sequence','Name','lstm')
fullyConnectedLayer(numClasses,'Name','fc')
softmaxLayer('Name','softmax')
classificationLayer('Name','classification')];
lgraph = layerGraph(layers);
lgraph = connectLayers(lgraph,'fold/miniBatchSize','unfold/miniBatchSize');
maxEpochs = 1;
miniBatchSize = 2;
options = trainingOptions('adam', ...
'ExecutionEnvironment','cpu', ...
'MaxEpochs',maxEpochs, ...
'MiniBatchSize',miniBatchSize, ...
'GradientThreshold',1, ...
'Verbose',false, ...
'Plots','training-progress');
net = trainNetwork(XTrain,YTrain,lgraph,options);
However, if I then run:
YScores = predict(net,XTrain,'MiniBatchSize',1);
the output is a
cell array whose i-th entry is a
vector of class probabilities.
This is INCORRECT. It should be a
vector of class probabilities.
4 Comments
John Malik
on 18 Dec 2019
Ridwan Alam
on 25 Dec 2019
Edited: Ridwan Alam
on 25 Dec 2019
Hey John, did you get a solution? Please share. Thanks!
Mohammad Sami
on 26 Dec 2019
Can you try putting the sequenceUnfoldingLayer before the flatten layer.
John Malik
on 26 Dec 2019
Edited: John Malik
on 26 Dec 2019
Answers (0)
Categories
Find more on Deep Learning Toolbox in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!