sanchit-gandhi/seq2seq-speech

Negative Losses in CTC Training

sanchit-gandhi opened this issue · 2 comments

Training the baseline CTC model on the Common Voice 9 (CV9) dataset, we observe that the training loss drops below zero after ~1.5k train steps: https://wandb.ai/sanchit-gandhi/commonvoice_9_0/runs/y593pwm4?workspace=user-sanchit-gandhi. The CTC loss should be strictly nonnegative.

  • CV9 tokenizer: working as expected. Tested within the training script (both tokenising, and decoding), and checked that all attributes all set correctly. Furthermore, the target string in the wandb predictions logs are identical to the transcribed text in the training data -> the tokenizer is correctly tokenising and decoding.
  • Logits test: for the randomly initialised (unscanned) model, the PT-Flax equivalence test passes on CV9 for both the logits and losses. Loss is nonnegative. https://github.com/sanchit-gandhi/seq2seq-speech/blob/main/tests/check_flax_ctc_cv9.py
  • Trained Flax model: using the 50k train steps checkpoint, the loss is negative.
  • We expect to see all log probabilities in the CTC loss function be strictly negative: the probabilities should lie in the range 0 to 1, and so should have a max value of log(1) = 0.
  • The CTC loss function can be divided into three stages:
  1. Initialisation of log prob arrays
  2. Looping over the CTC Markov chain process
  3. Extraction of per-sequence loss and CTC reduction ("mean" reduction)
  • Our attention is focused on the CTC Markov chain process:
    def loop_body(prev, x):
    prev_phi, prev_emit = prev
    # emit-to-phi epsilon transition, except if the next label is repetition
    prev_phi_orig = prev_phi
    prev_phi = prev_phi.at[:, 1:].set(
    jnp.logaddexp(prev_phi[:, 1:], prev_emit + log_epsilon * repeat))
    logprob_emit, logprob_phi, pad = x
    # phi-to-emit transition
    next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit,
    prev_emit + logprob_emit)
    # self-loop transition
    next_phi = prev_phi + logprob_phi
    # emit-to-phi blank transition only when the next label is repetition
    next_phi = next_phi.at[:, 1:].set(jnp.logaddexp(next_phi[:, 1:], prev_emit + logprob_phi + log_epsilon * (1.0 - repeat)))
    pad = pad.reshape((batchsize, 1))
    next_emit = pad * prev_emit + (1.0 - pad) * next_emit
    next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi
    return (next_phi, next_emit), (next_phi, next_emit)
  • When we remove the lax backend in the CTC loss function and use some print statements, we see that positive values creep into the log probabilities by the 4th loop of the CTC algorithm. They first occur in the 'phi-to-emit transition' with the 'next_emit' log probabilities, and then cascade to all further log probs. https://github.com/sanchit-gandhi/seq2seq-speech/blob/main/tests/check_negative_loss_ctc.ipynb
  • Changing the value of the log_epsilon hyperparameter to a more negative value does not alter this behaviour.

Found it! When we define the model:

model = FlaxWav2Vec2ForCTC.from_pretrained(
model_args.model_name_or_path,
config=config,
dtype=dtype,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)

We need to set the config attribute vocab_size to the number of elements in the tokenizer's vocabulary. Otherwise, it will default to the vocab_size for the Wav2Vec2-large-lv60 checkpoint, which is defined as the vocab size of the default Wav2Vec2 tokenizer built on Librispeech ASR. If the actual tokenizer's vocab size is greater than that of the default Wav2Vec2 tokenizer, we'll have logits that span over a partial sub-space of the full tokenizer vocabulary. These ill-defined logits then (likely) give rise to an ill-defined CTC loss function.

Great catch! Due to this the tokenizer converts too many letters to tokens which surely messes up the CTC loss.
You're exatly right we should add a vocab_size=len(tokenizer) here