getwb/setwb for compact neural network storage?
6 views (last 30 days)
I have a project where I'm using a fitnet with 8 nodes in the hidden layer to fit a surface, given by a set of points near that surface (so we're looking for a surface that minimizes the distance to the points). The algorithm happens to generate and train a very large number of fitnets. These are then saved out to disk for analysis by other stages.
The problem that has arisen is that the resulting output file is huge (on the order of 100 GB), and takes up most of my quota on the cluster that this runs on. I set out to find a more compact way of saving out the networks.
I found out about getwb() and setwb() in the documentation, which gets or sets the weights and biases of a network, respectively. I figured, since the network architecture is known, I could perhaps just save out the weights and biases of my trained networks, and then in later stages, create a new array of blank networks, and set their weights and biases with setwb. I tried it on one of the Matlab examples (fitting a 1-D function from a list of points on the curve), and the procedure seemed to work.
So then I tried adapting it to my network and inputs. I have one stage (simple_setwb_save.m, attached) that trains a fitnet and then saves out the results of getwb for that network. Then the second stage (simple_setwb_load.m, attached) reloads the saved weights and biases from the first stage, and uses setwb to set the weights and biases of a newly created fitnet. As evidenced by the attached plots, the output of the re-created network appears to match that of the original, and the output of perform() matches for both. Success - right?
Now, I adjusted my project to use the new protocol: instead of saving out half a million or so fitnets at a few MB a pop, I save out an array of half a million-by-33 floats. It didn't work, but it fails in a very peculiar way - the shape of the curves on a plot similar to those attached are still more or less reasonable, but they are wildly shifted upward away from the input data points. I've attached an example of this.
I then made a modification so it would save out both the array of fitnets and the array of weights and biases, and a compare script that would go through the array one network at a time, comparing the output of the saved fitnet to the recreated network that used setwb with the weights from the other file. Each line of the compare output reports firstly whether the result of getwb() is identical for both; then it shows the mean-squared error of the network predictions for the inputs vs. the original dataset. The output is quite long, but I've copied the last few lines of it here:
weights and biases match; mse1 = 0.003370; mse2 = 4.005985
weights and biases match; mse1 = 0.000408; mse2 = 4.048744
weights and biases match; mse1 = 0.000323; mse2 = 4.020332
weights and biases match; mse1 = 0.001506; mse2 = 4.011596
weights and biases match; mse1 = 0.000444; mse2 = 4.022094
weights and biases match; mse1 = 0.000364; mse2 = 4.012197
weights and biases match; mse1 = 0.004508; mse2 = 4.081187
weights and biases match; mse1 = 0.002639; mse2 = 4.005669
weights and biases match; mse1 = 0.002300; mse2 = 4.068722
weights and biases match; mse1 = 0.001478; mse2 = 4.007328
weights and biases match; mse1 = 0.009652; mse2 = 4.001592
weights and biases match; mse1 = 0.002022; mse2 = 4.017363
weights and biases match; mse1 = 0.004103; mse2 = 4.006953
weights and biases match; mse1 = 0.000640; mse2 = 4.012252
weights and biases match; mse1 = 0.003421; mse2 = 3.995628
weights and biases match; mse1 = 0.026180; mse2 = 4.001113
weights and biases match; mse1 = 0.005215; mse2 = 3.992562
weights and biases match; mse1 = 0.001084; mse2 = 4.015727
0 pass; 1000 fail
0 instances of mismatch in weights and biases.
There is nothing remarkable in the output preceding; the takeaway is that the loaded fitnets (mse1) have reasonable MSEs, while the networks created and updated via setwb (mse2) have MSEs that are all approximately 4. I've attached the compare script, but not the data files needed to run it because the saved fitnet file is large.
I feel there must be some sort of hidden variables or something going on here that are making this procedure not work at scale. Is there some extra step I'm missing that's needed to make the networks match up?