Clear Filters
Clear Filters

Encountered an inability to add error when using the modified UNet structure for prediction.

13 views (last 30 days)
I am trying to add a ViT module to the UNet constructed by the updated unet3d in MATLAB r2024a, and everything is normal during the training process. I have verified the performance of the model after a certain period of time. The analyzeNetwork function shows no errors, and the size of the front and back connections is 65 * 1024 * 1 (SCB). This is the result of serializing the image.
Incorrect use of dlnetwork/predict (line 658)
Execution failed during layer 'Transformer PositionEmbedding, Encoder Stage-4 Add-1'.
Error Unet3dTrain (line 288)
PredictedLabel=predict (net, image);
Reason:
Incorrect use of matlab. internal. path. cnn MLFusedNetwork/forwardExampleInputs
Arrays are not compatible for addition
Problem seems to occur when add the vector before position embedding and after position embedding.
There are no issues with adding custom print input size layers before and after this layer.
Here is part of the structure of the network.
My native language is not English, and I am using translation software. Please forgive any errors. The following is the code, which includes non English comments.
Code:
clc; clear;
rng(1);
% ========== 数据读取和数据集创建阶段 ==========
% 指定图像和标签文件的位置
imageDir = 'X:\BaiduDownload\brats2021\ProcessedData';
labelDir = 'X:\BaiduDownload\brats2021\ProcessedData'; % 标签数据存储在同一位置
% 定义类别名和对应的标签ID
categories = ["background", "necrotic_tumor_core", "peritumoral_edema", "enhancing_tumor"]; % 有4类
labelIDs = [0, 1, 2, 4]; % 分别对应上述类别
% 假定输入数据为 128x128 的体积,有一个背景类和一个肿瘤类
inputSize = [128 128 8 2]; % 最后一个维度1表示1种不同的模态
numClasses = 4; % 类别数(背景和肿瘤)
% 创建图像和标签的数据存储
imds = imageDatastore(imageDir, 'FileExtensions','.mat', 'ReadFcn', @customReadData);
pxds = pixelLabelDatastore(labelDir, categories, labelIDs, 'FileExtensions','.mat', 'ReadFcn', @customReadLabels);
% 分割训练集和验证集
numFiles = numel(imds.Files);
idx = randperm(numFiles); % 随机打乱索引
% numFiles = round(0.001 * numFiles); % 选取小数据集测试
numTrain = round(0.9 * numFiles); % 假设80%的数据用于训练
% 使用索引分割数据
trainImds = subset(imds, idx(1:numTrain));
trainPxds = subset(pxds, idx(1:numTrain));
valImds = subset(imds, idx(numTrain+1:end));
valPxds = subset(pxds, idx(numTrain+1:end));
% 组合训练和验证数据
dsTrain = combine(trainImds, trainPxds);
dsVal = combine(valImds, valPxds);
% 补充函数
function labels = customReadLabels(filename)
fileContent = load(filename);
segmentation = fileContent.segmentation(:,:,74:81);
% 假设原始大小为240x240,计算裁剪偏移
cropSize = 200;
startCrop = (size(segmentation,1) - cropSize) / 2 + 1;
endCrop = startCrop + cropSize - 1;
% 四周均匀裁剪为160x160
croppedSegmentation = segmentation(startCrop:endCrop, startCrop:endCrop, :);
% 重置三维数据大小到128x128,使用最近邻插值方法
segmentationResized = imresize3(croppedSegmentation, [128, 128, size(croppedSegmentation, 3)], 'Method', 'nearest');
% 创建分类数据,确保使用正确的类别名
labels = categorical(segmentationResized, [0, 1, 2, 4], {'background', 'necrotic_tumor_core', 'peritumoral_edema', 'enhancing_tumor'});
end
function data = customReadData(filename)
fileContent = load(filename);
% 提取特定切片
originalData = squeeze(fileContent.combinedData(:,:,74:81,[1, 3]));
% 同样计算裁剪偏移
cropSize = 200;
startCrop = (size(originalData,1) - cropSize) / 2 + 1;
endCrop = startCrop + cropSize - 1;
% 四周均匀裁剪为160x160
croppedData = originalData(startCrop:endCrop, startCrop:endCrop, :, :);
% 初始化一个新的四维数组,用于存储调整后的数据
resizedData = zeros(128, 128, size(croppedData, 3), size(croppedData, 4));
% 循环处理每一个通道
for i = 1:size(croppedData, 4)
% 调整每个通道的数据大小并进行灰度化
resizedData(:,:,:,i) = imresize3(mat2gray(croppedData(:,:,:,i)), [128, 128, size(croppedData, 3)]);
end
% 输出处理后的数据
data = resizedData;
end
% 创建3D U-Net网络
net = unet3d(inputSize, numClasses, Encoderdepth = 3);
% ========== Unet网络改造阶段 ==========
% 改造ResBlock
% 对Stage-1进行的操作
% 添加一个1x1x1卷积层以适应通道数
adjustConvLayer = convolution3dLayer([1, 1, 1], 64, 'Name', 'Encoder-Stage-1-Conv-Ident-1', 'Padding', 'same');
adjustBnLayer = batchNormalizationLayer('Name', 'Encoder-Stage-1-BN-Ident-1');
addLayer = additionLayer(2, 'Name', 'Encoder-Stage-1-Add-1');
% 添加层到图
net = addLayers(net, adjustConvLayer);
net = addLayers(net, adjustBnLayer);
net = addLayers(net, addLayer);
% 连接新层
net = connectLayers(net, 'encoderImageInputLayer', 'Encoder-Stage-1-Conv-Ident-1');
net = connectLayers(net, 'Encoder-Stage-1-Conv-Ident-1', 'Encoder-Stage-1-BN-Ident-1');
net = disconnectLayers(net,'Encoder-Stage-1-BN-2', 'Encoder-Stage-1-ReLU-2');
net = connectLayers(net, 'Encoder-Stage-1-BN-2', 'Encoder-Stage-1-Add-1/in1');
net = connectLayers(net, 'Encoder-Stage-1-BN-Ident-1', 'Encoder-Stage-1-Add-1/in2');
net = connectLayers(net, 'Encoder-Stage-1-Add-1', 'Encoder-Stage-1-ReLU-2');
% 对Stage-2进行的操作
% 添加一个1x1x1卷积层以适应通道数
adjustConvLayer2 = convolution3dLayer([1, 1, 1], 128, 'Name', 'Encoder-Stage-2-Conv-Ident-1', 'Padding', 'same');
adjustBnLayer2 = batchNormalizationLayer('Name', 'Encoder-Stage-2-BN-Ident-1');
addLayer2 = additionLayer(2, 'Name', 'Encoder-Stage-2-Add-1');
% 添加层到图
net = addLayers(net, adjustConvLayer2);
net = addLayers(net, adjustBnLayer2);
net = addLayers(net, addLayer2);
% 连接新层
net = connectLayers(net, 'Encoder-Stage-1-MaxPool', 'Encoder-Stage-2-Conv-Ident-1');
net = connectLayers(net, 'Encoder-Stage-2-Conv-Ident-1', 'Encoder-Stage-2-BN-Ident-1');
net = disconnectLayers(net,'Encoder-Stage-2-BN-2', 'Encoder-Stage-2-ReLU-2');
net = connectLayers(net, 'Encoder-Stage-2-BN-2', 'Encoder-Stage-2-Add-1/in1');
net = connectLayers(net, 'Encoder-Stage-2-BN-Ident-1', 'Encoder-Stage-2-Add-1/in2');
net = connectLayers(net, 'Encoder-Stage-2-Add-1', 'Encoder-Stage-2-ReLU-2');
% 对Stage-3进行操作
% 添加一个1x1x1卷积层以适应通道数
adjustConvLayer3 = convolution3dLayer([1, 1, 1], 256, 'Name', 'Encoder-Stage-3-Conv-Ident-1', 'Padding', 'same');
adjustBnLayer3 = batchNormalizationLayer('Name', 'Encoder-Stage-3-BN-Ident-1');
addLayer3 = additionLayer(2, 'Name', 'Encoder-Stage-3-Add-1');
% 添加层到图
net = addLayers(net, adjustConvLayer3);
net = addLayers(net, adjustBnLayer3);
net = addLayers(net, addLayer3);
% 连接新层
net = connectLayers(net, 'Encoder-Stage-2-MaxPool', 'Encoder-Stage-3-Conv-Ident-1');
net = connectLayers(net, 'Encoder-Stage-3-Conv-Ident-1', 'Encoder-Stage-3-BN-Ident-1');
net = disconnectLayers(net,'Encoder-Stage-3-BN-2', 'Encoder-Stage-3-ReLU-2');
net = connectLayers(net, 'Encoder-Stage-3-BN-2', 'Encoder-Stage-3-Add-1/in1');
net = connectLayers(net, 'Encoder-Stage-3-BN-Ident-1', 'Encoder-Stage-3-Add-1/in2');
net = connectLayers(net, 'Encoder-Stage-3-Add-1', 'Encoder-Stage-3-ReLU-2');
% BatchNormalization改造为GroupNormalization
% 获取网络中所有层的名称
layerNames = {net.Layers.Name};
% 循环遍历所有层的名称,寻找匹配“BN”的层
for i = 1:length(layerNames)
if contains(layerNames{i}, 'BN')
% 创建新的组归一化层
gnLayer = groupNormalizationLayer(4, 'Name', layerNames{i});
% 替换现有的 BN 层
net = replaceLayer(net, layerNames{i}, gnLayer);
end
end
% 添加Vision Transformer Layer
PatchEmbeddingLayer1 = patchEmbeddingLayer([4 4 2], 1024, 'Name', 'Transformer-PatchEmbedding');
EmbeddingConcatenationLayer1 = embeddingConcatenationLayer('Name', 'Transformer-EmbeddingConcatenation');
PositionEmbeddingLayer1 = positionEmbeddingLayer(1024, 1024, 'Name', 'Transformer-PositionEmbedding');
addLayer4 = additionLayer(2, 'Name', 'Encoder-Stage-4-Add-1');
addLayer5 = additionLayer(2, 'Name', 'Encoder-Stage-4-Add-2');
addLayer6 = additionLayer(2, 'Name', 'Encoder-Stage-4-Add-3');
dropoutLayer1 = dropoutLayer(0.1, 'Name', 'Transformer-DropOut-1');
dropoutLayer2 = dropoutLayer(0.1, 'Name', 'Transformer-DropOut-2');
LayerNormalizationLayer1 = layerNormalizationLayer('Name','Transformer-LN-1');
LayerNormalizationLayer2 = layerNormalizationLayer('Name','Transformer-LN-2');
SelfAttentionLayer = selfAttentionLayer(8, 32, 'Name', 'Transformer-SelfAttention');
FullyConnectedLayer = fullyConnectedLayer(1024, 'Name', 'Transformer-fc');
ReshapeLayer = reshapeLayer('Transformer-reshape');
index1dLayer = indexing1dLayer('Name', 'Transformer-index1d');
% printShapeLayer1 = functionLayer(@printShape, ...
% 'Name', 'printShapeLayer1', ...
% 'NumInputs', 1, ...
% 'NumOutputs', 1, ...
% 'InputNames', {'in'}, ...
% 'OutputNames', {'out'});
% printShapeLayer2 = functionLayer(@printShape, ...
% 'Name', 'printShapeLayer2', ...
% 'NumInputs', 1, ...
% 'NumOutputs', 1, ...
% 'InputNames', {'in'}, ...
% 'OutputNames', {'out'});
% printShapeLayer3 = functionLayer(@printShape, ...
% 'Name', 'printShapeLayer3', ...
% 'NumInputs', 1, ...
% 'NumOutputs', 1, ...
% 'InputNames', {'in'}, ...
% 'OutputNames', {'out'});
net = addLayers(net, PatchEmbeddingLayer1);
net = addLayers(net, EmbeddingConcatenationLayer1);
net = addLayers(net, PositionEmbeddingLayer1);
net = addLayers(net, addLayer4);
net = addLayers(net, addLayer5);
net = addLayers(net, addLayer6);
net = addLayers(net, dropoutLayer1);
net = addLayers(net, dropoutLayer2);
net = addLayers(net, LayerNormalizationLayer1);
net = addLayers(net, LayerNormalizationLayer2);
net = addLayers(net, SelfAttentionLayer);
net = addLayers(net, FullyConnectedLayer);
net = addLayers(net, ReshapeLayer);
net = addLayers(net, index1dLayer);
% net = addLayers(net, printShapeLayer1);
% net = addLayers(net, printShapeLayer2);
% net = addLayers(net, printShapeLayer3);
% net = disconnectLayers(net, 'encoderImageInputLayer', 'Encoder-Stage-1-Conv-1');
% net = disconnectLayers(net, 'encoderImageInputLayer', 'Encoder-Stage-1-BN-Ident-1');
% net = connectLayers(net, 'encoderImageInputLayer', 'printShapeLayer3');
% net = connectLayers(net, 'printShapeLayer3', 'Encoder-Stage-1-Conv-1');
% net = connectLayers(net, 'printShapeLayer3', 'Encoder-Stage-1-BN-Ident-1');
net = disconnectLayers(net,'Encoder-Stage-3-DropOut', 'Encoder-Stage-3-MaxPool');
% net = connectLayers(net,'Encoder-Stage-3-ReLU-2', 'printShapeLayer1');
% net = connectLayers(net,'printShapeLayer1', 'Transformer-PatchEmbedding');
net = connectLayers(net,'Encoder-Stage-3-DropOut', 'Transformer-PatchEmbedding');
net = connectLayers(net, 'Transformer-PatchEmbedding', 'Transformer-EmbeddingConcatenation');
net = connectLayers(net, 'Transformer-EmbeddingConcatenation', 'Transformer-PositionEmbedding');
net = connectLayers(net, 'Transformer-PositionEmbedding', 'Encoder-Stage-4-Add-1/in1');
net = connectLayers(net, 'Transformer-EmbeddingConcatenation', 'Encoder-Stage-4-Add-1/in2');
net = connectLayers(net, 'Encoder-Stage-4-Add-1', 'Transformer-DropOut-1');
net = connectLayers(net, 'Transformer-DropOut-1', 'Transformer-LN-1');
net = connectLayers(net, 'Transformer-LN-1', 'Transformer-SelfAttention');
net = connectLayers(net, 'Transformer-SelfAttention', 'Transformer-DropOut-2');
net = connectLayers(net, 'Transformer-DropOut-2', 'Encoder-Stage-4-Add-2/in1');
net = connectLayers(net, 'Transformer-DropOut-1', 'Encoder-Stage-4-Add-2/in2');
net = connectLayers(net, 'Encoder-Stage-4-Add-2', 'Transformer-LN-2');
net = connectLayers(net, 'Transformer-LN-2', 'Transformer-index1d');
net = connectLayers(net, 'Transformer-index1d', 'Transformer-fc');
net = connectLayers(net, 'Transformer-fc', 'Encoder-Stage-4-Add-3/in1');
net = connectLayers(net, 'Encoder-Stage-4-Add-2', 'Encoder-Stage-4-Add-3/in2');
net = connectLayers(net, 'Encoder-Stage-4-Add-3', 'Transformer-reshape');
% net = connectLayers(net, 'Transformer-reshape', 'Encoder-Stage-3-DropOut');
net = connectLayers(net, 'Transformer-reshape', 'Encoder-Stage-3-MaxPool');
% net = connectLayers(net, 'Transformer-reshape', 'encoderDecoderSkipConnectionCrop3/in');
% net = disconnectLayers(net, 'Encoder-Stage-3-MaxPool', 'LatentNetwork-Bridge-Conv-1');
% net = connectLayers(net, 'Encoder-Stage-3-MaxPool', 'printShapeLayer2');
net = removeLayers(net, 'Encoder-Stage-3-MaxPool');
net = connectLayers(net, 'Transformer-reshape', 'LatentNetwork-Bridge-Conv-1');
% net = connectLayers(net, 'Encoder-Stage-3-MaxPool', 'LatentNetwork-Bridge-Conv-1');
% 添加Attention Gate
relulayer1 = reluLayer('Name', 'AttentionGate-Stage-1-relu');
relulayer2 = reluLayer('Name', 'AttentionGate-Stage-2-relu');
relulayer3 = reluLayer('Name', 'AttentionGate-Stage-3-relu');
sigmoidlayer1 = sigmoidLayer('Name','AttentionGate-Stage-1-sigmoid');
sigmoidlayer2 = sigmoidLayer('Name','AttentionGate-Stage-2-sigmoid');
sigmoidlayer3 = sigmoidLayer('Name','AttentionGate-Stage-3-sigmoid');
convolution3dlayer11 = convolution3dLayer(1, 512, 'Padding','same', 'Name','AttentionGate-Stage-1-conv-1');
convolution3dlayer12 = convolution3dLayer(1, 256, 'Padding','same', 'Name','AttentionGate-Stage-1-conv-2');
convolution3dlayer13 = convolution3dLayer(1, 256, 'Padding','same', 'Name','AttentionGate-Stage-1-conv-3');
convolution3dlayer21 = convolution3dLayer(1, 256, 'Padding','same', 'Name','AttentionGate-Stage-2-conv-1');
convolution3dlayer22 = convolution3dLayer(1, 128, 'Padding','same', 'Name','AttentionGate-Stage-2-conv-2');
convolution3dlayer23 = convolution3dLayer(1, 128, 'Padding','same', 'Name','AttentionGate-Stage-2-conv-3');
convolution3dlayer31 = convolution3dLayer(1, 128, 'Padding','same', 'Name','AttentionGate-Stage-3-conv-1');
convolution3dlayer32 = convolution3dLayer(1, 64, 'Padding','same', 'Name','AttentionGate-Stage-3-conv-2');
convolution3dlayer33 = convolution3dLayer(1, 64, 'Padding','same', 'Name','AttentionGate-Stage-3-conv-3');
net = addLayers(net, relulayer1);
net = addLayers(net, relulayer2);
net = addLayers(net, relulayer3);
net = addLayers(net, sigmoidlayer1);
net = addLayers(net, sigmoidlayer2);
net = addLayers(net, sigmoidlayer3);
net = addLayers(net, convolution3dlayer11);
net = addLayers(net, convolution3dlayer12);
net = addLayers(net, convolution3dlayer13);
net = addLayers(net, convolution3dlayer21);
net = addLayers(net, convolution3dlayer22);
net = addLayers(net, convolution3dlayer23);
net = addLayers(net, convolution3dlayer31);
net = addLayers(net, convolution3dlayer32);
net = addLayers(net, convolution3dlayer33);
net = disconnectLayers(net, 'Decoder-Stage-1-UpReLU', 'encoderDecoderSkipConnectionCrop3/ref');
net = disconnectLayers(net, 'Decoder-Stage-2-UpReLU', 'encoderDecoderSkipConnectionCrop2/ref');
net = disconnectLayers(net, 'Decoder-Stage-3-UpReLU', 'encoderDecoderSkipConnectionCrop1/ref');
net = disconnectLayers(net, 'Encoder-Stage-3-DropOut', 'encoderDecoderSkipConnectionCrop3/in');
net = disconnectLayers(net, 'Encoder-Stage-2-ReLU-2', 'encoderDecoderSkipConnectionCrop2/in');
net = disconnectLayers(net, 'Encoder-Stage-1-ReLU-2', 'encoderDecoderSkipConnectionCrop1/in');
net = connectLayers(net, 'Decoder-Stage-1-UpReLU', 'AttentionGate-Stage-1-conv-1');
net = connectLayers(net, 'Decoder-Stage-2-UpReLU', 'AttentionGate-Stage-2-conv-1');
net = connectLayers(net, 'Decoder-Stage-3-UpReLU', 'AttentionGate-Stage-3-conv-1');
net = connectLayers(net, 'Encoder-Stage-3-DropOut', 'AttentionGate-Stage-1-conv-2');
net = connectLayers(net, 'Encoder-Stage-2-ReLU-2', 'AttentionGate-Stage-2-conv-2');
net = connectLayers(net, 'Encoder-Stage-1-ReLU-2', 'AttentionGate-Stage-3-conv-2');
net = connectLayers(net, 'AttentionGate-Stage-1-conv-1', 'encoderDecoderSkipConnectionCrop3/ref');
net = connectLayers(net, 'AttentionGate-Stage-2-conv-1', 'encoderDecoderSkipConnectionCrop2/ref');
net = connectLayers(net, 'AttentionGate-Stage-3-conv-1', 'encoderDecoderSkipConnectionCrop1/ref');
net = connectLayers(net, 'AttentionGate-Stage-1-conv-2', 'encoderDecoderSkipConnectionCrop3/in');
net = connectLayers(net, 'AttentionGate-Stage-2-conv-2', 'encoderDecoderSkipConnectionCrop2/in');
net = connectLayers(net, 'AttentionGate-Stage-3-conv-2', 'encoderDecoderSkipConnectionCrop1/in');
net = disconnectLayers(net, 'encoderDecoderSkipConnectionCrop3', 'encoderDecoderSkipConnectionFeatureMerge3/in1');
net = disconnectLayers(net, 'encoderDecoderSkipConnectionCrop2', 'encoderDecoderSkipConnectionFeatureMerge2/in1');
net = disconnectLayers(net, 'encoderDecoderSkipConnectionCrop1', 'encoderDecoderSkipConnectionFeatureMerge1/in1');
net = connectLayers(net, 'encoderDecoderSkipConnectionCrop3', 'AttentionGate-Stage-1-relu');
net = connectLayers(net, 'encoderDecoderSkipConnectionCrop2', 'AttentionGate-Stage-2-relu');
net = connectLayers(net, 'encoderDecoderSkipConnectionCrop1', 'AttentionGate-Stage-3-relu');
net = connectLayers(net, 'AttentionGate-Stage-1-relu', 'AttentionGate-Stage-1-conv-3');
net = connectLayers(net, 'AttentionGate-Stage-3-relu', 'AttentionGate-Stage-3-conv-3');
net = connectLayers(net, 'AttentionGate-Stage-2-relu', 'AttentionGate-Stage-2-conv-3');
net = connectLayers(net, 'AttentionGate-Stage-1-conv-3', 'AttentionGate-Stage-1-sigmoid');
net = connectLayers(net, 'AttentionGate-Stage-3-conv-3', 'AttentionGate-Stage-3-sigmoid');
net = connectLayers(net, 'AttentionGate-Stage-2-conv-3', 'AttentionGate-Stage-2-sigmoid');
net = connectLayers(net, 'AttentionGate-Stage-1-sigmoid', 'encoderDecoderSkipConnectionFeatureMerge3/in1');
net = connectLayers(net, 'AttentionGate-Stage-2-sigmoid', 'encoderDecoderSkipConnectionFeatureMerge2/in1');
net = connectLayers(net, 'AttentionGate-Stage-3-sigmoid', 'encoderDecoderSkipConnectionFeatureMerge1/in1');
% 设置训练选项,使用GPU,启用详细输出,以及其他重要训练参数
options = trainingOptions('adam', ...
'InitialLearnRate', 1e-4, ...
'LearnRateSchedule', 'piecewise', ... % 学习率计划
'LearnRateDropFactor', 0.5, ... % 学习率降低因子
'LearnRateDropPeriod', 5, ... % 每5个epochs降低学习率
'L2Regularization', 1e-4, ... % L2正则化,有助于防止过拟合
'MaxEpochs', 10, ...
'MiniBatchSize', 4, ...
'Verbose', true, ...
'ValidationData', dsVal, ...
'ValidationFrequency', 5, ...
'ValidationPatience', 20, ...
'Plots', 'training-progress', ...
'ExecutionEnvironment', 'gpu', ...
'CheckpointPath', 'X:\MATLAB codes\StatisticsModeling2');
analyzeNetwork(net);
net = initialize(net);
% 从验证数据集中读取一个样本
[data, info] = read(dsVal);
image = data{1}; % 图像数据
label = data{2}; % 真实标签
image = double(image);
% 进行预测
predictedLabel = predict(net, image);

Answers (1)

Alan
Alan on 2 Jul 2024 at 10:42
Hi 泽宇 ,
I did not run the entire code since I do not have your exact dataset. Instead, I created just the network and ran analyzeNetwork, which reported an issue as follows:
I’m not sure if this issue is causing the error during prediction. I also checked if it is a problem with the layers you have provided in the picture, but it appears to me that the layer “Encoder-Stage-4-Add-1” is receiving valid input formats with correct sizes. Here is where you could refer to the valid input formats: https://www.mathworks.com/help/deeplearning/ref/nnet.cnn.layer.additionlayer.html#mw_d4293d0e-58e6-40ea-9ab7-a27776196922_head
Additionally, I suspect that the format of the input image could have been incorrect, leading to unexpected computations by the network, and consequently the error in the addition layer. So, make sure that the input image matches the input size of the network, [128 128 8 2].

Products


Release

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!