Effficient Computation of Matrix Gradient

2 views (last 30 days)
Hi,
I am trying to compute the gradient of a matrix-valued function . I have computed the element-wise gradient as and have verified that it is correct numerically (for my purposes of gradient descent).
My MATLAB implementation of the above gradient is:
for p = 1:N
for q = 1:N
gradX(p,q) = sum(w .* (conj(A(:,p)) * conj(B(q,:))) .* (AXB), 'all');
end
end
which I have also verified is correct numerically.
However, my issue is that N = 750, so this computation is extremely slow and impractical for gradient descent: on my desktop with 32 GB RAM and an Intel Xeon 3.7 GHz processor, one iteration takes around 10-15 minutes. I expect to need several hundred iterations for convergence.
I was wondering if there is any obvious way I am missing to speed up or parallelize it. I have tried parfor but have not had any luck.
Thank you and I very much appreciate any suggestions.
  2 Comments
Bruno Luong
Bruno Luong on 9 Apr 2024
Edited: Bruno Luong on 9 Apr 2024
Whare is a typical size of w (or AXB)?
btw the first obvious optimization is pre multiply w with AXB.
Shreyas Bharadwaj
Shreyas Bharadwaj on 9 Apr 2024
Edited: Shreyas Bharadwaj on 9 Apr 2024
Thank you, I will do that. All matrices, including w, are of size N x N i.e. 750 x 750.

Sign in to comment.

Accepted Answer

Bruno Luong
Bruno Luong on 9 Apr 2024
The best
N = 200; % 750
gradX_1 = zeros(N,N);
w = rand(N,N);
AXB = rand(N,N)+1i*rand(N);
A = rand(N,N)+1i*rand(N);
B = rand(N,N)+1i*rand(N);
tic
for p = 1:N
for q = 1:N
gradX_1(p,q) = sum(w .* (conj(A(:,p)) * conj(B(q,:))) .* (AXB), 'all');
end
end
t1=toc
t1 = 15.1666
% Method 3
tic
C = w .* AXB;
gradX = A' * C * B';
t2=toc
t2 = 0.0049
err = norm(gradX(:)-gradX_1(:),'inf') / norm(gradX_1(:))
err = 2.4063e-17
fprintf('New code version 3 is %g faster\n', t1/t2)
New code version 3 is 3088.92 faster

More Answers (1)

Bruno Luong
Bruno Luong on 9 Apr 2024
I propose this, and time testing for N = 200;
N = 200; % 750
gradX_1 = zeros(N,N);
w = rand(N,N);
AXB = rand(N,N)+1i*rand(N);
A = rand(N,N)+1i*rand(N);
B = rand(N,N)+1i*rand(N);
tic
for p = 1:N
for q = 1:N
gradX_1(p,q) = sum(w .* (conj(A(:,p)) * conj(B(q,:))) .* (AXB), 'all');
end
end
t1=toc
t1 = 6.6905
gradX = zeros(N,N);
tic
C = w .* AXB;
C = reshape(C,1,[]);
for p = 1:N
Ap = A(:,p);
for q = 1:N
AB = Ap * B(q,:);
AB = reshape(AB,1,[]);
gradX(p,q) = C * AB';
end
end
t2=toc
t2 = 1.0750
fprintf('New code version 1 is %g faster\n', t1/t2)
New code version 1 is 6.22383 faster

Products


Release

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!