rickstaa/stable-learning-control

Prevent lambda from becoming stuck at 1.0

Closed this issue · 1 comments

The lambda Lagrange multiplier seems to get stuck at 1.0 sometimes.

Solution

This problem seems to be fixed if we use log_lambda in the lambda_loss function instead of lambda.

Use lambda

image

Use log_lambda

image

Testing 3

Let's perform some last quick tests to see if the behaviour changes if we:

  • Start with lambda on 1.0.
    • Causes the problem of lambda being stuck to occur more often when lambda is used in the lambda loss. This is not the case when we use log_lambda in the actor loss.
  • Let's check what the lambda does if we don't clip it to be below 1.0. We have to clip between 0.0 otherwise we get the problem that log_lambda becomes -inf.
    • The problem goes away if the lambda is not clipped to be lower than 1.0 and lambda is used in the lambda loss (I used a range of 0.0-2.0).

Possible causes

  1. As can be seen from the discussion on #32 this might be because log_lambda is more numerically stable. Looking at this stack exchange question](https://datascience.stackexchange.com/questions/25024/strange-behavior-with-adam-optimizer-when-training-for-too-long) the reasoning might be as follows:
Use of lambda -> Higher variance in the lambda loss -> Higher variance in the gradients -> This saturates the adam optimizer which uses a rolling geometric mean of recent gradients and squares of the gradients.
  1. The problem is caused by the fact that we both use lambda in the loss_function but also clamp lambda using torch.clamp. This causes the gradients outside the accepted range (in our case 0-1.0 to become zero. This means that when the log_lambda becomes bigger than 0.0 the gradients vanish. As a result the lambda_loss does not influence the log_lambda anymore. The optimizer therefore has no way to optimize the log_lambda and related lambda also see this pytorch issue.

image

Figure taken from this pytorch forum post.

Conclusion

@panweihit Based on my debugging session explanation 2 seems to be the most prominent. If we look at the gradients they indeed become 0.0 when log_lambda becomes bigger than 0.0. I therefore think it is wise to use log_lambda instead of the in the paper defined lambda (see eq 14 of Han et al 2020). While doing this, we have to specify the reason why the used format is different from the paper.

Fixed in v0.5.0.