Optimising Nearest Neighbor Program
3 views (last 30 days)
Show older comments
Hi Guys,
I am trying to optmimise this code so that it runs in under 10 seconds for N=20k.
It currently takes around 40 seconds to run.
I think that I need to vecotirse some or all of the loops so that the calculations are done at the same time, but I cannot figure out how to do it.
Any help would be much appreciated as we are learning from home with next to no support from the University.
Here is the code.
s = zeros(N,N);
for n = 1:3
for a = 1:N
count = 0;
for b=1:N
off=0;
if (pos(n,a)<=0.25 && pos(n,b) >= 0.75)
off=1;
elseif pos(n,a)>=0.75 && pos(n,b)<=0.25
off=-1;
end
s(a,b)=s(a,b)+(pos(n,a)-pos(n,b)+off)^2.0;
if n == 3
s(a,b)=sqrt(s(a,b));
end
end
end
end
match = zeros(1,N);
for a=1:N
mindist=1e10;
for c=1:N
if (a~=c)
mindist=min(s(a,c),mindist);
if (mindist==s(a,c))
match(a)=c;
end
end
end
end
4 Comments
KSSV
on 30 Oct 2020
Refer this, it might help you: https://in.mathworks.com/matlabcentral/fileexchange/44334-nearest-neighboring-particle-search-using-all-particles-search
Accepted Answer
Bruno Luong
on 31 Oct 2020
I put my comments as answer here, so you can accept if it helps
N=20000;
nd=3;
pos=rand(nd,N);
pos_r = reshape(pos,[nd 1 N]);
s = zeros(N);
for b=1:N
posb = pos_r(:,:,b);
off = (pos<=0.25 & posb>=0.75) - ...
(pos>=0.75 & posb<=0.25);
sb = sum((pos-posb+off).^2,1);
s(:,b) = sb(:);
end
s = sqrt(s);
s(1:N+1:end) = Inf;
[~,match] = min(s,[],2);
match = match.'; % row vector
Still if one dosn't have to deal with the odd OFFSET, the delaunay approach is much faster.
0 Comments
More Answers (3)
KSSV
on 30 Oct 2020
Edited: KSSV
on 30 Oct 2020
N = 1000;
seed = 1234;
%seed=input('Random number seed = ');
rng(seed)
pos=rand(3,N);
a = 1:N ;
b = 1:N ;
[a,b] = meshgrid(a,b) ;
S = zeros(N,N) ;
for n = 1:3
off = zeros(N,N) ;
P = pos(n,:) ;
off(P(a)<=0.25 && P(b) >= 0.75) = +1 ;
off(P(a)>=0.75 && P(b)<=0.25) = -1 ;
S = S+(P(a)-P(b)+off).^2;
end
S = sqrt(S);
9 Comments
KSSV
on 30 Oct 2020
How did you calculate the time? Check the below:
N = 1000;
seed = 1234;
%seed=input('Random number seed = ');
rng(seed)
pos=rand(3,N);
t1 = tic ;
s = zeros(N,N);
for n = 1:3
for a = 1:N
count = 0;
for b=1:N
off=0;
if (pos(n,a)<=0.25 && pos(n,b) >= 0.75)
off=1;
elseif pos(n,a)>=0.75 && pos(n,b)<=0.25
off=-1;
end
s(a,b)=s(a,b)+(pos(n,a)-pos(n,b)+off)^2.0;
if n == 3
s(a,b)=sqrt(s(a,b));
end
end
end
end
t1 = toc(t1) ;
t2 = tic ;
a = 1:N ;
b = 1:N ;
[a,b] = meshgrid(a,b) ;
S = zeros(N,N) ;
for n = 1:3
off = zeros(N,N) ;
P = pos(n,:) ;
off(P(a)<=0.25 && P(b) >= 0.75) = +1 ;
off(P(a)>=0.75 && P(b)<=0.25) = -1 ;
S = S+(P(a)-P(b)+off).^2;
end
S = sqrt(S);
t2 = toc(t2) ;
Now check t1, t2 for different size inputs. My bet is always the second code will be faster. I am comparing only the first part off the code.
Bruno Luong
on 30 Oct 2020
Edited: Bruno Luong
on 30 Oct 2020
I remove your hanling of offset (not sure what is the purpose), and this is a much faster method using delaunay triangulation:
clear
N=20000;
nd=3;
pos=rand(nd,N);
tic
s = zeros(N,N);
for n = 1:3
for a = 1:N
count = 0;
for b=1:N
s(a,b)=s(a,b)+(pos(n,a)-pos(n,b))^2.0;
if n == 3
s(a,b)=sqrt(s(a,b));
end
end
end
end
match = zeros(1,N);
for a=1:N
mindist=1e10;
for c=1:N
if (a~=c)
mindist=min(s(a,c),mindist);
if (mindist==s(a,c))
match(a)=c;
end
end
end
end
toc % Elapsed time is 56.414344 seconds.
Find nearest neighbour of the same set of point in 2D or 3D
% INPUT: pos is array of size (nd x N), coordinates of N points in R^nd
tic
T = delaunay(pos.');
p = nchoosek(1:size(T,2),2);
P = T(:,p);
P = reshape(P,[],2);
P = unique(sort(P,2),'rows');
P1 = P(:,1);
P2 = P(:,2);
d2 = sum((pos(:,P2)-pos(:,P1)).^2,1);
A = [P1(:), d2(:), P2(:);
P2(:), d2(:), P1(:)];
A = sortrows(A,[1 2]);
b = [true; diff(A(:,1),1)>0];
A = A(b,:);
nn = A(:,3).'; % index of nearest neighbour
d = sqrt(A(:,2)).'; % corresponding distance
toc % Elapsed time is 0.349317 seconds.
isequal(match,nn) % 1
Now it doesn't help you for your practice, but I still pot it here for future readers who seek for fast method.
4 Comments
Bruno Luong
on 30 Oct 2020
Edited: Bruno Luong
on 30 Oct 2020
I run it 10 times and isequal(match,nn) return TRUE. So the answer match 10 times with random points.
> for k=1:10; benchnntest; end
Elapsed time is 49.127755 seconds.
Elapsed time is 0.516644 seconds.
ans =
logical
1
Elapsed time is 51.189459 seconds.
Elapsed time is 0.459055 seconds.
ans =
logical
1
Elapsed time is 50.431960 seconds.
Elapsed time is 0.465887 seconds.
ans =
logical
1
Elapsed time is 50.426246 seconds.
Elapsed time is 0.454885 seconds.
ans =
logical
1
Elapsed time is 50.651649 seconds.
Elapsed time is 0.567084 seconds.
ans =
logical
1
Elapsed time is 50.889422 seconds.
Elapsed time is 0.461514 seconds.
ans =
logical
1
Elapsed time is 50.678441 seconds.
Elapsed time is 0.491820 seconds.
ans =
logical
1
Elapsed time is 50.476219 seconds.
Elapsed time is 0.451430 seconds.
ans =
logical
1
Elapsed time is 52.659327 seconds.
Elapsed time is 0.443114 seconds.
ans =
logical
1
Elapsed time is 52.004992 seconds.
Elapsed time is 0.459873 seconds.
ans =
logical
1
>>
Image Analyst
on 30 Oct 2020
MATLAB is column major order, which means that the left most indexes go faster because they are adjacent in memory. MATLAB goes down rows first, then moves over to the next column and goes down its rows. So this slow code
for row = 1 : rows
for col = 1 : columns
s(row, col) = whatever; % Col iterates fastest
end
end
will (or may be) be slower than this fast code
for col = 1 : columns
for row = 1 : rows
s(row, col) = whatever; % row iterates fastest
end
end
Note that, in your code, n is your left most index of your arrays, yet you had the n loop as the outer loop, which is the slowest possible to do it. If possible, see if you can move n to an inner loop. I've had luck in the past getting nested loops to speed up doing that.
8 Comments
Image Analyst
on 30 Oct 2020
Samuel, going by your description, I'd try something like this to find the closest point.
N = 1000;
xyz = rand(N, 3); % Get N randomly located points in 3-D.
for k = 1 : N
% Get the squared distance of point k to every other point in the array.
distancesSquared = ((xyz(k, 1) - xyz(:, 1)) .^2 + ...
(xyz(k, 2) - xyz(:, 2)) .^2 + ...
(xyz(k, 3) - xyz(:, 3)) .^2);
% We don't want to consider the distance of the point to itself, so set any zeros to infinity.
distancesSquared(distancesSquared==0) = inf;
% Find the min value and the index of that min for the other points.
[minDist2, index] = min(distancesSquared);
% Print it out.
fprintf('Point %d at (%.2f, %.2f, %.2f) is closest to point %d at (%.2f, %.2f, %.2f).\n',...
k, xyz(k, 1), xyz(k, 2), xyz(k, 3), index, xyz(index, 1), xyz(index, 2), xyz(index, 3));
end
It prints out stuff like:
Point 1 at (0.70, 0.37, 0.03) is closest to point 312 at (0.70, 0.37, 0.04).
Point 2 at (0.09, 0.71, 0.91) is closest to point 918 at (0.10, 0.72, 0.94).
Point 3 at (0.53, 0.95, 0.47) is closest to point 50 at (0.54, 0.93, 0.43).
etc.
Point 998 at (0.03, 0.67, 0.34) is closest to point 972 at (0.05, 0.59, 0.39).
Point 999 at (0.99, 0.87, 0.16) is closest to point 592 at (0.94, 0.93, 0.16).
Point 1000 at (0.54, 0.34, 0.42) is closest to point 540 at (0.55, 0.33, 0.42).
It doesn't do the stuff about the edges of the square though.
See Also
Categories
Find more on Logical in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!