juntang-zhuang/Adabelief-Optimizer

Matlab implementation

pcwhy opened this issue · 8 comments

pcwhy commented

How about a Matlab test case?

I tried to implement a Matlab version of AdaBelief and compare it with SGD with momentum at
https://github.com/pcwhy/AdaBelief-Matlab
I found that sometimes AdaBelief is not guaranteed to converge to an optimal solution as SGD with momentum can reach.

Hi, thanks for feedback. I did not implement a Matlab version. Quickly looked through your repo, I found several possible differences from pytorch version: 1) decoupled weight decay seems not enabled in your implementations, this is typically important for the Adam family. 2) the way you use eps is different from the algorithm, also the eps is set as 1e-3 in your code, if I got it correct; while it’s 1e-8 in Adam, and 1e-16 in AdaBelief, note that eps appears twice in AdaBelief, both in and outside the sqrt. 3) you disabled bias correction, while it's enabled in our algorithm. The short version algorithm does not include that only for display purpose, full algorithm is now in updated version and appendix. 4) rectification is not enabled, but this may not always help, so perhaps it’s fine to not implement it for now. Hopefully you might get improvement after these modifications.

pcwhy commented

Thanks for the reply, the 0.001 is actually the learning rate alpha. And e is originally set to 1e-8 and has now set to 1e-16.

And I conduct the following experiments:

  1. I implemented the decoupled weight decay, but it doesn't play the Black Magic. Comparably, in SGD, I also disabled the L2 regularization, I think it will be fair to let the two models run freely.
  2. I implemented the bias correction, indeed it helps in the regular CNN and DNN, but it does not work in my own DNN with a zero-bias (Zb) dense layer. Specifically, Zb is a differentiable template matching layer using cosine similarity to increase model explainability.

Thanks again for the reply, I have updated my repository, and hopefully, it would be helpful to those who work in Matlab.

@pcwhy Thanks for the update. I quickly ran your code, and it seems the code does not use any real data. I clicked the link to data and preprocessing code, however I'm confused by the code. Could you provide the full code to reproduce the figure you posted, so I can look into it?

pcwhy commented

I updated my repository (https://github.com/pcwhy/AdaBelief-Matlab). And the entry for the experiment is AdaBeliefOnRealData.m. I placed several checkpoints in code so as to avoid getting lost. This script is supposed to serve as a minimal example. I just manually alter the configuration in Checkpoint 1 to 3 to get those curves in the graph. If necessary, I can make an integral script to derive all curves within a single run-through.

Thanks again.

Thank you so much! I'll take a look and try your code.

@pcwhy Please see the code below. Sorry I don't have sufficient time to carefully tune it, also I'm not quite familiar with Matlab. Just did a little bit tuning with the momentum, also found the implementation is different from ours. Many thanks for your implementation in MATLAB, if possible we still recommend our implementations since there are many details that require familiarity with both languages. Using following code, run for 3 epochs I got around 0.93 acc, which should be higher than before, and could be better with more careful tuning, visually it's close to SGD in your figure. For the SGDM, I can't run your code, seems some coding error and I did not debug it.

clc;
close all;
clear;
rng default;

%%----------- Check Point 0:  

% Please download the dataset from
% https://drive.google.com/uc?export=download&id=1N-eImoAA3QFPu3cBJd-1WIUH0cqw2RoT
% Or
% https://ieee-dataport.org/open-access/24-hour-signal-recording-dataset-labels-cybersecurity-and-iot
load('adsb_records_qt.mat');
%%%%---------- End of Check Point 0

payloadMatrix = reshape(payloadMatrix', ...
    length(payloadMatrix)/length(msgIdLst), length(msgIdLst))';
rawIMatrix = reshape(rawIMatrix', ...
    length(rawIMatrix)/length(msgIdLst), length(msgIdLst))';
rawQMatrix = reshape(rawQMatrix', ...
    length(rawQMatrix)/length(msgIdLst), length(msgIdLst))';
rawCompMatrix = rawIMatrix + rawQMatrix.*1j;
if size(rawCompMatrix,2) < 1024
    appendingBits = (ceil(sqrt(size(rawCompMatrix,2))))^2 - size(rawCompMatrix,2);
    rawCompMatrix = [rawCompMatrix, zeros(size(rawCompMatrix,1), appendingBits)];
else
   rawCompMatrix = rawCompMatrix(:,1:1024); 
end
dateTimeLst = datetime(uint64(timeStampLst),'ConvertFrom','posixtime','TimeZone','America/New_York','TicksPerSecond',1e3,'Format','dd-MMM-yyyy HH:mm:ss.SSS');

uIcao = unique(icaoLst);
c = countmember(uIcao,icaoLst);
icaoOccurTb = [uIcao,c];
icaoOccurTb = sortrows(icaoOccurTb,2,'descend');
cond1 = icaoOccurTb(:,2)>=300;
cond2 = icaoOccurTb(:,2)<=5000;
cond = logical(cond1.*cond2);
selectedPlanes = icaoOccurTb(cond,:);
unknowPlanes = icaoOccurTb(~cond,:);
allTrainData = [icaoLst, abs(rawCompMatrix)];

selectedBasebandData = [];
selectedRawCompData = [];

unknownBasebandData = [];
unknownRawCompData = [];

minTrainChance = 100;
maxTrainChance = 500;

for i = 1:size(selectedPlanes,1)
    selection = allTrainData(:,1)==selectedPlanes(i,1);
    localBaseband = allTrainData(selection,:);
    localComplex = rawCompMatrix(selection,:);
    localAngles = (angle(localComplex(:,:))); 
    
    if size(localBaseband,1) < minTrainChance
        continue;
    elseif size(localBaseband,1) >= maxTrainChance
        rndSeq = randperm(size(localBaseband,1));
        rndSeq = rndSeq(1:maxTrainChance);
        localBaseband = localBaseband(rndSeq,:);
        localComplex = localComplex(rndSeq,:);
    else
        %Nothing to do
    end
    selectedBasebandData = [selectedBasebandData; localBaseband];
    selectedRawCompData = [selectedRawCompData; localComplex];    
end

for i = 1:size(unknowPlanes,1)
    selection = allTrainData(:,1)==unknowPlanes(i,1);
    localBaseband = allTrainData(selection,:);
    localComplex = rawCompMatrix(selection,:);
    localAngles = (angle(localComplex(:,:)));
    unknownBasebandData = [unknownBasebandData; localBaseband];
    unknownRawCompData = [unknownRawCompData; localComplex];    
end

randSeries = randperm(size(selectedBasebandData,1));
selectedBasebandData = selectedBasebandData(randSeries,:);
selectedRawCompData = selectedRawCompData(randSeries,:);

randSeries = randperm(size(unknownBasebandData,1));
unknownBasebandData = unknownBasebandData(randSeries,:);
unknownRawCompData = unknownRawCompData(randSeries,:);

% 
% for i = 1:size(selectedBasebandData,1)
%     mu = mean(selectedBasebandData(i,2:end));
%     sigma = std(selectedBasebandData(i,2:end));
%     
%     % Use probabilistic filtering to restore the rational signal
%     mask = selectedBasebandData(i,2:end) <= mu + 3*sigma;
%     cleanInput = selectedBasebandData(i,2:end).*mask;
%     selectedBasebandData(i,2:end) = cleanInput;
% end
% for i = 1:size(selectedRawCompData,1)
%     muReal = mean(real(selectedRawCompData(i,1:end)));
%     sigmaReal = std(real(selectedRawCompData(i,1:end)));
%     muImag = mean(imag(selectedRawCompData(i,1:end)));
%     sigmaImag = std(imag(selectedRawCompData(i,1:end)));
%     
%     % Use probabilistic filtering to restore the rational signal
%     maskReal = abs(real(selectedRawCompData(i,1:end))) <= muReal + 3*sigmaReal;
%     cleanInputReal = real(selectedRawCompData(i,1:end)).*maskReal;
%     
%     maskImag = abs(imag(selectedRawCompData(i,1:end))) <= muImag + 3*sigmaImag;
%     cleanInputImag = imag(selectedRawCompData(i,1:end)).*maskImag;
%     
%     selectedRawCompData(i,1:end) = cleanInputReal+cleanInputImag.*1j;    
% end


[X,cX,Y,cY] = makeDataTensor(selectedBasebandData,selectedRawCompData);
[uX,cuX,uY,cuY] = makeDataTensor(unknownBasebandData,unknownRawCompData);

lookUpTab = [unique([Y;cY]),[1:length(unique([Y;cY]))]'];
Y2 = Y;
for i = 1:size(Y)
    Y2(i) = lookUpTab(find(lookUpTab(:,1) == Y(i)),2);
end
cY2 = cY;
for i = 1:size(cY)
    cY2(i) = lookUpTab(find(lookUpTab(:,1) == cY(i)),2);
end

Y = Y2;
cY = cY2;

inputSize = [size(X,1) size(X,2) size(X,3)];
numClasses = size(unique(selectedBasebandData(:,1)),1);

layers = [
    imageInputLayer(inputSize, 'Name', 'input', 'Mean', 0)
    convolution2dLayer(5,10, 'Name', 'conv2d_1')
    batchNormalizationLayer('Name', 'batchNorm_1')
    reluLayer('Name', 'relu_1')
    convolution2dLayer(3, 10, 'Padding', 1, 'Name', 'conv2d_2')
    reluLayer('Name', 'relu_2')    
    convolution2dLayer(3, 10, 'Padding', 1, 'Name', 'conv2d_3')
    reluLayer('Name', 'relu_3')    
    depthConcatenationLayer(2,'Name','add_1')    
    fullyConnectedLayer(numClasses, 'Name', 'fc_bf_fp') % 11th
    
%%----------- Check Point 1:  
%% Here you can specify whether to use regular dense layer 
%% or zero-bias dense layer
     fullyConnectedLayer(numClasses, 'Name', 'Fingerprints') 
    %zeroBiasFCLayer(numClasses,numClasses,'Fingerprints',[])
%%----------- End of Check Point 1:         
    yxSoftmax('softmax_1')
    classificationLayer('Name', 'classify_1')
    ];

lgraph = layerGraph(layers);
lgraph = connectLayers(lgraph, 'relu_1', 'add_1/in2');
plot(lgraph);

XTrain = X;
YTrain = Y;
numEpochs = 3;
miniBatchSize = 128;
plots = "training-progress";
executionEnvironment = "auto";
if plots == "training-progress"
    figure(10);
    lineLossTrain = animatedline('Color','#0072BD','lineWidth',1.5);
    lineClassificationLoss = animatedline('Color','#EDB120','lineWidth',1.5);
      
    ylim([-inf inf])
    xlabel("Iteration")
    ylabel("Loss")
    legend('Loss','classificationLoss');
    grid on;
    
    figure(11);  
    lineCVAccuracy = animatedline('Color','#D95319','lineWidth',1.5);
    ylim([0 1.1])
    xlabel("Iteration")
    ylabel("Loss")    
    legend('CV Acc.','Avg. Kernel dist.');
    grid on;    
end
L2RegularizationFactor = 0.0;
initialLearnRate = 0.001;
decay = 0.01;
momentumSGD = 0.9;
velocities = [];
learnRates = [];
momentums = [];
gradientMasks = [];
numObservations = numel(YTrain);
numIterationsPerEpoch = floor(numObservations./miniBatchSize);
iteration = 0;
start = tic;
classes = categorical(YTrain);
lgraph2 = lgraph; % No old weights
lgraph2 = lgraph2.removeLayers('classify_1');
dlnet = dlnetwork(lgraph2);

% Loop over epochs.
totalIters = 0;
for epoch = 1:numEpochs
    idx = randperm(numel(YTrain));
    XTrain = XTrain(:,:,:,idx);
    YTrain = YTrain(idx); 
    % Loop over mini-batches.
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;
        totalIters = totalIters + 1;
        % Read mini-batch of data and convert the labels to dummy
        % variables.
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        Xb = XTrain(:,:,:,idx);
        Yb = zeros(numClasses, miniBatchSize, 'single');
        for c = 1:numClasses
            Yb(c,YTrain(idx)==(c)) = 1;
        end
        % Convert mini-batch of data to dlarray.
        dlX = dlarray(single(Xb),'SSCB');
        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
        end
        % Evaluate the model gradients, state, and loss using dlfeval and the
        % modelGradients function and update the network state.
        [gradients,state,loss,classificationLoss] = dlfeval(@modelGradientsOnWeights,dlnet,dlX,Yb);
%         [gradients,state,loss] = dlfeval(@modelGradientsOnWeights,dlnet,dlX,Yb);        
        dlnet.State = state;
        % Determine learning rate for time-based decay learning rate schedule.
        learnRate = initialLearnRate/(1 + decay*iteration);
        % Update the network parameters using the SGDM optimizer.
        %[dlnet, velocity] = sgdmupdate(dlnet, gradients, velocity, learnRate, momentum);
        % Update the network parameters using the SGD optimizer.
        %dlnet = dlupdate(@sgdFunction,dlnet,gradients);
        if isempty(velocities)
            velocities = packScalar(gradients, 0);
            learnRates = packScalar(gradients, learnRate);
%             momentumSGDs = packScalar(gradients, momentumSGD);
            momentums = packScalar(gradients, 0);
            L2Foctors = packScalar(gradients, L2RegularizationFactor);            
            wd = packScalar(gradients, L2RegularizationFactor);  
            gradientMasks = packScalar(gradients, 1);   
%             % Let's lock some weights
%             for k = 1:2
%                 gradientMasks.Value{k}=dlarray(zeros(size(gradientMasks.Value{k})));
%             end
        end
%%%%----------- Check Point 2:  
%%%% Here you can specify which optimizer to use, 
%         [dlnet, velocities] = dlupdate(@sgdmFunctionL2, ...
%             dlnet, gradients, velocities, ...
%             learnRates, momentumSGDs, L2Foctors, gradientMasks); % This is
%             % the famous SGD with momentum
        totalIterInPackage = packScalar(gradients, totalIters); % We have to make this...
                                                         % stupid data
                                                         % structure but it
                                                         % only contains
                                                         % the number of
                                                         % iterations
        [dlnet, velocities, momentums] = dlupdate(@adamFunction, ...
                    dlnet, gradients, velocities, ...
                    learnRates, momentums, wd, gradientMasks, ...
                    totalIterInPackage);        
%         [dlnet] = dlupdate(@sgdFunction, ...
%             dlnet, gradients); % the vanilla
%%%%-----------End of Check Point 2 

        % Display the training progress.
        if plots == "training-progress"
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            XTest = cX;
            YTest = categorical(cY);
            if mod(iteration,5) == 0 
                accuracy = cvAccuracy(dlnet, XTest,YTest,miniBatchSize,executionEnvironment,0);
                addpoints(lineCVAccuracy,iteration, accuracy);
            end
            addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
            addpoints(lineClassificationLoss,iteration,double(gather(extractdata(classificationLoss))));
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            drawnow
        end
    end
end
accuracy = cvAccuracy(dlnet, cX, categorical(cY), miniBatchSize, executionEnvironment, 1)


function accuracy = cvAccuracy(dlnet, XTest, YTest, miniBatchSize, executionEnvironment, confusionChartFlg)
    dlXTest = dlarray(XTest,'SSCB');
    if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
        dlXTest = gpuArray(dlXTest);
    end
    dlYPred = modelPredictions(dlnet,dlXTest,miniBatchSize);
    [~,idx] = max(extractdata(dlYPred),[],1);
    YPred = categorical(idx);
    accuracy = mean(YPred(:) == YTest(:));
    if confusionChartFlg == 1
        figure
        confusionchart(YPred(:),YTest(:));
    end
end

function dlYPred = modelPredictions(dlnet,dlX,miniBatchSize)
    numObservations = size(dlX,4);
    numIterations = ceil(numObservations / miniBatchSize);
    numClasses = size(dlnet.Layers(end-1).Weights,1);
    dlYPred = zeros(numClasses,numObservations,'like',dlX);
    for i = 1:numIterations
        idx = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations);
        dlYPred(:,idx) = predict(dlnet,dlX(:,:,:,idx));
    end
end


function [gradients,state,loss,classificationLoss] = modelGradientsOnWeights(dlnet,dlX,Y)
%   %This is only used with softmax of matlab which only applies softmax
%   on 'C' and 'B' channels.
    [rawPredictions,state] = forward(dlnet,dlX,'Outputs', 'Fingerprints');
    dlYPred = softmax(dlarray(squeeze(rawPredictions),'CB'));
%     [dlYPred,state] = forward(dlnet,dlX);
    penalty = 0;
    scalarL2Factor = 0;
    if scalarL2Factor ~= 0
        paramLst = dlnet.Learnables.Value;
        for i = 1:size(paramLst,1)
            penalty = penalty + sum((paramLst{i}(:)).^2);
        end
    end
    
    classificationLoss = crossentropy(squeeze(dlYPred),Y) + scalarL2Factor*penalty;

    loss = classificationLoss;
%     loss = classificationLoss + 0.2*(max(max(rawPredictions))-min(max(rawPredictions)));
    gradients = dlgradient(loss,dlnet.Learnables);
    %gradients = dlgradient(loss,dlnet.Learnables(4,:));
end

function [params,velocityUpdates,momentumUpdate] = adamFunction(params, rawParamGradients,...
    velocities, learnRates, momentums, L2Foctors, gradientMasks, iters)
    % https://arxiv.org/pdf/2010.07468.pdf %%AdaBelief
    % https://arxiv.org/pdf/1711.05101.pdf  %%DeCoupled Weight Decay 
    b1 = 0.5; 
    b2 = 0.999;
    e = 1e-8;
    curIter = iters(:);
    curIter = curIter(1);
    
    gt = rawParamGradients;
    mt = (momentums.*b1 + ((1-b1)).*gt);
    vt = (velocities.*b2 + ((1-b2)).*((gt-mt).^2)) + e;

    momentumUpdate = mt;
    velocityUpdates = vt;
    h_mt = mt./(1-b1.^curIter);
    h_vt = vt./(1-b2.^curIter);
%%%%----------- Check Point 3:  
%%%% Here you can specify whether to use bias correction, 
%%%% or zero-bias dense layer 
%%%% in this test, we can just try to eliminate the effect of varying learning
%%%% rates
%    params = params - 0.001.*(mt./(sqrt(vt)+e)).*gradientMasks...
%        - wd.*params.*gradientMasks; %This works better for zero-bias dense layer
     params = params - learnRates.*(h_mt./(sqrt(h_vt)+e)).*gradientMasks...
         -2*learnRates .* L2Foctors.*params.*gradientMasks;
%%%%
%%%%-----------End of Check Point 3 
end

function param = sgdFunction(param,paramGradient)
    learnRate = 0.01;
    param = param - learnRate.*paramGradient;
end

function [params, velocityUpdates] = sgdmFunction(params, paramGradients,...
    velocities, learnRates, momentums)
% https://towardsdatascience.com/stochastic-gradient-descent-momentum-explanation-8548a1cd264e
%     velocityUpdates = momentums.*velocities+learnRates.*paramGradients;
    velocityUpdates = momentums.*velocities+0.001.*paramGradients;
    params = params - velocityUpdates;
end

function [params, velocityUpdates] = sgdmFunctionL2(params, rawParamGradients,...
    velocities, learnRates, momentums, L2Foctors, gradientMasks)
% https://towardsdatascience.com/stochastic-gradient-descent-momentum-explanation-8548a1cd264e
% https://towardsdatascience.com/intuitions-on-l1-and-l2-regularisation-235f2db4c261
    paramGradients = rawParamGradients + 2*L2Foctors.*params;
    velocityUpdates = momentums.*velocities+learnRates.*paramGradients;
    params = params - (velocityUpdates).*gradientMasks;
end

function tabVars = packScalar(target, scalar)
% The matlabs' silly design results in such a strange function
    tabVars = target;
    for row = 1:size(tabVars(:,3),1)
        tabVars{row,3} = {...
            dlarray(...
            ones(size(tabVars.Value{row})).*scalar...%ones(size(tabVars(row,3).Value{1,1})).*scalar...
            )...
            };
    end
end

pcwhy commented

Thanks for the rapid reply. The reason why SGDM does not run is owing to my typo at:

% momentumSGDs = packScalar(gradients, momentumSGD);

Just uncomment that line and SGDM can run.

I realize that different initial learning rates (SGDM 0.01, while the Matlab-AdaBelief favors 0.001...LoL..) and other hyperparams are required to get optimal performances. That might be the real hardship...

I have updated the newest test results including the code above in my repository. Hopefully, it would be helpful.

Thanks for the update. Hyperparameter is definitely important, also the learning rate schedule matters, but I don't sufficient time to tune them one by one. BTW, since using SGDM or a smaller beta1 helps, I suspect the gradient variance is large, and it might be helpful to apply some rectification.