How to create a custom weighted loss function for regression using Deep Learning Toolbox?

41 views (last 30 days)
I want to implement a custom weighted loss function for regression neural network and want to achieve following:
% non-vectorized form is used for clarity
loss_elem(i) = sum((Y(:,i) - T(:,i)).^2) * W(i));
loss = sum(loss_elem) / N;
where W(i) is the weight of the i-th input sample.
I found a similar example for creating weighted classification output layer and tried to implement it for custom regression output layer.
The weighted classification output layer uses weights for each class label meaning that same fixed weights will be used for training iterations. However, for weighted regression layer, there should be a different weight vector for each training batch.
I am not sure how to use weights as input arguments while creating the network. And how to keep track of weights indices for each training batch.

Accepted Answer

MathWorks Support Team
MathWorks Support Team on 9 Sep 2020
Unfortunately as of MATLAB R2020a, the requested workflow is not available with a built-in solution.  
However, below are two possible workarounds:
1) "trainNetwork" approach with a custom layer  
This approach would be based on using a layer property "RegressionWeights". The workflow is similar to the following example of weighted classification output layer but here needs to be implemented custom regression output layer.
However, this approach might be quite cumbersome since it require to implement the backward function and to keep track of the indexing of the batches as well. This approach is not recommend. 
2) "dlnetwork" approach with custom training loop
This approach involves implementing a custom training loop using "dlnetwork". In this case, no output layers are required, but you can specify the loss function that you want to use (as explained in the "Model Predictions Function" section in the following example): 
The above example uses the "crossentropy" loss. However, you will need to implement your own version of the weighted mean squared error (WMSE) loss function and call it instead of the “crossentropy” function. Note that if you want to apply the weights on the mini-batches, you would need to extract the indexes of the data that you are using (in the above example, this is done by indexing with the variable "idx" in every epoch) and use them as input to the WMSE function. 
Please also note that this would NOT require to implement a backward method, since the computation will be performed using Automatic Differentiation. 
The current mean squared error (MSE) function is implemented as: 
function X = mse(X, T, observationDim)
% Half Mean Squared Error
N = size(X, observationDim);
X = sum((X-T).^2, 'all') / (2*N);
end
Please modify it for WMSE. It will be straightforward to add the input W (indexed to have the same size as mini-batch) and perform the desired calculation. 
The only drawback of this approach is that you will not be able to use "trainNetwork" function. Instead you will need to implement a custom training loop that might be slightly cumbersome. 
The "custom training loop" feature was introduced recently as a tool to provide users with more flexibility and you should be able to achieve the desired workflow using the above approach.

More Answers (0)

Products


Release

R2020a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!