Classify error: requires 3 arguments
3 views (last 30 days)
Show older comments
Hello All I have trained an LSTM model to classify EMG signals ( one dimensional time series) to produce a class prediction. Now when trying to test the trained LSTM on a test signal , classify produces error of requiring 3 arguments. No matter how I changed the shape of the test signal nothing helped. Also predict produced error results. Could you please help.
% Training code:
% LSTM-1D classification using raw EMG signal
%
% Data path
path = '/home/ubuntu/Desktop/EMG data analysis/EMG signal Matlab'
parameters
numHiddenUnits = 120;
numClasses = 8;
numChannels = 1
% Now prepare training/lables dataset for LSTM training
% Assuming sorted_emg_data is your sorted array with data and labels
% Extract the data for training
XTrain = cellfun(@(c) c.signal, sorted_emg_data(:, 1), 'UniformOutput', false);
% Extract the labels for training
TTrain = sorted_emg_data(:, 2);
% Convert the labels to a categorical array
TTrain = categorical(TTrain);
% Now XTrain contains all the EMG signals and TTrain contains the corresponding labels
% Now training the LSTM model
numHiddenUnits = 120;
numClasses = 8;
numChannels = 1
% Now define your layers with the correct number of output classes
layers = [ ...
sequenceInputLayer(numChannels)
bilstmLayer(numHiddenUnits, 'OutputMode', 'last')
fullyConnectedLayer(numClasses) % Make sure this matches the number of unique classes in TTrain
softmaxLayer
classificationLayer];
% Define your training options (make sure MiniBatchSize is appropriate for your dataset)
options = trainingOptions('adam', ...
'MaxEpochs', 50, ...
'MiniBatchSize', 3, ... % Adjust based on your hardware capabilities
'InitialLearnRate', 0.01, ...
'GradientThreshold', 1, ...
'Verbose', 0, ...
'Plots', 'training-progress');
% Train the network
net = trainNetwork(XTrain, TTrain, layers, options);
% Testing code
%
% Loading the network
net = load ("lstm_trained_model.mat")
% Loading data
test = load ('emg_signal_3.mat')
net.layers
length (test)
length (signal) % signal directly loaded
% Classification
pred = classify(net, test);
Answers (1)
Cris LaPierre
on 2 Jan 2024
I cannot duplicate your error. I used this example to create a sample data set. I then trained that data using your code, and then tested it using the code in the pdfs. My conclusion is there is nothing wrong with your code. Without more details, I don't know what more we can do to help.
Here are the results I obtained when running the model on test data using the code from your pdfs.

0 Comments
See Also
Categories
Find more on Measurements and Feature Extraction in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!