rwth-i6/returnn

Ignore a single broken gradient

JackTemaki opened this issue · 2 comments

In my current language model training I sometimes get "nan" gradients, which break the training. Surprisingly, just restarting the training from the last checkpoint is often enough uncertainty to resume training.

Here people discussed something like:

        valid_gradients = True
        for name, param in self.named_parameters():
            if param.grad is not None:
                valid_gradients = not (torch.isnan(param.grad).any() or torch.isinf(param.grad).any())
                if not valid_gradients:
                    break
        if not valid_gradients:
            print(f'detected inf or nan values in gradients. not updating model parameters')
            self.zero_grad()

I think it would be a good idea to have this as a configurable option for the updater. Preferably with a "limit", so that it still crashes after e.g. 5 broken updates.

I sometimes get "nan" gradients, which break the training.

You mean after that, the model parameters itself become nan, i.e. the model is broken?

Surprisingly, just restarting the training from the last checkpoint is often enough uncertainty to resume training.

Uncertainty? You mean restarting from the last checkpoint solves the issue, i.e. getting nan is non-deterministic, and rare, and after such restart, you are usually lucky that you don't get nan anymore, or only much later?

We could also implement such automatic restart logic. It would be another approach than what you suggest afterwards, i.e. to always check for non-finite grads.

I'm not sure what approach would be better. It probably depends also on how often you get this. (E.g. I personally have never gotten this problem so far.)

I do not want such automatic restart logic, because this wastes computation. For testing I now implemented here that it skips the update when the result of the gradient clipping is NaN or Inf. I do this for the grad_clip_norm_ right now, and I am not sure how grad_clip_value_ behaves with Inf/NaN values (maybe it automatically corrects them already).

This specific problem appeared for me only for LSTM-LM training so far, not anywhere else (ASR, TTS).

In my current training after 2 full epochs, I now get a skip every 40k update steps. Before it did not happen at all.