Similarity to AdaHessian
davda54 opened this issue · 7 comments
Hi, first of all, thank you very much for sharing the code for AdaBelief, it looks like a very promising optimizer! :) Have you considered comparing it to AdaHessian? I feel like AdaHessian is using the same trick as you (but they do it less efficiently).
Thanks for the comment. I quickly skimmed over their paper, the idea is roughly similar, but very different in terms of implementation. Adahessian directly computes the hessian diagonal, ours approximate this by change in gradient. I have not compared to AdaHessian in practice. From the paper, they are doing a back-prop through the back-prop, seems quite slow (2 to 3 times slower than Adam), while ours is similar to Adam in terms of speed. As for the trick, AdaHessian uses a block-averaging trick, which is rarely discussed in previous work, but I feel it might be helpful for AdaBelief and other optimizers too.
Thanks for your quick answer. What I meant is that AdaBelief defines as (very informally), while AdaHessian as (where is the Hessian trace). Otherwise, both optimizers are the same. I believe that should be highly correlated with , as they both show the amount of local curvature.
So I am interested if there is a performance-speed tradeoff between these two optimizers, or if AdaBelief is strictly better than AdaHessian in both the speed and the accuracy.
I don't know if AdaBelief is strictly better than AdaHessian, in fact I guess there will be both cases where one outperforms the other. It's hard to determine without extensive experiments. I want to point out that even in theory they are not the same, h_t is not the trace, but the amplitude of diagonal element of the Hessian if I'm correct. Whether EMA(gt-mt)^2 is a good approximation to Hessian is hard to say. Just like the convergence proof of Adam is so much more trouble than RMSProp, a single modification in practice could result in big difference in theory.
In terms of empirical results, I think the implementation matters, even if the algorithm looks very similar. So I'm sorry I don't have a conclusion yet.
I see :) Anyway, thanks again for your great work!
I know this is closed, but I wanted to agree that if an element of EMA(gt-mt)^2 was divided by the square of the change in the corresponding model parameter, it would be a close finite difference approximation of a diagonal Hessian - which is what AdaHessian uses. Have you tried including this division in your routine to see if it improves results?
Thanks for the comment, I have not tried that in my experiments. It sounds like a very interesting idea. I think it would help, but just one more thing to consider, the gradient w.r.t. parameters depends on data, so we might need to use one batch twice or somehow find an approximation. I'll try to pursue it later.
Hi
I made a try at implementing my suggestion above as another option in adabelief which is used when the flag fd_adahessian
(which stands for finite difference Hessian used as in AdaHessian) is set to True
. What it is using instead of g_t - m_t
in the exponential moving average s_t
(using your paper's notation) is the finite difference of the momentum m
(which is the exponential moving average of the gradient g
with respect to the parameter theta
) with the corresponding model parameter theta
between the present and previous step of the optimizer. The assumption made is that the change in this momentum m
for parameter theta
between steps is primarily due to the change in correspond parameter theta
. This is similar to what adaHessian is assuming when it uses only the diagonal of the Hessian matrix in the update to its version of s_t
. I am unsure of the best initialization s_0
for this version of adabelief since you can't form a finite difference until step 2 - any suggestions you have would be appreciated. I am training a model with it at present and it appears to be no worse than the default version of adabelief. I have time to check the code more carefully while it is training since I am not sure it doesn't have mistakes. Below is the code for this version if you would like to comment on it or try it yourself.
Regards
-Steve
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
# cast data type
half_precision = False
if p.data.dtype == torch.float16:
half_precision = True
p.data = p.data.float()
p.grad = p.grad.float()
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
'AdaBelief does not support sparse gradients, please consider SparseAdam instead')
amsgrad = group['amsgrad']
fd_adahessian = self.fd_adahessian
state = self.state[p]
beta1, beta2 = group['betas']
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
if version_higher else torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
if version_higher else torch.zeros_like(p.data)
if fd_adahessian: # create p_old
state['p_old'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
if version_higher else torch.zeros_like(p.data)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
if version_higher else torch.zeros_like(p.data)
# get current state variable
exp_avg, exp_avg_var, p_old = state['exp_avg'], state['exp_avg_var'], state['p_old']
state['step'] += 1
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
if fd_adahessian:
'''
if state['step'] == 1:
#SJS zero out for first step
delta_m = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
if version_higher else torch.zeros_like(p.data)
else:
#SJS below assume this also works for the first step assuming previous m is zero at p_old of zero
#SJS the step = 1 code above will do a divide by sqrt eps which may blow up the routine
'''
# first calculate delta m --- uses present grad and previous m
delta_m = (grad - exp_avg) * (1 - beta1) #SJS new
delta_m.div_(torch.sub(p.data, p_old).add_(group['eps'])) # approximates delta m / delta p
# Update first and second moment running average
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
if not fd_adahessian:
delta_m = grad - exp_avg #SJS original adabelief... this uses current m which is exp_avg
exp_avg_var.mul_(beta2).addcmul_( delta_m, delta_m, value=1 - beta2)
if amsgrad:
max_exp_avg_var = state['max_exp_avg_var']
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_var, exp_avg_var.add_(group['eps']), out=max_exp_avg_var) #SJS want add_ here?
# Use the max. for normalizing running avg. of gradient
denom = (max_exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) #SJS want add_ here?
else:
denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) #SJS want 2x add_ here?
if fd_adahessian:
state['p_old'] = p.data.clone() # p_old goes with present m for next step
# perform weight decay, check if decoupled weight decay #SJS moved here from earlier because can modify grad that was needed for grad_residual
if self.weight_decouple:
if not self.fixed_decay:
p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
else:
p.data.mul_(1.0 - group['weight_decay'])
else:
if group['weight_decay'] != 0:
grad.add_(p.data, alpha=group['weight_decay'])
# update
if not self.rectify:
# Default update
step_size = group['lr'] / bias_correction1
p.data.addcdiv_( exp_avg, denom, value=-step_size)
else: # Rectified update, forked from RAdam
buffered = group['buffer'][int(state['step'] % 10)]
if state['step'] == buffered[0]:
N_sma, step_size = buffered[1], buffered[2]
else:
buffered[0] = state['step']
beta2_t = beta2 ** state['step']
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
buffered[1] = N_sma
# more conservative since it's an approximated value
if N_sma >= 5:
step_size = math.sqrt(
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
N_sma_max - 2)) / (1 - beta1 ** state['step'])
elif self.degenerated_to_sgd:
step_size = 1.0 / (1 - beta1 ** state['step'])
else:
step_size = -1
buffered[2] = step_size
if N_sma >= 5:
denom = exp_avg_var.sqrt().add_(group['eps'])
p.data.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
elif step_size > 0:
p.data.add_( exp_avg, alpha=-step_size * group['lr'])
if half_precision:
p.data = p.data.half()
p.grad = p.grad.half()
return loss
UPDATE: I did a little derivation and found that the term in AdaBelief: g_t - m_t is equal to: beta2 *(m_t - m_(t-1) )/ (1 - beta2)
So AdaBelief has the numerator part of the momentum finite difference mentioned above.