- “softmaxLayer” function: https://www.mathworks.com/help/deeplearning/ref/nnet.cnn.layer.softmaxlayer.html
- “pixelClassificationLayer” function: https://www.mathworks.com/help/vision/ref/nnet.cnn.layer.pixelclassificationlayer.html
How can i do multi-output for u-net model
9 views (last 30 days)
Show older comments
Hi, i just want to know did images regression that create with u-net model can train/predict multi-output? If it's can how can i do.
For my model, i created and trained the network by using Train Convolutional Neural Network for Regression as a guideline which is it have only one output but I need to create the model with the same artchitecture but more output. The multi-ouput that i need to train create from the original output which split into groups based on the values in the array as shown in the picture below.
Note: My old model predict the original image and my new model that i need to try, i want it predict the output 1 to output 4.
0 Comments
Answers (1)
Aishwarya
on 3 Nov 2023
Hello,
As per my understanding, you have created a U-Net model using the below mentioned document as reference. Now you wish to create a multi-class U-Net model and split the output mask for each class.
As the mentioned documentation provides a simple convolution neural network for regression which outputs a single continuous value, I assume you have added up-sampling layers and skip-connections to make the network into U-Net architecture. To modify the network into multi-class U-Net model, consider changing the the last few layers after the last “relu” layer as show in example code below:
conv = convolution2dLayer(1, numClasses); % numClasses is the number of output classes
softmax = softmaxLayer();
outputLayer = pixelClassificationLayer();
After getting the output mask (“labelled_mask”) from the U-Net model, each class output can be separated using the example code below.
% Extract each class into a separate image
num_classes = 4;
for i = 1:num_classes
% Create binary mask for current class
class_mask = (labeled_mask == i);
% Apply binary mask to input image
class_img = img .* uint8(class_mask);
% Show output image
figure,
imshow(class_img);
end
Please refer to below MathWorks documentation for more details about the functions used:
I hope this helps!
0 Comments
See Also
Categories
Find more on Pattern Recognition and Classification in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!