Inconsistent computation of weight_decay and grad_residual among pytorch versions
sjscotti opened this issue · 5 comments
Hi
I was looking at the various versions you have in the pypi_packages
folder and noticed that the order of computation of weight decay (which for some options modifies grad
) and of grad_residual
(which uses grad
) differs for the different versions. In adabelief_pytorch0.0.5
, adabelief_pytorch0.2.0
, and adabelief_pytorch0.2.1
weight decay is done before computing grad_residual
but in adabelief_pytorch0.1.0
it is done afterwards. It seems that adabelief_pytorch0.1.0
is more closely following what your paper described as the second-order momentum computation. Shouldn't the others be changes to align with adabelief_pytorch0.1.0
?
Thanks for pointing it out. This part is a bit tricky if weight decay is not decoupled. Currently, Adam and AdamW implementation in PyTorch also does the weight decay before the update, I have not got into details and just follow the convention. But I think it would be very interesting to perform a careful comparison.
Thanks for the speedy reply. I had looked through your project page and in the presentation there was a comment from a user that gradient clipping caused problems for the algorithm. But it worked well when he turned gradient clipping off. So I was thinking that you don't want to affect the values stored for the gradients in any way.
Also, I had a question that is a bit of a newbie question WRT pytorch. I noticed that you had a number of .add_(group['eps'])
on different quantities in the code (BTW I am using adabelief_pytorch0.2.1
) which, as I understand it, will modify the quantities where this function is applied in place . So even though it is a small number, did you intend (for example) to modify the value of exp_avg_var
that will be used in future iterations with these function applications?
Hi, thanks for the comments. The gradient clip could be problematic, say suppress the difference
For the add_(group['eps'])
, it's because there are multiple configurations such as weight_decouple, rectify, so I need to make sure they have the same update. Ultimately, ignore all the if branches, it appears twice, as in Algo2 in the paper.
The eps in-place add is slightly different from Adam, and also default eps=1e-16 for AdaBelief but is 1e-8 for Adam. In AdaBelief, eps is added both inside and outside sqrt, but the one outside sqrt can be safely ignored. (sqrt(1e-16)>>1e-16)
For the in-place add that eps will add to exp_avg_var
in every step, it would require some more validations to determine whether it causes a big difference. Say using eps=1e-16
, it would require a huge step t
to make t * eps
large enough to numerically take effect. For larger eps it could cause a difference, but for default 1e-16 I think it would not make too much difference unless train for a very very long time.
Currently, I would suggest keeping eps=1e-16
(at most 1e-12) for most cases. Recently I tested the default setting on Resnet again, I found with proper warmup and smoothly decayed lr schedule, also decoupled weight decay, AdaBelief with default eps can also perform no worse than SGD. I also quickly test on Imagenet with ViT with default eps also outperforms AdamW. For NLP tasks with transformer, I have not tested on very big models yet, but I guess a small eps should also work.
PS: I noticed your comments on the Newton-style modification. It looks very interesting to me. I'll take a careful look and perhaps test on larger NLP tasks later, but for now I'm a bit busy with other stuff and could not pursue very deep. But you are very welcome to post a comment here.
Thanks for helping me understand the application of eps
in AdaBelief.
WRT the Newton-style modification in my other reply and my question of initialization of s_0
(since you can't form a finite difference until step 2), I have modified the version of step
to initialize s_0
to zero, but not to start the exponential moving average recursion calculation until step 2 (see code below) . I also updated the second momentum bias correction and RAdam calculations to take into account this skipping of the first step when doing the Newton-style correction by subtracting 1 from the value used for step. I also decided to use the code in pytorch/optim as a model for the RAdam implementation rather than the original (it looked cleaner to me), and hardwired using SGDM for first step and when RAdam criteria of > 5
isn't met. I am running the code now in training the blenderbot2 400M
transformer model. (BTW, because this model can be trained in fp16
precision, I am grateful that you included the recasting the precision of the parameters and gradients in step
.)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
# cast data type
half_precision = False
if p.data.dtype == torch.float16:
half_precision = True
p.data = p.data.float()
p.grad = p.grad.float()
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
'AdaBelief does not support sparse gradients, please consider SparseAdam instead')
amsgrad = group['amsgrad']
fd_adahessian = self.fd_adahessian
state = self.state[p]
beta1, beta2 = group['betas']
# State initialization
first_step = False
if len(state) == 0:
first_step = True
state['step'] = 0
print('state initialization in adabeliefoff2_1_hess')
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
if version_higher else torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
if version_higher else torch.zeros_like(p.data)
if fd_adahessian: # create p_old
state['p_old'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
if version_higher else torch.zeros_like(p.data)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
if version_higher else torch.zeros_like(p.data)
# get current state variable
exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
state['step'] += 1
bias_correction1 = 1 - beta1 ** state['step']
step = state['step']
if fd_adahessian:
step = step - 1 # because don't form second order momentum till can compute first delta m / delta p finite difference at state['step'] == 2
bias_correction2 = 1 - beta2 ** step
if fd_adahessian:
# first calculate delta m --- uses present grad and previous m
delta_m = (grad - exp_avg) * (1 - beta1) #SJS new
delta_m.div_(torch.sub(p.data, state['p_old']).add_(group['eps'])) # approximates delta m / delta p
if first_step: # can't get delta m / delta p on first step so set delta_m to zero so exp_avg_var update is zero
delta_m.zero_()
# Update first and second moment running average
exp_avg.mul_(beta1).add_(grad, alpha = 1 - beta1)
if not fd_adahessian:
delta_m = grad - exp_avg #SJS original adabelief... this uses current m
exp_avg_var.mul_(beta2).addcmul_( delta_m, delta_m, value = 1 - beta2) #will be zero for first fd_adahessian step
del delta_m # free up memory
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(state['max_exp_avg_var'], exp_avg_var.add_(group['eps']), out=state['max_exp_avg_var']) #SJS want add_ here? #changed to be similar to pytorch/optim code
# Use the max. for normalizing running avg. of gradient
denom = (state['max_exp_avg_var'].sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) #SJS want add_ here?
else:
denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) #SJS want 2x add_ here?
if fd_adahessian:
state['p_old'] = p.data.clone() # p_old goes with present m for next step
# perform weight decay if weight_decay is non-zero, check if decoupled weight decay #SJS moved here from earlier because can modify grad that was needed for delta_m
if group['weight_decay'] != 0:
if self.weight_decouple:
# decoupled weight decay as in AdamW
if not self.fixed_decay:
p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
else:
p.data.mul_(1.0 - group['weight_decay'])
else:
# L2 penalty
grad.add_(p.data, alpha = group['weight_decay'])
# update
if fd_adahessian and first_step:
#update using SGDM step on first step since can't compute finite difference Hessian yet
p.data.add_(exp_avg, alpha = -group['lr'] / bias_correction1)
else:
if not self.rectify:
# Default update
p.data.addcdiv_(exp_avg, denom, value = -group['lr'] / bias_correction1)
else: # Rectified update, forked from RAdam in pytorch optim
#NOTE: because assumption for fd_adahessian is that change over a step in gradient WRT p is primarily due to Hessian diagonal term for p
# we do not not skip updates for parameter p not meeting rho_t > 5 criteria, but ALWAYS do SDGM step
# maximum length of the approximated SMA
rho_inf = 2 / (1 - beta2) - 1
# compute the length of the approximated SMA
rho_t = rho_inf - 2 * step * (beta2 ** step) / bias_correction2
if rho_t > 5.:
# Compute the variance rectification term and update parameters accordingly
rect = math.sqrt((rho_t - 4) * (rho_t - 2) * rho_inf / ((rho_inf - 4) * (rho_inf - 2) * rho_t))
p.data.addcdiv_(exp_avg, denom, value = -rect * group['lr'] / bias_correction1)
else:
# use SGDM for initial steps not meeting criterion
p.data.add_(exp_avg, alpha = -group['lr'] / bias_correction1)
del denom # free up memory
if half_precision:
p.data = p.data.half()
p.grad = p.grad.half()
return loss
Hi
I think I found an issue with forming a finite-difference diagonal Hessian in the Newton-style modification to AdaBelief that I described above. In digging into the code that is used for my particular application, I believe it computes a new random dropout of the elements of a neural network (which I believe sets the elements dropped-out to zero) for every forward
call of the network. Since the dropout elements will change between the computations that are used to form the finite-difference Hessian approximation, it would degrades the accuracy of the approximation. Since dropout is a recommended approach for preventing overtraining of a network, it probably should not be eliminated.