imds = imageDatastore('D:\2023\thesis\waveletImages\SubsetImages', ...
    'IncludeSubfolders', true, ...
    'LabelSource', 'foldernames');
DataSetInfo = countEachLabel(imds);
[imdsTrain, imdsValidation] = splitEachLabel(imds, 0.80, 'randomized');
inputSize = net.Layers(1).InputSize;
augimdsTrain = augmentedImageDatastore(inputSize, imdsTrain);
augimdsValidation = augmentedImageDatastore(inputSize, imdsValidation);
trainFeatures = activations(net, augimdsTrain, featureLayer, 'OutputAs', 'rows');
validationFeatures = activations(net, augimdsValidation, featureLayer, 'OutputAs', 'rows');
trainLabels = grp2idx(imdsTrain.Labels);
validationLabels = grp2idx(imdsValidation.Labels);
numFeatures = size(trainFeatures, 2); 
ObservationInfo = rlNumericSpec([numFeatures, 1], ...
    'LowerLimit', -inf, 'UpperLimit', inf, ...
ObservationInfo.Description = 'Feature vector extracted from CNN';
ActionInfo = rlFiniteSetSpec(1:max(trainLabels)); 
ActionInfo.Name = 'Class Labels';
stepFunction = @(Action, LoggedSignals) stepFunctionRL(Action, trainFeatures, trainLabels, LoggedSignals);
resetFunction = @() resetFunctionRL(trainFeatures);
env = rlFunctionEnv(ObservationInfo, ActionInfo, stepFunction, resetFunction);
    featureInputLayer(numFeatures, 'Normalization', 'none', 'Name', 'state')
    fullyConnectedLayer(128, 'Name', 'fc1')
    reluLayer('Name', 'relu1')
    fullyConnectedLayer(64, 'Name', 'fc2')
    reluLayer('Name', 'relu2')
    fullyConnectedLayer(numel(ActionInfo.Elements), 'Name', 'fcOutput')]; 
criticNet = dlnetwork(layerGraph(statePath));
critic = rlQValueFunction(criticNet, ObservationInfo, ActionInfo, 'ObservationInputNames', 'state');
agentOpts = rlDQNAgentOptions( ...
    'ExperienceBufferLength', 1e5, ...
    'DiscountFactor', 0.99, ...
    'TargetSmoothFactor', 1e-3, ...
    'TargetUpdateFrequency', 4);
agent = rlDQNAgent(critic, agentOpts);
trainOpts = rlTrainingOptions( ...
    'MaxStepsPerEpisode', size(trainFeatures, 1), ...
    'Plots', 'training-progress', ...
    'StopTrainingCriteria', 'EpisodeCount', ...
    'StopTrainingValue', 500);
trainingStats = train(agent, env);
function [InitialObservation, LoggedSignals] = resetFunctionRL(Features)
LoggedSignals = struct();
LoggedSignals.CurrentIndex = 1;
InitialObservation = Features(LoggedSignals.CurrentIndex, :)'; 
if size(InitialObservation, 2) ~= 1
    error('InitialObservation must be a column vector of size [numFeatures, 1]');
LoggedSignals.EpisodeStartTime = datetime('now'); 
function [NextObservation, Reward, IsDone, LoggedSignals] = stepFunctionRL(Action, Features, Labels, LoggedSignals)
idx = LoggedSignals.CurrentIndex;
correctLabel = Labels(idx); 
if Action == correctLabel
LoggedSignals.CurrentIndex = idx + 1;
IsDone = LoggedSignals.CurrentIndex > size(Features, 1);
    NextObservation = Features(LoggedSignals.CurrentIndex, :)'; 
    NextObservation = Features(idx, :)';