Policy weights and output becomes NaN after some iterations
charlesjsun opened this issue · 5 comments
Issue Overview
After some training iterations, the policy starts outputting NaN, all the policy weights become NaN.
Package versions
- Lastest commit of softlearning
- tensorflow version 2.2.0rc2
- tfp-nightly version 0.11.0.dev20200424
Preliminary debugging
I think the issue might be cause by either of these three things: the Tanh bijector, large learning rate (unlikely, using the default 3e-4), or the alpha training (most likely alpha).
Tanh
The policy sometimes output actions that are 1 or -1 (I checked and it never outputted values greater than 1 in magnitude), which may cause a problem with the inverse becoming +inf or -inf, which may or may not be a problem because I don't know if inverse is ever used (edit: inverse is called when calculating log_prob https://github.com/tensorflow/probability/blob/dd3a555ef37fc31c6ad04f3236942e3dbc0f4228/tensorflow_probability/python/distributions/transformed_distribution.py#L509). Could be the problem due to tensorflow/probability#840 but this is apparently fixed by combining action and log prob together.
Alpha
This is most likely the issue. From my logging diagnostics inside of _do_training_repeats
a few training steps before the policy failed look like this:
diagnostics: OrderedDict([('Q_value-mean', 3.2876506), ('Q_loss-mean', 0.04909911), ('policy_loss-mean', -3.1032994), ('alpha', nan), ('alpha_loss-mean', -inf)])
diagnostics: OrderedDict([('Q_value-mean', 3.2876506), ('Q_loss-mean', 0.04909911), ('policy_loss-mean', -3.1032994), ('alpha', nan), ('alpha_loss-mean', -inf)])
diagnostics: OrderedDict([('Q_value-mean', 3.3472314), ('Q_loss-mean', nan), ('policy_loss-mean', nan), ('alpha', nan), ('alpha_loss-mean', nan)])
diagnostics: OrderedDict([('Q_value-mean', 3.3472314), ('Q_loss-mean', nan), ('policy_loss-mean', nan), ('alpha', nan), ('alpha_loss-mean', nan)])
diagnostics: OrderedDict([('Q_value-mean', 3.3472314), ('Q_loss-mean', nan), ('policy_loss-mean', nan), ('alpha', nan), ('alpha_loss-mean', nan)])
We can see that alpha was the first to fail, which then propagated to the Q functions and policy. I also noticed that during training, sometimes alpha would become negative, and from my understanding of automatic entropy adjustment, alpha should always be non-negative.
After digging through the SAC training step, I noticed this line
softlearning/softlearning/algorithms/sac.py
Lines 247 to 252 in 84d7589
which is different from the old tf1 implementation that uses log_alpha instead
softlearning/softlearning/algorithms/sac.py
Lines 210 to 211 in bd30e33
The SAC paper uses alpha as the multiplier instead of log_alpha in the loss function, so the old implementation might be an oversight? However, the old code did store log alpha as the training variable.
Or the issue might be something else, for example what caused the alpha loss to be -inf in the first place? Perhaps log_pis became -inf, which means actions_and_log_probs was the problem? I don't know enough about the implementation to decide for sure.
Let me know if you want more logs or the programs output or other questions, etc. This was ran on my own environment in a fork of this repo:
Hey @externalhardrive, thanks a lot for the very thorough report! This is very interesting.
Tanh
This is actually a great suggestion for the cause, and generally I'd consider this being the most nan-prone place of our code. However, in its current state, I believe this to be unlikely since the log prob computation should be numerically very stable given that (as you mention) we compute the log probs together with the actual actions. This means that effectively the numerical issues caused by tanh are bypassed by the caching. It's not impossible that this would be the cause but I doubt it. Also, if this was the case, I don't really know how this could be made more numerically stable without introducing some hacks.
Alpha
Good observation about the log vs. non-log alpha. Neither the usage of log alpha in the old code and the switch to non-log alpha are actually oversights but rather conscious choices. We initially decided to use the log version, mainly because of the generally nicer numerical properties of log values. However, this was rather confusing implementation detail (see e.g. #37) and the paper does not mention anything about it, and given that after our testing there was no difference at all between the log and non-log versions, we decided to switch to the direct log alpha version. Again, I think it's possible, but very unlikely that alpha itself would be the issue. My guess here is that the nans first pop out in the log probabilities used in the alpha loss.
When do you see these nans exactly? Do they always happen very early in the training? I have seen some cases recently where some of the observations change wildly in the very beginning of the training. One good example is the quadruped
env in dm_control
, where the imu observations sporadically jump up to ~300 from the typical values of <10. I believe (haven't verified this yet though) this causes the data being so far from the previously seen data distribution that the action log probs become nans. One way to verify if this happens in your case too would be to do the following:
- Wrap these lines in the
SimpleSampler.sample()
in a try-except block like:
try:
next_observation, reward, terminal, info = self.environment.step(
action)
except Exception as e:
from pprint import pprint
all_observations = self.pool.last_n_batch(self.pool.size)['observations']
print("min:")
pprint(tree.map_structure(
lambda x: np.min(x, axis=0), all_observations))
print()
print("max:")
pprint(tree.map_structure(
lambda x: np.max(x, axis=0), all_observations))
print()
print("mean:")
pprint(tree.map_structure(
lambda x: np.mean(x, axis=0), all_observations))
print()
print("std:")
pprint(tree.map_structure(
lambda x: np.std(x, axis=0), all_observations))
print()
breakpoint()
pass
- I'm pretty sure the code will stop in the exception-block. Check the values. Are there any surprisingly high-magnitude values?
You can also check the values that introduce the nans directly by adding a breakpoint in SAC._update_alpha and doing something like this (I didn't test this):
import tree
from pprint import pprint
if not all(tree.flatten(tree.map_structure(
lambda x: tf.reduce_all(tf.math.is_finite(x)),
(alpha_loss, alpha_gradients),
))):
alpha_loss_nan_index = tf.where(~tf.math.is_finite(alpha_loss))
nan_causing_observation = tree.map_structure(
lambda x: x[alpha_loss_nan_index], observations)
pprint(nan_causing_observation)
Don't forget to run with --debug=True --run-eagerly=True
so that the debugger behaves nicely and the tf graph mode gets disabled!
I hope this is the issue. If so, one easy (yet possibly temporary) fix would be to scale and squash the problem-causing observations like: problematic_observation = 50.0 * np.tanh(problematic_observation / 200.0)
or something like that. If, on the other hand, this is not the case, and I can dive a bit deeper into this by running your environment.
There doesn't appear to be any problem with the observations. All the observations are within expected range: pixels value are between 0 and 255 (uint8), and velocity values are between -1 and 1 (float32). I tried using the untested code you gave above but it doesn't work.
However, I think the problem may have been the alpha training after all. What I did was instead of training self._alpha
directly, I stored the training variable as self._log_alpha
instead, and replaced all instances of self._alpha
with tf.exp(self._log_alpha)
. e.g.:
# __init__
self._log_alpha = tf.Variable(0.0, name='log_alpha')
# _update_alpha
alpha_losses = -1.0 * (tf.exp(self._log_alpha) * tf.stop_gradient(log_pis + self._target_entropy))
After making this change, I no longer encountered any issue (I'm still training the experiment, so errors may still arise later). From looking at logs, I see that in the experiments that failed, they all had times where alpha became negative, which could have messed up the policy training, which led to the policy outputting infinity or -infinity, which then led to the problem above. By storing _log_alpha instead, alpha is always constrained to be positive.
Looking at the graph of the policy entropy during training, I saw this with training using _alpha:
and this when converting to log_alpha:
And clearly in the first graph something is off (entropy oscillates too much and even jumped down to around -50 two times), whereas the second graph looks as expected.
I read the issue #37, and perhaps using log_alpha or alpha doesn't make much difference when training alpha itself, as long as the internal representation is log_alpha, because in the policy loss alpha is still used regardless.
Also, I did run this with another environment a week ago and had no problem, but it may be due to the other environment had denser rewards, so alpha never dropped too low.
As an aside, I also tried a hack by creating a ClippedTanh bijector (https://github.com/externalhardrive/mobilemanipulation-tf2/blob/master/softlearning/distributions/bijectors/clipped_tanh.py) that clips the output so it never outputs 1 or -1. I'm not sure if I created the bijector correctly, but this did not fix the issue.
I'm not sure if my hypothesis is correct, but I'm going to keep running my experiments as usual using my changes and see if any problems come up. Here are my changes if you want to look at them: https://github.com/externalhardrive/mobilemanipulation-tf2/blob/eb16f10844efaebd8c645c85508a46ca61436261/softlearning/algorithms/sac.py
Thanks!
Also, I'm reading this again and isn't there supposed to be a tf.nn.compute_average_loss before calculating and applying the gradient?
softlearning/softlearning/algorithms/sac.py
Lines 214 to 233 in 84d7589
Wow, this is really interesting! Thanks for digging into this. I'll try to replicate this on my end, and if it indeed turns out to be the difference between log alpha and alpha, we should probably move back to using the log alpha formulation.
Re clipped tanh bijector: I think the implementation seems right. But I think ultimately it's better to fix the underlying issue instead of patching it on this level. The clipping is particularly annoying because you lose the gradient signal. If using log alpha works, I'd recommend using that and removing the clipping.
Re average loss computation: I think you're actually right here. I'll open a PR soon to make all these losses consistent across the algorithm. Thanks for spotting this!
Looking into this again, the alpha should definitely be constrained to be positive, which I'm completely ignoring in the current implementation. I'll push a fix shortly.