VAE import custom data

20 views (last 30 days)
Glacial Claw
Glacial Claw on 8 Feb 2023
Edited: Glacial Claw on 20 Feb 2023
I was able to implement this line by line. That being said, I am trying to get my own custom data uploaded to the Ai for training.
I have 2 rar files, with color images (1024x1024 [will be resized]). The first file has a lot of images for training, and the second file has a few images for testing the algorithm.
The issue is, the VAE used in the link uses a function called "processImagesMNIST", I have looked into the function that makes this up, not sure on how I have to structure my training and testing data to meet the criteria that the function allows for.
Or is there another way to upload my own training data that the VAE can use, without using the special function?
function X = processImagesMNIST(filename)
% The MNIST processing functions extract the data from the downloaded IDX
% [What are IDX files???]
% files into MATLAB arrays. The processImagesMNIST function performs these
% operations: Check if the file can be opened correctly. Obtain the magic
% number by reading the first four bytes. The magic number is 2051 for
% image data, and 2049 for label data. Read the next 3 sets of 4 bytes,
% which return the number of images, the number of rows, and the number of
% columns. Read the image data. Reshape the array and swaps the first two
% dimensions due to the fact that the data was being read in column major
% format. Ensure the pixel values are in the range [0,1] by dividing them
% all by 255, and converts the 3-D array to a 4-D dlarray object. Close the
% file.
%[What do I need to do to make my dataset applicable for this function? Or is there another way to implement my own data?]
dataFolder = fullfile(tempdir,'mnist');
gunzip(filename,dataFolder)
[~,name,~] = fileparts(filename);
[fileID,errmsg] = fopen(fullfile(dataFolder,name),'r','b');
if fileID < 0
error(errmsg);
end
magicNum = fread(fileID,1,'int32',0,'b');
if magicNum == 2051
fprintf('\nRead MNIST image data...\n')
end
numImages = fread(fileID,1,'int32',0,'b');
fprintf('Number of images in the dataset: %6d ...\n',numImages);
numRows = fread(fileID,1,'int32',0,'b');
numCols = fread(fileID,1,'int32',0,'b');
X = fread(fileID,inf,'unsigned char');
X = reshape(X,numCols,numRows,numImages);
X = permute(X,[2 1 3]);
X = X./255;
X = reshape(X, [28,28,1,size(X,3)]);
fclose(fileID);
end

Answers (1)

Image Analyst
Image Analyst on 8 Feb 2023
I would just extract all your images to regular image files (like PNG images) and then instead of calling
X = processImagesMNIST(filename)
just use imread()
X = imread(filename)
  18 Comments
Glacial Claw
Glacial Claw on 16 Feb 2023
Ok I was able to get it to read the images, now the console shows each image being read.
However, now I get an error where the concatanation dimensions are incorrect.
Based on the "processImagesMNIST" function, there are these commands: (I wonder if this has anything to do with it.)
numImages = fread(fileID,1,'int32',0,'b');
fprintf('Number of images in the dataset: %6d ...\n',numImages);
numRows = fread(fileID,1,'int32',0,'b');
numCols = fread(fileID,1,'int32',0,'b');
X = fread(fileID,inf,'unsigned char');
X = reshape(X,numCols,numRows,numImages);
X = permute(X,[2 1 3]);
X = X./255;
X = reshape(X, [28,28,1,size(X,3)]);
Glacial Claw
Glacial Claw on 17 Feb 2023
Edited: Glacial Claw on 20 Feb 2023
So, I was able to fix the concat error, by swapping the "train_datastored" with "thisImage" variables. Now I have a new issue because of this.
Is there a way to change the image channel dimension?
This is the entire code I am using, I have redacted the file paths, as they are not necessary and doesn't cause any issues.
filePattern = fullfile('(redacted)', '*.jpg');
theFiles = dir(filePattern);
for k = 1 : length(theFiles)
[filepath,name,ext] = fileparts(filePattern);
baseFileName = theFiles(k).name;
fullFileName = fullfile(theFiles(k).folder, baseFileName);
fprintf('Now reading %s\n', fullFileName);
thisImage = imread(fullFileName);
thisImage_resized = imresize(thisImage, [28 28]);
if k == 1
Train_datastored = thisImage_resized;
else
Train_datastored = cat(4,thisImage_resized, Train_datastored);
end
end
XTrain = Train_datastored;
filePattern_test = fullfile('(redacted)', '*.jpg');
theFiles_test = dir(filePattern_test);
for k = 1 : length(theFiles_test)
[filepath_test,name,ext_test] = fileparts(filePattern_test);
baseFileName_test = theFiles_test(k).name;
fullFileName_test = fullfile(theFiles_test(k).folder, baseFileName_test);
fprintf('Now reading %s\n', fullFileName_test);
thisImage_test = imread(fullFileName_test);
thisImage_test_resized = imresize(thisImage_test, [28 28]);
if k == 1
Test_datastored = thisImage_test_resized;
else
Test_datastored = cat(4,thisImage_test_resized, Test_datastored);
end
end
XTest = Test_datastored;
numLatentChannels = 16;
imageSize = [28 28 1];
layersE = [
imageInputLayer(imageSize,Normalization="none")
convolution2dLayer(3,32,Padding="same",Stride=2)
reluLayer
convolution2dLayer(3,64,Padding="same",Stride=2)
reluLayer
fullyConnectedLayer(2*numLatentChannels)
samplingLayer];
projectionSize = [7 7 64];
numInputChannels = size(imageSize,1);
layersD = [
featureInputLayer(numLatentChannels)
projectAndReshapeLayer(projectionSize)
transposedConv2dLayer(3,64,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,32,Cropping="same",Stride=2)
reluLayer
transposedConv2dLayer(3,numInputChannels,Cropping="same")
sigmoidLayer];
netE = dlnetwork(layersE);
netD = dlnetwork(layersD);
numEpochs = 120;
miniBatchSize = 50;
learnRate = 1e-3;
dsTrain = arrayDatastore(XTrain,IterationDimension=4);
numOutputs = 1;
mbq = minibatchqueue(dsTrain,numOutputs, ...
MiniBatchSize = miniBatchSize, ...
MiniBatchFcn=@preprocessMiniBatch, ...
MiniBatchFormat="SSCB", ...
PartialMiniBatch="discard");
trailingAvgE = [];
trailingAvgSqE = [];
trailingAvgD = [];
trailingAvgSqD = [];
numObservationsTrain = size(XTrain,4);
numIterationsPerEpoch = ceil(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
monitor = trainingProgressMonitor( ...
Metrics="Loss", ...
Info="Epoch", ...
XLabel="Iteration");
epoch = 0;
iteration = 0;
dsTest = arrayDatastore(XTest,IterationDimension=4);
numOutputs = 1;
mbqTest = minibatchqueue(dsTest,numOutputs, ...
MiniBatchSize = miniBatchSize, ...
MiniBatchFcn=@preprocessMiniBatch, ...
MiniBatchFormat="SSCB");
YTest = modelPredictions(netE,netD,mbqTest);
err = mean((XTest-YTest).^2,[1 2 3]);
figure
histogram(err)
numImages = 2;
ZNew = randn(numLatentChannels,numImages);
ZNew = dlarray(ZNew,"CB");
YNew = predict(netD,ZNew);
YNew = extractdata(YNew);

Sign in to comment.

Products


Release

R2022b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!