need to vectorize efficiently calculating only certain values in the matrix multiplication A * B, using a logical array L the size of A * B
9 views (last 30 days)
Show older comments
I have matrices A (m by v) and B (v by n). I also have a logical matrix L (m by n).
I am interested in, as efficiently as possible, calculating only the values in A * B that correspond to logical values in L (values of 1s). Essentially I am interested in the quantity ( A * B ) .* L .
For my problem, a typical L matrix has less than 0.1% percent of its values as 1s; the vast majority of the values are 0s. Thus, it makes no sense for me to directly perform ( A * B ) .* L , it would actually be faster to loop over each element of A * B that I want to compute, but even that is inefficient.
-------------------------------------------------------------------------------------------------------------------------------------------------------
Possible solution (need help vectorizing this code if possible)
My particular problem may have a nice solution given that the logical matrix L has a nice structure.
This L matrix is nice in that it can be represented as something like a permuted block matrix. This example L in is composed of 9 "blocks" of 1s, where each block of 1s has its own set of row and column indices. For instance, the highlighted area here can be seen the values of 1 as a particular submatrix in L.
My solution was to do this. I can get the row indices and column indices per each block's submatrix in L, organized in two cell lists "rowidxs_list" and "colidxs_list", both with the number of cells equal to the number of blocks. For instance in the block example I gave, subblock 1, I could calculate those particular values in A * B by simply doing A( rowidxs_list{1} , : ) * B( : , colidxs_list{1} ) .
That means that if I precomputed rowidxs_list and colidxs_list (ignore the costs of calculating these lists, they are negligable for my application), then my problem of calculating C = ( A * B ) .* L could effectively be done by:
C = sparse( m,n )
for i = 1:length( rowidxs_list )
C( rowidxs_list{i} , colidxs_list{i} ) = A( rowidxs_list{i} , : ) * B( : , colidxs_list{i} ) .
end
This seems like it would be the most efficient way to solve this problem if I knew how to vectorize this for loop. Does anyone see a way to vectorize this?
There may be ways to vectorize if certain things hold, e.g. only if rowidxs_list and colidxs_list are matrix arrays instead of cell lists of lists (where each column in an array is an index list, thus replacing use of rowidxs_list{i} with rowidxs_list(i,:) ). I'd prefer to use cell lists here if possible since different lists can have different numbers of elements.
-------------------------------------------------------------------------------------------------------------------------------------------------------
other suggested solution (creating a mex file?)
I first posted this question on the /r/matlab subreddit, see here for the reddit thread. The user "qtac" recommended that a C-MEX function linking to C programming language:
My gut feeling is the only way to really optimize this is with a C-MEX solution; otherwise, you are going to get obliterated by overhead from subsref in these loops. With C you could loop over L until you find a nonzero element, and then do only the row-column dot product needed to populate that specific element. You will miss out on a lot of the BLAS optimizations but the computational savings may make up for it.
Honestly I bet an LLM could write 90%+ of that MEX function for you; it's a well-formulated problem.
I think this could be a good solution to pursue, but I'd like other opinons as well.
8 Comments
Accepted Answer
James Tursa
on 19 Feb 2025
Edited: James Tursa
on 22 Feb 2025
You can try this naive mex routine. Not optimized at all for cache hits or patterns in L. Might be faster to transpose A first, but I haven't looked into that yet. Also doesn't contain any code to clean 0's from result.
/* File: ABL.c
*
* C = ABL(A,B,L)
*
* Performs the function
*
* C = (A*B).*L
*
* Where
*
* A = Full M x P real double matrix
* B = Full P x N real double matrix
* L = Sparse M x N logical matrix
* C = Sparse M x N real double matrix
*
* Brute force algorithm, no attempt at multi-threading, etc.
*
* Building: mex ABL.c -R2018a
*
* Programmer: James Tursa
*/
#include "mex.h"
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
mwSize Am, An, Bm, Bn, Lm, Ln, i, j, k;
mwIndex nrow;
mwIndex *Ir, *Jc, *Cir, *Cjc;
double *A, *B, *C, *a, *b;
/* Check inputs */
if( nrhs != 3 ||
!mxIsDouble(prhs[0]) || mxIsComplex(prhs[0]) || mxIsSparse(prhs[0]) || mxGetNumberOfDimensions(prhs[0]) != 2 ||
!mxIsDouble(prhs[1]) || mxIsComplex(prhs[1]) || mxIsSparse(prhs[1]) || mxGetNumberOfDimensions(prhs[1]) != 2 ||
!mxIsLogical(prhs[2]) || !mxIsSparse(prhs[2]) ) {
mexErrMsgTxt("Invalid Inputs");
}
Am = mxGetM(prhs[0]);
An = mxGetN(prhs[0]);
Bm = mxGetM(prhs[1]);
Bn = mxGetN(prhs[1]);
Lm = mxGetM(prhs[2]);
Ln = mxGetN(prhs[2]);
if( Am != Lm || Bn != Ln || An != Bm ) {
mexErrMsgTxt("Matrix sizes incompatible");
}
/* Get sparse indexing pointers */
Ir = mxGetIr(prhs[2]);
Jc = mxGetJc(prhs[2]);
/* Create output array same size and sparsity as L */
// mexCallMATLAB(plhs,1,(mxArray **)(prhs+2),1,"double"); /* slower? will need to manually 0 the data spots */
plhs[0] = mxCreateSparse(Lm, Ln, Jc[Ln], mxREAL); /* will need to manually fill in Ir and Jc */
/* Get data pointers */
A = (double *) mxGetData(prhs[0]);
B = (double *) mxGetData(prhs[1]);
C = (double *) mxGetData(plhs[0]);
Cir = mxGetIr(plhs[0]);
Cjc = mxGetJc(plhs[0]);
/* Loop through the logical L matrix indexing, C has exact same structure */
for( j=0; j<Ln; j++) {
*Cjc++ = Jc[j];
nrow = Jc[j+1] - Jc[j]; /* number of elements in this column */
for( i=0; i<nrow; i++ ) {
*Cir++ = *Ir;
a = A + *Ir++;
b = B + Bm * j;
for( k=0; k<An; k++ ) {
*C += *a * *b++;
a += Am;
}
C++;
}
}
*Cjc = Jc[Ln];
}
9 Comments
James Tursa
on 22 Feb 2025
Edited: James Tursa
on 22 Feb 2025
*** UPDATE ***
Here is a more CPU cache friendly version of the code. By transposing A first, the A elements used for the dot products are contiguous in memory (the B elements already are) and this makes much better use of the CPU cache for the dot product operations and gives better performance. This version would also be more friendly to multi-threading (although I don't think I will bother attempting that at the moment ... this can get tricky to figure out the best way to split things up depending on matrix sizes). Keep in mind that you have to pass in the transpose of A to this mex routine. I.e., the following are all mathematically equivalent (but not numerically because of different order of operations):
(A*B).*L
ABL(A,B,L)
ATBL(A.',B,L) % <-- Pass in the transpose of A!
That is, ABL(A,B,L) should give the exact same result as ATBL(A.',B,L) because the order of operations is exactly the same, but neither of these would be expected to give the exact same result as (A*B).*L because of different order of operations involved, even though there are all mathematically equivalent.
The mex code:
/* File: ATBL.c
*
* C = ATBL(A,B,L)
*
* Performs the function
*
* C = ((A.')*B).*L
*
* Where
*
* A = Full P x M real double matrix
* B = Full P x N real double matrix
* L = Sparse M x N logical matrix
* C = Sparse M x N real double matrix
*
* Brute force algorithm, no attempt at multi-threading, etc.
*
* Building: mex ATBL.c -R2018a
*
* Programmer: James Tursa
*/
#include "mex.h"
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
mwSize Am, An, Bm, Bn, Lm, Ln, i, j, k;
mwIndex nrow;
mwIndex *Lir, *Ljc, *Cir, *Cjc;
double *A, *B, *C, *a, *b, d;
/* Check inputs */
if( nrhs != 3 ||
!mxIsDouble(prhs[0]) || mxIsComplex(prhs[0]) || mxIsSparse(prhs[0]) || mxGetNumberOfDimensions(prhs[0]) != 2 ||
!mxIsDouble(prhs[1]) || mxIsComplex(prhs[1]) || mxIsSparse(prhs[1]) || mxGetNumberOfDimensions(prhs[1]) != 2 ||
!mxIsLogical(prhs[2]) || !mxIsSparse(prhs[2]) ) {
mexErrMsgTxt("Invalid Inputs");
}
Am = mxGetM(prhs[0]);
An = mxGetN(prhs[0]);
Bm = mxGetM(prhs[1]);
Bn = mxGetN(prhs[1]);
Lm = mxGetM(prhs[2]);
Ln = mxGetN(prhs[2]);
if( An != Lm || Bn != Ln || Am != Bm ) {
mexErrMsgTxt("Matrix sizes incompatible");
}
/* Get sparse indexing pointers */
Lir = mxGetIr(prhs[2]);
Ljc = mxGetJc(prhs[2]);
/* Create output array same size and sparsity as L */
plhs[0] = mxCreateSparse(Lm, Ln, Ljc[Ln], mxREAL); /* will need to manually fill in Ir and Jc */
/* Get data pointers */
A = (double *) mxGetData(prhs[0]);
B = (double *) mxGetData(prhs[1]);
C = (double *) mxGetData(plhs[0]);
Cir = mxGetIr(plhs[0]);
Cjc = mxGetJc(plhs[0]) + 1; /* 2nd value is for 1st column */
/* Loop through the logical L matrix indexing */
/* C has same basic structure, but we need to clean the 0's as we go */
for( j=0; j<Ln; j++) { /* for each column of L */
*Cjc = *(Cjc-1); /* Copy running total from last column */
nrow = Ljc[j+1] - Ljc[j]; /* number of elements in this column of L */
for( i=0; i<nrow; i++ ) { /* for each index in this column of L */
a = A + Am * *Lir; /* point to correct column of A */
b = B; /* point to this column of B */
d = 0.0; /* initialize dot product value */
for( k=0; k<Am; k++ ) { /* dot product */
d += *a++ * *b++;
}
if( d ) { /* only include in sparse result if non-zero */
*Cir++ = *Lir; /* copy row index and increment pointer */
*Cjc += 1; /* bump up running total */
*C++ = d; /* store the value and increment pointer */
}
Lir++; /* point to next row value */
}
B += Bm; /* point to next column of B */
Cjc++; /* point to next value of non-zero running total */
}
}
Also note that this version of the code is production quality, meaning that it checks for 0's along the way and only stores explicit non-zeros in the sparse result. The ABL.c code posted above does not do that. E.g., a run where there are forced 0 results:
>> A = reshape(1:12,3,4); A = [A,fliplr(A)]
A =
1 4 7 10 10 7 4 1
2 5 8 11 11 8 5 2
3 6 9 12 12 9 6 3
>> B = [ones(4,4);-ones(4,4)]
B =
1 1 1 1
1 1 1 1
1 1 1 1
1 1 1 1
-1 -1 -1 -1
-1 -1 -1 -1
-1 -1 -1 -1
-1 -1 -1 -1
>> L = sparse(rand(3,4)<.5)
L =
3×4 sparse logical array
(1,1) 1
(1,2) 1
(2,2) 1
(3,2) 1
(1,3) 1
(2,3) 1
(3,3) 1
(2,4) 1
>> (A*B).*L
ans =
All zero sparse: 3×4
>> sarek(ans)
SAREK -- Sparse Analyzer Real Et Komplex , by James Tursa
Compiled in version R2020a (with -R2018a option)
Running in version R2020a
Matrix is double ...
Matrix is real ...
M = 3 >= 0 OK ...
N = 4 >= 0 OK ...
Nzmax = 1 >= 1 OK ...
Jc[0] == 0 OK ...
Jc[N] = 0 <= Nzmax = 1 OK ...
Jc array OK ...
Ir array OK ...
All stored elements nonzero OK ...
All sparse integrity checks OK
ans =
0
>>
>> ABL(A,B,L)
ans =
(1,1) 0
(1,2) 0
(2,2) 0
(3,2) 0
(1,3) 0
(2,3) 0
(3,3) 0
(2,4) 0
>> sarek(ans)
SAREK -- Sparse Analyzer Real Et Komplex , by James Tursa
Compiled in version R2020a (with -R2018a option)
Running in version R2020a
Matrix is double ...
Matrix is real ...
M = 3 >= 0 OK ...
N = 4 >= 0 OK ...
Nzmax = 8 >= 1 OK ...
Jc[0] == 0 OK ...
Jc[N] = 8 <= Nzmax = 8 OK ...
Jc array OK ...
Ir array OK ...
ERROR: There are 8 explicit 0's in matrix
TO FIX: B = 1*A;
There were ERRORS found in matrix!
ans =
1
>>
>> ATBL(A.',B,L)
ans =
All zero sparse: 3×4
>> sarek(ans)
SAREK -- Sparse Analyzer Real Et Komplex , by James Tursa
Compiled in version R2020a (with -R2018a option)
Running in version R2020a
Matrix is double ...
Matrix is real ...
M = 3 >= 0 OK ...
N = 4 >= 0 OK ...
Nzmax = 8 >= 1 OK ...
Jc[0] == 0 OK ...
Jc[N] = 0 <= Nzmax = 8 OK ...
Jc array OK ...
Ir array OK ...
All stored elements nonzero OK ...
All sparse integrity checks OK
ans =
0
You can see that the ABL mex routine stores explicit 0's in the sparse result, whereas the ATBL mex routine checks for this and does not store them. The sarek function is a mex routine that checks sparse matrices for integrity. It can be found here:
As a side issue, this 0 check makes it trickier to multi-thread the code. You can't have each thread work completely independently on part of the result since the resulting memory locations of one part depend on the number of actual non-zeros of the other part(s). What you could do is multi-thread where each thread assumes no 0's in previous threads (that way each thread knows where to put its results in memory), then in a second pass clean the 0's if there were any detected. That would not be a bad strategy, particularly if you don't expect to have 0's in the resulting sparse spots (e.g., if you were typically working with arbitrary non-integer real values). You would only have to pass through the memory twice to fix things up if you detected a 0 in the first pass.
See Also
Categories
Find more on Sparse Matrices in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!