how to make cross attention use attentionlayer?
Show older comments
I want to replace the dual-branch merge section of the model in the following link with cross-attention for fusion, but it's not successful. Is my operation incorrect? I have written an example, but I still don't understand how to embed it into the model in the link.
net one:(failure, loss dont down)
initialLayers = [
sequenceInputLayer(1, "MinLength", numSamples, "Name", "input", "Normalization", "zscore", "SplitComplexInputs", true)
convolution1dLayer(7, 2, "stride", 1)
];
stftBranchLayers = [
stftLayer("TransformMode", "squaremag", "Window", hann(64), "OverlapLength", 52, "Name", "stft", "FFTLength", 256, "WeightLearnRateFactor", 0 )
functionLayer(@(x)dlarray(x, 'SCBS'), Formattable=true, Acceleratable=true, Name="stft_reformat")
convolution2dLayer([4, 8], 16, "Padding", "same", "Name", "stft_conv_1")
layerNormalizationLayer("Name", "stft_layernorm_1")
reluLayer("Name", "stft_relu_1")
maxPooling2dLayer([4, 8], "Stride", [1 2], "Name", "stft_maxpool_1")
convolution2dLayer([4, 8], 24, "Padding", "same", "Name", "stft_conv_2")
layerNormalizationLayer("Name", "stft_layernorm_2")
reluLayer("Name", "stft_relu_2")
maxPooling2dLayer([4, 8], "Stride", [1 2], "Name", "stft_maxpool_2")
convolution2dLayer([4, 8], 32, "Padding", "same", "Name", "stft_conv_3")
layerNormalizationLayer("Name", "stft_layernorm_3")
reluLayer("Name", "stft_relu_3")
maxPooling2dLayer([4, 8], "Stride", [1 2], "Name", "stft_maxpool_3")
flattenLayer("Name", "stft_flatten")
dropoutLayer(0.5)
fullyConnectedLayer(100,"Name","fc_stft")
];
cwtBranchLayers = [
cwtLayer("SignalLength", numSamples, "TransformMode", "squaremag", "Name","cwt", "WeightLearnRateFactor", 0);
functionLayer(@(x)dlarray(x, 'SCBS'), Formattable=true, Acceleratable=true, Name="cwt_reformat")
convolution2dLayer([4, 8], 16, "Padding", "same", "Name", "cwt_conv_1")
layerNormalizationLayer("Name", "cwt_layernorm_1")
reluLayer("Name", "cwt_relu_1")
maxPooling2dLayer([4, 8], "Stride", [1 4], "Name", "cwt_maxpool_1")
convolution2dLayer([4, 8], 24, "Padding", "same", "Name", "cwt_conv_2")
layerNormalizationLayer("Name", "cwt_layernorm_2")
reluLayer("Name", "cwt_relu_2")
maxPooling2dLayer([4, 8], "Stride", [1 4], "Name", "cwt_maxpool_2")
convolution2dLayer([4, 8], 32, "Padding", "same", "Name", "cwt_conv_3")
layerNormalizationLayer("Name", "cwt_layernorm_3")
reluLayer("Name", "cwt_relu_3")
maxPooling2dLayer([4, 8], "Stride", [1 4], "Name", "cwt_maxpool_3")
flattenLayer("Name", "cwt_flatten")
dropoutLayer(0.5)
fullyConnectedLayer(100,"Name","fc_cwt")
];
finalLayers = [
attentionLayer(4,"Name","attention")
layerNormalizationLayer("Name","layernorm")
fullyConnectedLayer(48,"Name","fc_1")
fullyConnectedLayer(numel(waveformClasses),"Name","fc_2")
softmaxLayer("Name","softmax")
];
dlLayers2 = dlnetwork(initialLayers);
dlLayers2 = addLayers(dlLayers2, stftBranchLayers);
dlLayers2 = addLayers(dlLayers2, cwtBranchLayers);
dlLayers2 = addLayers(dlLayers2, finalLayers);
dlLayers2 = connectLayers(dlLayers2, "conv1d", "stft");
dlLayers2 = connectLayers(dlLayers2, "conv1d", "cwt");
dlLayers2 = connectLayers(dlLayers2,"fc_stft","attention/key");
dlLayers2 = connectLayers(dlLayers2,"fc_stft","attention/value");
dlLayers2 = connectLayers(dlLayers2,"fc_cwt","attention/query");
my example:(is it right ?)
numChannels = 10;
numObservations = 128;
numTimeSteps = 100;
X = rand(numChannels,numObservations,numTimeSteps);
X = dlarray(X);
Y = rand(numChannels,numObservations,numTimeSteps);
Y = dlarray(Y);
numHeads = 8;
outputSize = numChannels*numHeads;
WQ = rand(outputSize, numChannels, 1, 1);
WK = rand(outputSize, numChannels, 1, 1);
WV = rand(outputSize, numChannels, 1, 1);
WO = rand(outputSize, outputSize, 1, 1);
Z = crossAttention(X, Y, numHeads, WQ, WK, WV, WO);
function Z = crossAttention(X, Y, numHeads, WQ, WK, WV, WO)
queries = WQ * X;
keys = WK * Y;
values = WV * Y;
A = attention(queries, keys, values, numHeads, 'DataFormat', 'CBT');
Z = WO * A;
end
Accepted Answer
More Answers (0)
Categories
Find more on Built-In Layers in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!