Wav2Vec2Pretrain (HFTransformersInterface implementation) samples padded values for mask_time_indices and negative_sample_indices
porfirythelaw opened this issue · 3 comments
Describe the bug
I've been using SpeechBrain Wav2Vec2 training recipe (with HF integration) on my own data, and noticed that I get significantly different metrics with the same model on validation dataset depending on the amount of padding in the batch. My hypothesis was that somehow padding is not ignored during indices sampling process, and I think this is what in fact is happening.
speechbrain/speechbrain/lobes/models/huggingface_transformers/wav2vec2.py
Lines 261 to 265 in eba7714
As you can see in this function you don't provide attention mask, so masked indices are drawn from padded values as well.
Same for negative masked indicies, which you take from the whole sequence
speechbrain/speechbrain/lobes/models/huggingface_transformers/wav2vec2.py
Lines 278 to 286 in eba7714
You provide attention mask in this call to the model
speechbrain/speechbrain/lobes/models/huggingface_transformers/wav2vec2.py
Lines 289 to 296 in eba7714
However, if you check hugginface source code it does not affect loss calculation, it only affects encoder self-attention.
I'm not sure if this behavior was intended or not.
Expected behaviour
Padded values should not be influencing model loss / metrics.
To Reproduce
No response
Environment Details
Speechbrain v0.5.16
Relevant Log Output
No response
Additional Context
No response
Hey @TParcollet, could you please have a look?
My local fix is something like this (using features_padding_mask):
padding_mask = make_padding_masks(wav, wav_len=wav_lens)
features_padding_mask = self.model._get_feature_vector_attention_mask(
sequence_length, padding_mask, add_adapter=False
)
# 1. Compute the indices that will be masked
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
mask_prob=self.mask_prob,
mask_length=self.mask_length,
attention_mask=features_padding_mask
)
torch_mask_time_indices = torch.tensor(
mask_time_indices, device=wav.device, dtype=torch.long,
)
# 2. Sample the negative samples from the entire sequence.
# Fairseq does it only on the masked indices, but this only work if you
# have long sentences. For more versatily, we sample on the entire sequence.
# value.
full_sentence_indices = np.ones((batch_size, sequence_length))
negative_sample_indices = torch.tensor(
transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices(
(batch_size, sequence_length),
num_negatives=self.config.num_negatives,
# mask_time_indices=full_sentence_indices,
mask_time_indices=features_padding_mask.detach().cpu().numpy()
),
device=wav.device,
dtype=torch.long,
)
That's quite late to answer, but yes it certainly is true. The reason is that we rely on HF functions here, and back to when we wrote this code, I believe there was no alternative. @porfirythelaw could you propose a PR with this fix? I will test it.
Many thanks.