Gradients not recorded for a dlnetwork VAE

2 views (last 30 days)
I have created a VAE using dlnetwork - an encoder, decoder and a classifier. The loss function will not record gradients for the update via dlgradients. Can you have a look and discover what is the cause.
% === Setup ===
numEpochs = 5;
miniBatchSize = 32;
learnRate = 1e-3;
latentDim = 32;
hiddenUnits = 64;
inputSize = size(XTrain{1}, 1); % Number of features
sequenceLength = size(XTrain{1}, 2); % Time steps
% === Encoder ===
encoderLayers = [
sequenceInputLayer(inputSize, 'Name', 'input')
lstmLayer(64, 'OutputMode', 'last', 'Name', 'lstm_enc')
fullyConnectedLayer(latentDim, 'Name', 'fc_mu')
fullyConnectedLayer(latentDim, 'Name', 'fc_logvar')
];
encoderNet = dlnetwork(layerGraph(encoderLayers));
% === Decoder ===
decoderLayers = [
sequenceInputLayer(latentDim, 'Name', 'latent_input')
fullyConnectedLayer(64, 'Name', 'fc_latent')
lstmLayer(64, 'OutputMode', 'sequence', 'Name', 'lstm_dec')
fullyConnectedLayer(inputSize, 'Name', 'fc_recon')
];
decoderNet = dlnetwork(layerGraph(decoderLayers));
% === Classifier ===
classifierLayers = [
featureInputLayer(latentDim, 'Name', 'class_input')
fullyConnectedLayer(hiddenUnits, 'Name', 'fc_class_hidden')
reluLayer('Name', 'relu_class')
fullyConnectedLayer(numClasses, 'Name', 'fc_class')
softmaxLayer('Name', 'softmax_out')
];
classifierNet = dlnetwork(layerGraph(classifierLayers));
function loss = computeTotalLoss(dlX, dlY, encoderNet, decoderNet, classifierNet, sequenceLength)
% All operations inside this function are traced
[mu, logvar] = encodeLatents(dlX, encoderNet);
z = sampleLatents(mu, logvar); % Sampling with reparameterization
loss = computeLoss(dlX, dlY, z, mu, logvar, decoderNet, classifierNet, sequenceLength);
end
function [mu, logvar] = encodeLatents(dlX, encoderNet)
mu = forward(encoderNet, dlX, 'Outputs', 'fc_mu');
logvar = forward(encoderNet, dlX, 'Outputs', 'fc_logvar');
end
function z = sampleLatents(mu, logvar)
eps = dlarray(randn(size(mu), 'like', mu)); % Traced random noise
z = mu + exp(0.5 * logvar) .* eps;
end
function loss = computeLoss(dlX, dlY, z, mu, logvar, decoderNet, classifierNet, sequenceLength)
zRepeated = repmat(z, 1, 1, sequenceLength);
dlZSeq = dlarray(zRepeated, 'CBT');
reconOut = forward(decoderNet, dlZSeq, 'Outputs', 'fc_recon');
classOut = forward(classifierNet, dlarray(z, 'CB'), 'Outputs', 'softmax_out');
reconLoss = mse(reconOut, dlX);
classLoss = crossentropy(classOut, dlY);
klLoss = -0.5 * sum(1 + logvar - mu.^2 - exp(logvar), 'all');
loss = reconLoss + classLoss + klLoss;
end
for epoch = 1:numEpochs
idx = randperm(numel(XTrain));
totalEpochLoss = 0;
numBatches = 0;
for i = 1:miniBatchSize:numel(XTrain)
batchIdx = idx(i:min(i+miniBatchSize-1, numel(XTrain)));
XBatch = XTrain(batchIdx);
YBatch = YTrain(batchIdx, :);
% === Format Inputs ===
XCat = cat(3, XBatch{:});
XCat = permute(XCat, [1, 3, 2]);
dlX = dlarray(XCat, 'CBT');
dlY = dlarray(YBatch', 'CB');
totalLoss = dlfeval(@computeTotalLoss, dlX, dlY, encoderNet, decoderNet, classifierNet, sequenceLength);
gradients = dlgradient(totalLoss, encoderNet.Learnables);
Error using dlarray/dlgradient (line 105)
Value to differentiate is not traced. It must be a traced real dlarray scalar. Use dlgradient inside a function called by dlfeval to trace the variables.

Accepted Answer

Torsten
Torsten on 22 Aug 2025
Moved: Torsten on 22 Aug 2025
Did you read the documentation on how "dlgradient" is to be applied ?
You will have to call "dlgradient" inside the function "computeTotalLoss", not outside from it.
  2 Comments
SIMON
SIMON on 23 Aug 2025
Ok thanks this is news to me thankyou. I will do that. Thanks so much.
Catalytic
Catalytic on 23 Aug 2025
@SIMON, you should click accept on Torsten's answer if it solves your problem.

Sign in to comment.

More Answers (1)

Matt J
Matt J on 22 Aug 2025
Edited: Matt J on 22 Aug 2025
As the error message says, the call to dlgradient must occur within the function (in this case computeTotalLoss) called by dlfeval, but that is not what you've done.
More confusingly, it appears that computeTotalLoss and computeLoss call each other in a recursive loop. That will create problems as well, I imagine. You are not meant to be calling dlfeval inside the loss function. It is meant to be called externally, in your training loop.
  2 Comments
SIMON
SIMON on 23 Aug 2025
The two functions are not recursive computeLoss is called by computeTotalLoss which is inside dfeval. However the gradients are calculated outside of dlfeval - so I may try putting it inside this thanks.
Torsten
Torsten on 23 Aug 2025
You have to put "dlgradient" inside "computeTotalLoss". Then you can call "dlfeval" from somewhere outside "computeTotalLoss" and "computeLoss".

Sign in to comment.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!