Use only the schema without using pre-trained weights

2 views (last 30 days)
I want to use RESnet-50 without using pre-trained weights by deep learning Toolbox,how can i do?

Answers (1)

Sameer
Sameer on 13 Sep 2024
Hi Wenlin
From my understanding, you want to use the "ResNet-50" architecture in MATLAB's "Deep Learning Toolbox" without utilizing pre-trained weights.
This involves defining the "ResNet-50" model architecture from scratch, initializing it with random weights, and then training it on your dataset. Here's how you can achieve this:
1. Define the ResNet-50 Architecture: You need to manually define the "ResNet-50" architecture, which involves specifying the layers and how they connect. MATLAB allows you to define custom layers and networks.
2. Initialize Weights: By default, MATLAB initializes the weights randomly when you define the layers from scratch, so you won't be using any pre-trained weights in this setup.
3. Train the Network: Once the architecture is defined, you can train the network using your dataset with the "trainNetwork" function.
Here’s a basic outline:
% Define the layers of ResNet-50
layers = [
imageInputLayer([224 224 3],"Name","input")
convolution2dLayer(7,64,'Stride',2,'Padding','same','Name','conv1')
batchNormalizationLayer('Name','bn_conv1')
reluLayer('Name','conv1_relu')
maxPooling2dLayer(3,'Stride',2,'Padding','same','Name','pool1')
% Add more layers according to the ResNet-50 architecture
% This includes multiple residual blocks
fullyConnectedLayer(numClasses,'Name','fc')
softmaxLayer('Name','softmax')
classificationLayer('Name','output')
];
% Create a layer graph from the layers
lgraph = layerGraph(layers);
% Add connections for ResNet-50 (skip connections)
% Use the addLayers and connectLayers functions to create the full graph
% Define training options
options = trainingOptions('sgdm', ...
'MaxEpochs',30, ...
'InitialLearnRate',0.01, ...
'Verbose',true, ...
'Plots','training-progress');
% Load your dataset
% [trainImages, trainLabels] = loadYourDataFunction();
% Train the network
% net = trainNetwork(trainImages, trainLabels, lgraph, options);
Hope this helps!

Categories

Find more on Image Data Workflows in Help Center and File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!