Gradients not recorded for a dlnetwork VAE
2 views (last 30 days)
Show older comments
I have created a VAE using dlnetwork - an encoder, decoder and a classifier. The loss function will not record gradients for the update via dlgradients. Can you have a look and discover what is the cause.
% === Setup ===
numEpochs = 5;
miniBatchSize = 32;
learnRate = 1e-3;
latentDim = 32;
hiddenUnits = 64;
inputSize = size(XTrain{1}, 1); % Number of features
sequenceLength = size(XTrain{1}, 2); % Time steps
% === Encoder ===
encoderLayers = [
sequenceInputLayer(inputSize, 'Name', 'input')
lstmLayer(64, 'OutputMode', 'last', 'Name', 'lstm_enc')
fullyConnectedLayer(latentDim, 'Name', 'fc_mu')
fullyConnectedLayer(latentDim, 'Name', 'fc_logvar')
];
encoderNet = dlnetwork(layerGraph(encoderLayers));
% === Decoder ===
decoderLayers = [
sequenceInputLayer(latentDim, 'Name', 'latent_input')
fullyConnectedLayer(64, 'Name', 'fc_latent')
lstmLayer(64, 'OutputMode', 'sequence', 'Name', 'lstm_dec')
fullyConnectedLayer(inputSize, 'Name', 'fc_recon')
];
decoderNet = dlnetwork(layerGraph(decoderLayers));
% === Classifier ===
classifierLayers = [
featureInputLayer(latentDim, 'Name', 'class_input')
fullyConnectedLayer(hiddenUnits, 'Name', 'fc_class_hidden')
reluLayer('Name', 'relu_class')
fullyConnectedLayer(numClasses, 'Name', 'fc_class')
softmaxLayer('Name', 'softmax_out')
];
classifierNet = dlnetwork(layerGraph(classifierLayers));
function loss = computeTotalLoss(dlX, dlY, encoderNet, decoderNet, classifierNet, sequenceLength)
% All operations inside this function are traced
[mu, logvar] = encodeLatents(dlX, encoderNet);
z = sampleLatents(mu, logvar); % Sampling with reparameterization
loss = computeLoss(dlX, dlY, z, mu, logvar, decoderNet, classifierNet, sequenceLength);
end
function [mu, logvar] = encodeLatents(dlX, encoderNet)
mu = forward(encoderNet, dlX, 'Outputs', 'fc_mu');
logvar = forward(encoderNet, dlX, 'Outputs', 'fc_logvar');
end
function z = sampleLatents(mu, logvar)
eps = dlarray(randn(size(mu), 'like', mu)); % Traced random noise
z = mu + exp(0.5 * logvar) .* eps;
end
function loss = computeLoss(dlX, dlY, z, mu, logvar, decoderNet, classifierNet, sequenceLength)
zRepeated = repmat(z, 1, 1, sequenceLength);
dlZSeq = dlarray(zRepeated, 'CBT');
reconOut = forward(decoderNet, dlZSeq, 'Outputs', 'fc_recon');
classOut = forward(classifierNet, dlarray(z, 'CB'), 'Outputs', 'softmax_out');
reconLoss = mse(reconOut, dlX);
classLoss = crossentropy(classOut, dlY);
klLoss = -0.5 * sum(1 + logvar - mu.^2 - exp(logvar), 'all');
loss = reconLoss + classLoss + klLoss;
end
for epoch = 1:numEpochs
idx = randperm(numel(XTrain));
totalEpochLoss = 0;
numBatches = 0;
for i = 1:miniBatchSize:numel(XTrain)
batchIdx = idx(i:min(i+miniBatchSize-1, numel(XTrain)));
XBatch = XTrain(batchIdx);
YBatch = YTrain(batchIdx, :);
% === Format Inputs ===
XCat = cat(3, XBatch{:});
XCat = permute(XCat, [1, 3, 2]);
dlX = dlarray(XCat, 'CBT');
dlY = dlarray(YBatch', 'CB');
totalLoss = dlfeval(@computeTotalLoss, dlX, dlY, encoderNet, decoderNet, classifierNet, sequenceLength);
gradients = dlgradient(totalLoss, encoderNet.Learnables);
Error using dlarray/dlgradient (line 105)
Value to differentiate is not traced. It must be a traced real dlarray scalar. Use dlgradient inside a function called by dlfeval to trace the variables.
0 Comments
Accepted Answer
More Answers (1)
Matt J
on 22 Aug 2025
Edited: Matt J
on 22 Aug 2025
As the error message says, the call to dlgradient must occur within the function (in this case computeTotalLoss) called by dlfeval, but that is not what you've done.
More confusingly, it appears that computeTotalLoss and computeLoss call each other in a recursive loop. That will create problems as well, I imagine. You are not meant to be calling dlfeval inside the loss function. It is meant to be called externally, in your training loop.
2 Comments
Torsten
on 23 Aug 2025
You have to put "dlgradient" inside "computeTotalLoss". Then you can call "dlfeval" from somewhere outside "computeTotalLoss" and "computeLoss".
See Also
Categories
Find more on Custom Training Loops in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!