juntang-zhuang/Adabelief-Optimizer

Inconsistent computation of weight_decay and grad_residual among pytorch versions

sjscotti opened this issue · 5 comments

Hi
I was looking at the various versions you have in the pypi_packages folder and noticed that the order of computation of weight decay (which for some options modifies grad) and of grad_residual (which uses grad) differs for the different versions. In adabelief_pytorch0.0.5, adabelief_pytorch0.2.0, and adabelief_pytorch0.2.1 weight decay is done before computing grad_residual but in adabelief_pytorch0.1.0 it is done afterwards. It seems that adabelief_pytorch0.1.0 is more closely following what your paper described as the second-order momentum computation. Shouldn't the others be changes to align with adabelief_pytorch0.1.0?

Thanks for pointing it out. This part is a bit tricky if weight decay is not decoupled. Currently, Adam and AdamW implementation in PyTorch also does the weight decay before the update, I have not got into details and just follow the convention. But I think it would be very interesting to perform a careful comparison.

Thanks for the speedy reply. I had looked through your project page and in the presentation there was a comment from a user that gradient clipping caused problems for the algorithm. But it worked well when he turned gradient clipping off. So I was thinking that you don't want to affect the values stored for the gradients in any way.

Also, I had a question that is a bit of a newbie question WRT pytorch. I noticed that you had a number of .add_(group['eps']) on different quantities in the code (BTW I am using adabelief_pytorch0.2.1) which, as I understand it, will modify the quantities where this function is applied in place . So even though it is a small number, did you intend (for example) to modify the value of exp_avg_var that will be used in future iterations with these function applications?

Hi, thanks for the comments. The gradient clip could be problematic, say suppress the difference $g_t - m_t$ and cause an undesired huge step size.

For the add_(group['eps']), it's because there are multiple configurations such as weight_decouple, rectify, so I need to make sure they have the same update. Ultimately, ignore all the if branches, it appears twice, as in Algo2 in the paper.

The eps in-place add is slightly different from Adam, and also default eps=1e-16 for AdaBelief but is 1e-8 for Adam. In AdaBelief, eps is added both inside and outside sqrt, but the one outside sqrt can be safely ignored. (sqrt(1e-16)>>1e-16)

For the in-place add that eps will add to exp_avg_var in every step, it would require some more validations to determine whether it causes a big difference. Say using eps=1e-16, it would require a huge step t to make t * eps large enough to numerically take effect. For larger eps it could cause a difference, but for default 1e-16 I think it would not make too much difference unless train for a very very long time.

Currently, I would suggest keeping eps=1e-16 (at most 1e-12) for most cases. Recently I tested the default setting on Resnet again, I found with proper warmup and smoothly decayed lr schedule, also decoupled weight decay, AdaBelief with default eps can also perform no worse than SGD. I also quickly test on Imagenet with ViT with default eps also outperforms AdamW. For NLP tasks with transformer, I have not tested on very big models yet, but I guess a small eps should also work.

PS: I noticed your comments on the Newton-style modification. It looks very interesting to me. I'll take a careful look and perhaps test on larger NLP tasks later, but for now I'm a bit busy with other stuff and could not pursue very deep. But you are very welcome to post a comment here.

Thanks for helping me understand the application of eps in AdaBelief.

WRT the Newton-style modification in my other reply and my question of initialization of s_0 (since you can't form a finite difference until step 2), I have modified the version of step to initialize s_0 to zero, but not to start the exponential moving average recursion calculation until step 2 (see code below) . I also updated the second momentum bias correction and RAdam calculations to take into account this skipping of the first step when doing the Newton-style correction by subtracting 1 from the value used for step. I also decided to use the code in pytorch/optim as a model for the RAdam implementation rather than the original (it looked cleaner to me), and hardwired using SGDM for first step and when RAdam criteria of > 5 isn't met. I am running the code now in training the blenderbot2 400M transformer model. (BTW, because this model can be trained in fp16 precision, I am grateful that you included the recasting the precision of the parameters and gradients in step.)

    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                
                # cast data type
                half_precision = False
                if p.data.dtype == torch.float16:
                    half_precision = True
                    p.data = p.data.float()
                    p.grad = p.grad.float()

                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError(
                        'AdaBelief does not support sparse gradients, please consider SparseAdam instead')
                amsgrad = group['amsgrad']
				
                fd_adahessian = self.fd_adahessian

                state = self.state[p]

                beta1, beta2 = group['betas']

                # State initialization
                first_step = False
                if len(state) == 0:
                    first_step = True
                    state['step'] = 0
                    print('state initialization in adabeliefoff2_1_hess')
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
                        if version_higher else torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
                        if version_higher else torch.zeros_like(p.data)
                    if fd_adahessian: # create p_old
                        state['p_old'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
                            if version_higher else torch.zeros_like(p.data) 
                    if amsgrad:
                        # Maintains max of all exp. moving avg. of sq. grad. values
                        state['max_exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
                            if version_higher else torch.zeros_like(p.data)
               
                # get current state variable
                exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']

                state['step'] += 1
                bias_correction1 = 1 - beta1 ** state['step']
                step = state['step']
                if fd_adahessian:
                    step = step - 1 # because don't form second order momentum till can compute first  delta m / delta p finite difference at state['step'] == 2
                bias_correction2 = 1 - beta2 ** step

                if fd_adahessian: 
                    # first calculate delta m --- uses present grad and previous m
                    delta_m = (grad - exp_avg) * (1 - beta1) #SJS new
                    delta_m.div_(torch.sub(p.data, state['p_old']).add_(group['eps']))	# approximates delta m / delta p
                    if first_step: # can't get delta m / delta p on first step so set delta_m to zero so exp_avg_var update is zero
                       	delta_m.zero_()	
                # Update first and second moment running average
                exp_avg.mul_(beta1).add_(grad, alpha = 1 - beta1)
                if not fd_adahessian:
                    delta_m = grad - exp_avg #SJS original adabelief...  this uses current m
                exp_avg_var.mul_(beta2).addcmul_( delta_m, delta_m, value = 1 - beta2)  #will be zero for first fd_adahessian step
                del delta_m # free up memory
                if amsgrad:
                    # Maintains the maximum of all 2nd moment running avg. till now
                    torch.max(state['max_exp_avg_var'], exp_avg_var.add_(group['eps']), out=state['max_exp_avg_var'])  #SJS want add_ here? #changed to be similar to pytorch/optim code
                    # Use the max. for normalizing running avg. of gradient
                    denom = (state['max_exp_avg_var'].sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])  #SJS want add_ here?
                else:
                    denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])  #SJS want 2x add_ here?
                   
                if fd_adahessian:
                    state['p_old'] = p.data.clone() #  p_old goes with present m for next step
					
                # perform weight decay if weight_decay is non-zero, check if decoupled weight decay #SJS moved here from earlier because can modify grad that was needed for delta_m
                if group['weight_decay'] != 0:
                    if self.weight_decouple:
                        # decoupled weight decay as in AdamW
                        if not self.fixed_decay:
                            p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
                        else:
                            p.data.mul_(1.0 - group['weight_decay'])
                    else:
                        # L2 penalty
                        grad.add_(p.data, alpha = group['weight_decay']) 
						
                # update
                if fd_adahessian and first_step:
                    #update using SGDM step on first step since can't compute finite difference Hessian yet
                    p.data.add_(exp_avg, alpha = -group['lr'] / bias_correction1)
                else:
                    if not self.rectify:
                        # Default update						
                        p.data.addcdiv_(exp_avg, denom, value = -group['lr'] / bias_correction1)
                    else:  # Rectified update, forked from RAdam in pytorch optim
                        #NOTE: because assumption for fd_adahessian is that change over a step in gradient WRT p is primarily due to Hessian diagonal term for p
                        #      we do not not skip updates for parameter p not meeting rho_t > 5 criteria, but ALWAYS do SDGM step
                        # maximum length of the approximated SMA
                        rho_inf = 2 / (1 - beta2) - 1
                        # compute the length of the approximated SMA
                        rho_t = rho_inf - 2 * step * (beta2 ** step) / bias_correction2
                        if rho_t > 5.:
                            # Compute the variance rectification term and update parameters accordingly
                            rect = math.sqrt((rho_t - 4) * (rho_t - 2) * rho_inf / ((rho_inf - 4) * (rho_inf - 2) * rho_t))
                            p.data.addcdiv_(exp_avg, denom, value = -rect * group['lr'] / bias_correction1)
                        else:
                            # use SGDM for initial steps not meeting criterion
                            p.data.add_(exp_avg, alpha = -group['lr'] / bias_correction1)
                del denom # free up memory
  
                
                if half_precision:
                    p.data = p.data.half()
                    p.grad = p.grad.half() 

        return loss

Hi
I think I found an issue with forming a finite-difference diagonal Hessian in the Newton-style modification to AdaBelief that I described above. In digging into the code that is used for my particular application, I believe it computes a new random dropout of the elements of a neural network (which I believe sets the elements dropped-out to zero) for every forward call of the network. Since the dropout elements will change between the computations that are used to form the finite-difference Hessian approximation, it would degrades the accuracy of the approximation. Since dropout is a recommended approach for preventing overtraining of a network, it probably should not be eliminated.