MATLAB Answers

How to turn the data in mini batch into a deep learning array

1 view (last 30 days)
Maik Noungoua
Maik Noungoua on 8 Aug 2020
Answered: Divya Gaddipati on 11 Aug 2020
I am using a deep learning network to predict the distanes of the codewords of a for wireless communications systems. I am trying to update the parameters of my network with the adam algorithm.
I need help of how to turn the data inside the mini batch into data for the dl array.
Thank you in advance
clc
clear all
S=100E3;
SNR=10;
s_1=-1i; % these are used as the labels
s_2=-1;
s_3=1i;
s_4=1;
m=round(rand(1,2*S));
A=reshape(m,2,S); % the 1x2S row vector is transformed into a 2xS matrix A
x = zeros(1,6);
s = 0;
s = s+(sum(abs(A-[0;0]*ones(1,S)))<0.1)*s_1;
s = s+(sum(abs(A-[0;1]*ones(1,S)))<0.1)*s_2;
s = s+(sum(abs(A-[1;1]*ones(1,S)))<0.1)*s_3;
s = s+(sum(abs(A-[1;0]*ones(1,S)))<0.1)*s_4; % label keeps track of what is originally transmitted
s_= reshape(s,2,S/2);
%% Generation of the complex channel
H = (randn(2,S)+1i*randn(2,S))/2; % Normalised 2x2 Rayleigh MIMO fading channel
n=(randn(2,S/2)+1i*randn(2,S/2))/sqrt(2*10^(SNR/10));
for (k=1:S/2)
nchr_deta= [s_1 s_2 s_3 s_4];
H_ = H(:,2*k-1:2*k);
% x(:,k)=H_*s_(:,k); % This line has been added to compute the numerical SNR: 10*log10(var(reshape(x,1,S))/var(reshape(n,1,S))) (not necessary otherwi
y = H_*s_(:,k)+ n(:,k);
y_arrayReal(:,k) = real(y); %real values of y
y_arrayImag(:,k) = imag(y); %imaginary values of y
H_arrayReal(:,k) = real(H_(1,1));
H_arrayImag(:,k) = imag(H_(1,1));
%% x processsing
x(k,1:2) = [y_arrayReal(1,k) y_arrayReal(2,k)];
x(k,3:4) = [y_arrayImag(1,k) y_arrayImag(2,k)];
x(k,5) = H_arrayReal(k);
x(k,6) = H_arrayImag(k);
%Gerenation of the list of radius for the DNN using MMSE
W_MMSE=(H_'*H_+10^(-SNR/10)*eye(2,2))^-1*H_'; % MMSE equalisation matrix
s_hat_MMSE(:,k)=W_MMSE*y;
outDNN(k,:) = s_hat_MMSE(:,k); %output of the DNN
end
out1 = [real(outDNN(:,1)),imag(outDNN(:,1)),real(outDNN(:,2)),imag(outDNN(:,2))]; %outg of the DNN
%% Defining the network
layers = [
imageInputLayer([50000 6],'Name','Input layer','Mean',ones([50000 6]))
convolution2dLayer([10000 6],5,'Name','conv1')
reluLayer('Name', 'relu1')
% convolution2dLayer([10 6], 1,'Name','conv2')
% reluLayer('Name','relu2')
% convolution2dLayer([1000 6],1,'Name','conv3')
% reluLayer('Name','relu3')
fullyConnectedLayer(5,'Name','Output layer');
]
lgraph = layerGraph(layers);
%% Creation of the dl network
dlnet = dlnetwork(lgraph);
%% Specifying the training options
miniBatchSize = 128;
numEpochs = 20;
numObservations = numel(x);
numIterationsPerEpoch = floor(numObservations./miniBatchSize);
%% Initialising average gradient, squared average gradient and iteration counter
averageGrad = [];
averageSqGrad = [];
iteration = 1;
%% Initialise the training progress plot
plots = "training-progress";
if plots == "training-progress"
figure
lineLossTrain = animatedline;
xlabel("Total Iterations")
ylabel("Loss")
end
%% Train the network while updating its parameters using adamupdate funtion
for epoch = 1:numEpochs
% Convert mini-batch of data to dlarray.
X = dlarray(x(:,:,:,idx),'SSCB');
% Evaluate the model gradients using dlfeval and the
% modelGradients function defined below.
gradients = dlfeval(@modelGradients,net,X,(out1:,idx));
%Update the network using adam optimiser
[dlnet,averageGrad,averageSqGrad] = adamupdate(dlnet,grad,averageGrad,averageSqGrad,iteration);
% Display the training progress.
if plots == "training-progress"
addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
title("Loss During Training: Epoch - " + epoch + "; Iteration - ")
drawnow
end
% Increment the iteration counter.
iteration = iteration + 1;
end

  0 Comments

Sign in to comment.

Answers (1)

Divya Gaddipati
Divya Gaddipati on 11 Aug 2020
dlarray is used to convert the data to deep learning array, which you are already doing here:
X = dlarray(x(:,:,:,idx),'SSCB');
You can refer to the following link for more information

  0 Comments

Sign in to comment.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!