Prevent lambda from becoming stuck at 1.0
Closed this issue · 1 comments
rickstaa commented
The lambda Lagrange multiplier seems to get stuck at 1.0 sometimes.
Solution
This problem seems to be fixed if we use log_lambda
in the lambda_loss
function instead of lambda.
Use lambda
Use log_lambda
Testing 3
Let's perform some last quick tests to see if the behaviour changes if we:
- Start with lambda on 1.0.
- Causes the problem of lambda being stuck to occur more often when
lambda
is used in thelambda loss
. This is not the case when we uselog_lambda
in the actor loss.
- Causes the problem of lambda being stuck to occur more often when
- Let's check what the lambda does if we don't clip it to be below 1.0. We have to clip between 0.0 otherwise we get the problem that log_lambda becomes -inf.
- The problem goes away if the lambda is not clipped to be lower than 1.0 and
lambda
is used in thelambda
loss (I used a range of 0.0-2.0).
- The problem goes away if the lambda is not clipped to be lower than 1.0 and
Possible causes
- As can be seen from the discussion on #32 this might be because log_lambda is more numerically stable. Looking at this stack exchange question](https://datascience.stackexchange.com/questions/25024/strange-behavior-with-adam-optimizer-when-training-for-too-long) the reasoning might be as follows:
Use of lambda -> Higher variance in the lambda loss -> Higher variance in the gradients -> This saturates the adam optimizer which uses a rolling geometric mean of recent gradients and squares of the gradients.
- The problem is caused by the fact that we both use
lambda
in theloss_function
but also clamp lambda using torch.clamp. This causes the gradients outside the accepted range (in our case0-1.0
to become zero. This means that when the log_lambda becomes bigger than 0.0 the gradients vanish. As a result thelambda_loss
does not influence thelog_lambda
anymore. The optimizer therefore has no way to optimize thelog_lambda
and relatedlambda
also see this pytorch issue.
Figure taken from this pytorch forum post.
Conclusion
@panweihit Based on my debugging session explanation 2 seems to be the most prominent. If we look at the gradients they indeed become 0.0
when log_lambda becomes bigger than 0.0
. I therefore think it is wise to use log_lambda
instead of the in the paper defined lambda
(see eq 14 of Han et al 2020). While doing this, we have to specify the reason why the used format is different from the paper.