Arcface loss, sphereface loss

Hi
Is there any implementation of arcface loss or sphereface loss in MATLAB deep leaning toolbox?
Best Regards
Sameed

Answers (2)

Hi,
Currently there is no support for arcface or sphereface loss in MATLAB.
xingxingcui
xingxingcui on 21 Nov 2020
Edited: xingxingcui on 13 Jan 2021

6 Comments

Hi Cui
I developed my own arcface layer which works well for image retrieval task. I used a pretrained ResNet101 and removed FC and loss layer. Then I added the my own FC layer and arcface loss layer to lgraph by using following script, where num_classes is the number of classes in training datat:
num_classes=20000;
in=2048;
out=num_classes;
sc = sqrt(3/(1*1*in)) ;
weights = (rand(out,in, 'single')*2 - 1)*sc ;
CFC = cosineFC('CFC',weights);
CFC=setLearnRateFactor(CFC,'Weights',10);
factor1=getLearnRateFactor(CFC,'Weights');
lgraph = addLayers(lgraph,CFC);
lgraph = connectLayers(lgraph,'gap','CFC');
The cosineFC layer computes the dot product between L2-normalized weights and L2-normalised features
classdef cosineFC < nnet.layer.Layer
% Example custom PReLU layer.
properties (Learnable)
% Layer learnable parameters
% Scaling coefficient
Weights
end
methods
function layer = cosineFC(name,Weights)
layer.Name = name;
% Set layer description.
layer.Description = "cosineFC ";
% Initialize scaling coefficient.
layer.Weights = Weights;
end
function Z = predict(layer, X)
% Z = predict(layer, X) forwards the input data X through the
X=L2_Norm(X,3);
layer.Weights=L2_Norm(layer.Weights,2);
bias = zeros(size(layer.Weights,1),1);
Z = fullyconnect(X,layer.Weights,bias,'DataFormat','SSCB');
Z=permute(Z,[3 4 1 2]);
end
end
Finally I modified regression loss for arcface loss optimization
classdef arcloss < nnet.layer.RegressionLayer
% Example custom regression layer with mean-absolute-error loss.
methods
function layer = arcloss(name)
% layer = maeRegressionLayer(name) creates a
% mean-absolute-error regression layer and specifies the layer
% name.
% Set layer name.
layer.Name = name;
% Set layer description.
layer.Description = 'arcface loss';
end
function loss = forwardLoss(layer, cosine, Y)
cosine=dlarray(squeeze(cosine),'CB');
Y=squeeze(Y);
m=0.3;
s=30;
cos_m = cos(m);
sin_m = sin(m);
th = cos(pi - m);
mm = sin(pi - m) * m;
cs=(1-cosine.^2);
cs_limit=max(min(cs,1),0);
sine=sqrt(cs_limit);
phi = cosine * cos_m - sine * sin_m;
cos_th=cosine-mm;
idx=cosine<th;
phi(idx) = cos_th(idx);
output = (Y .* phi) + ((1 - Y) .* cosine);
output = output*s;
dlYPred = softmax(output);
loss = crossentropy(dlYPred,Y);
end
end
end
It works very well on Google landmarks retrieval v2 dataset.
How do you input the value format of label in the "trainNetwork" function? Can you post your code here, thank you!"Regression network" requires discrete values, I don’t know how to input.
Hi Cui
Yes we have to create our datastore for regression net. Here is the sample code
% delete old pool
poolobj=gcp('nocreate');
delete(poolobj);
%% convert net to layer graph
lgraph=layerGraph(net); clear net
%% Get database for regression problem
imds = imageDatastore('landmarks','LabelSource','foldernames','IncludeSubfolders',true);
Predictor=imds.Files;
lb=single(imds.Labels);
num_classes=size(unique(lb),1);
Y = zeros(length(lb), num_classes, 'single');
for i=1:length(lb)
Y(i,lb(i))=1;
end
responseName=Y;
T=table(Predictor,responseName);
auimds=augmentedImageDatastore([224 224 3],T,'responseName');
%% train network
miniBatchSize = 32;
options = trainingOptions('sgdm', ...
'Momentum',0.9, ...
'InitialLearnRate',1e-3, ...
'MaxEpochs',5, ...
'MiniBatchSize',miniBatchSize, ...
'ExecutionEnvironment','multi-gpu', ...
'CheckpointPath','/scratch/');
net = trainNetwork(auimds,lgraph,options);
@Syed Sameed Husain
Thank you very much for your useful code!
In your cosineFC layer what is the L2_Norm function? may I have the code. Thank You
In your cosineFC layer what is the L2_Norm function? may I have the code. Thank You

Sign in to comment.

Products

Release

R2020a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!