Implement Ridge Regression Equation for a Neural Network MATLAB
8 views (last 30 days)
Show older comments
Jonathan Frutschy
on 9 Jun 2024
Edited: Jonathan Frutschy
on 23 Jun 2024
I am trying to replicate the following equation in MATLAB to find the optimal output weight matrix of a neural network from training using ridge regression.
Output Weight Matrix of a Neural Network after Training using Ridge Regression:
This equation comes from the echo state network guide provided by Mantas Lukosevicius and can be found at: https://www.researchgate.net/publication/319770153_A_practical_guide_to_applying_echo_state_networks
My attempt is below. I believe that the outer parenthesis (in red) makes this a non-traditional double summation, meaning the method presented by @Voss (see https://www.mathworks.com/matlabcentral/answers/1694960-nested-loops-for-double-summation) cannot be followed.Note that y_i is a T by 1 vector and y_i_target is also a T by 1 vector. Wout_i is a N by 1 vector where N is the number of nodes in the neural network. I generate a Wout_i,y_i,y_i_target for each i^th target training signal. The final output for Wout is a N by 1 vector, where each element in the vector is the optimal weight for each node in the network.
close all;
clear all;
clc;
N = 100; % number of nodes in nerual network
Ny = 200; % number of training signals
T = 50; % time length of each training signal
X = rand(N,T); % neural network state matrix
reg = 10^-4; % ridge regression coefficient
outer_sum = zeros(Ny,1);
for i = 1:Ny
y_i_target = rand(T,1); % training signal
Wout_i = ((X*X' + reg*eye(N)) \ (X*y_i_target));
Wouts{i} = Wout_i; % collected cell matrix of each Wout_i for each i^th target training signal
y_i = Wout_i'*X; % predicted signal
inner_sum = sum(((y_i'-y_i_target).^2)+reg*norm(Wout_i)^2);
outer_sum(i) = inner_sum;
end
outer_sum = outer_sum.*(1/Ny);
[minval, minidx] = min(outer_sum);
Wout = cell2mat(Wouts(minidx));
outer_sum = outer_sum.*(1/Ny);
[minval, minidx] = min(outer_sum);
Wout = cell2mat(Wouts(minidx));
My final answer for Wout is a N by 1 as it should be, but I am uncertain in my answer. I am particularly unsure whether or not I have done the double summation and arg min with respect to Wout operations correctly. Is there any way to validate my answer?
Accepted Answer
Garmit Pant
on 19 Jun 2024
Hello Jonathan Frutschy
From what I understand, you are implementing a ridge regression equation to find the optimal output weight matrix of a neural network from training.
I have referred to the research paper that you have referenced and based on my investigation, I have found some issues in your implementation.
The equation that you have attached is the equation that is solved in ridge regression to calculate the weights for the network.
The goal is to find a value for the weights ‘W_out’ parameter such that the entire sum given below is minimised.
- The ‘w_i_out’ parameter represents the ith row of the weights matrix ‘W_out’.
- ‘y_i' denotes the actual outputs of the model.
- ‘y_i_target’ refers to the calculated target outputs.
- The training process aims to minimize the sum value described, effectively learning the optimal weights.
- Contrary to the initial assumption, ‘W_out’ is not an N x 1 vector. Instead, it is a matrix of size Ny x (1 + Nu + Nx).
- This information and the structure of Wout can be confirmed by referring to Equation 27.4 on page 661 and its explanation in the provided research paper.
You can modify your implementation to implement the training loop for the neural network accordingly to generate the weights such that the above-mentioned equation has a minimum value.
MATLAB also has an inbuilt function to find the coefficient estimates for ridge regression models of the predictor data and the response. Please refer to the following documentation link to understand how to use the ‘ridge’ function:
I hope you find the above explanation and suggestions useful!
1 Comment
More Answers (0)
See Also
Categories
Find more on Matrix Indexing in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!