loading and training an existing network.

I am trying define a network, then train it in multiple sessions. The problem is that I can't get the load or read of the network to work in the second session. The code is:
layers = [ ...
sequenceInputLayer(270)
bilstmLayer(numHiddenUnits,OutputMode="last")
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer]
options = trainingOptions("adam", ...
InitialLearnRate=0.002,...
MaxEpochs=15, ...
Shuffle="never", ...
GradientThreshold=1, ...
Verbose=false, ...
ExecutionEnvironment="gpu", ...
Plots="training-progress");
clabels=categorical(labels);
numLables=numel(clabels)
load("savednet.mat","layers");
net = trainNetwork(data,clabels,layers,options);
save("savednet","net");
I have tried many variations of the load command and it always gives an error on the second argument:
Warning: Variable 'layers' not found.
Exactly what should that look like and then how should it be used as input to the trainNetwork routine?

7 Comments

layers = [ ...
sequenceInputLayer(270)
bilstmLayer(numHiddenUnits,OutputMode="last")
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer]
You define a variable named layers
load("savednet.mat","layers");
You try to overwrite the variable named layers with the content of the variable layers stored in savednet.mat but that variable does not exist in that .mat
save("savednet","net");
Notice you do not save layers in savednet.mat
Thank you for the quick reply,
Then two questions. What should the save call look like? I had copied this from anothers reply. Then exactly what should the load command look like? It is very unclear what is being saved in the savenet.mat file.
Remove
load("savednet.mat","layers");
Change
save("savednet","net");
to
save("savednet","net","layers");
Afterwards, to do additional work on the saved network,
load("savednet","net","layers");
I see says the blind man!! This seems very logical. None of the examples that I could find appeared so clean.
Still, if I may dig into this a bit more, then the the call:
net = trainNetwork(data,clabels,layers,options);
would appear to be using generating the network from scratch each time. What tells it to use the read in network as the starting point?
filename = "savednet.mat";
if isfile(filename)
clear layers net
try
load(filename, "net", "layers");
catch ME
end
end
if ~exist("layers", "var") || ~exist("net", "var")
layers = [ ...
sequenceInputLayer(270)
bilstmLayer(numHiddenUnits,OutputMode="last")
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer]
options = trainingOptions("adam", ...
InitialLearnRate=0.002,...
MaxEpochs=15, ...
Shuffle="never", ...
GradientThreshold=1, ...
Verbose=false, ...
ExecutionEnvironment="gpu", ...
Plots="training-progress");
clabels=categorical(labels);
numLables=numel(clabels);
net = trainNetwork(data,clabels,layers,options);
save(filename, "net", "layers");
end
The above code loads layers and net from the file if possible, and if that fails then it creates and trains the network and saves it.
Perhaps I don't understand how one can train in stages then. The idea is that the training will be continued in the second and subsequent sessions. Sort of a continuing transfer learning. The idea is that over time the network keeps improving. perhaps I should be using trainnet instead of trainnetwork. Then it would appear the call is:
load(filename,"net1","layers");
net=trainnet(data,clabels,net1,"crossentropy",options);
Is this the correct direction?
Probably
net1 = trainnet(data,clabels,net1,"crossentropy",options);

Sign in to comment.

 Accepted Answer

previous = load("savednet","net","layers");
net = trainNetwork(data,clabels,previous.net,options);

1 Comment

Matt J
Matt J on 7 Oct 2024
Edited: Matt J on 8 Oct 2024
perhaps I should be using trainnet instead of trainnetwork.
It would be better, since trainnet is newer and has more flexibility. However, it won't make a difference as far as how to resume the training of a pre-existing network..

Sign in to comment.

More Answers (0)

Categories

Find more on Deep Learning Toolbox in Help Center and File Exchange

Products

Release

R2024b

Asked:

on 7 Oct 2024

Edited:

on 8 Oct 2024

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!