CTC tokenizer returns <unk> tokens with `do_lower_case=True`
sanchit-gandhi opened this issue · 1 comments
This issue explores the behaviour of the example CV9 tokenizer created in the get_ctc_tokenizer.py
file saved at https://huggingface.co/patrickvonplaten/wav2vec2_ctc_cv9_tokenizer.
With the do_lower_case
argument set to true, we observe that the tokenizer returns the <unk>
token for all input chars:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/wav2vec2_ctc_cv9_tokenizer", do_lower_case=True)
input_str = "The cat sat on the mat."
input_ids = tokenizer(input_str).input_ids
decoded_str = tokenizer.decode(input_ids)
print("Input str: ", input_str)
print("Input ids: ", input_ids)
print("Decoded str: ", decoded_str)
print("<unk> token id: ", tokenizer.unk_token_id)
print("Type: ", type(tokenizer))
Output:
Input str: The cat sat on the mat.
Input ids: [3, 3, 3, 25, 3, 3, 3, 25, 3, 3, 3, 25, 3, 3, 25, 3, 3, 3, 25, 3, 3, 3, 5]
Decoded str: <unk> <unk> <unk> <unk> <unk> <unk>.
<unk> token id: 3
Type: <class 'transformers.models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizer'>
Every input character in the alphabet is incorrectly mapped to the unknown token. However, what's interesting here is that we see the punctuation (word-spaces and full-stop) correctly tokenized and decoded. These are the two inputs that are unaffected by the do_lower_case
operation.
If we now re-run this operation with the do_lower_case
argument set to False
, we observe that the decoded word string is correct (barring the first capital T
):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/wav2vec2_ctc_cv9_tokenizer", do_lower_case=False)
input_str = "The cat sat on the mat."
input_ids = tokenizer(input_str).input_ids
decoded_str = tokenizer.decode(input_ids)
print("Input str: ", input_str)
print("Input ids: ", input_ids)
print("Decoded str: ", decoded_str)
Output:
Input str: The cat sat on the mat.
Input ids: [3, 9, 7, 25, 17, 35, 18, 25, 8, 35, 18, 25, 16, 6, 25, 18, 9, 7, 25, 37, 35, 18, 5]
Decoded str: <unk>he cat sat on the mat.
Thus, the issue must lie within the do_lower_case
attribute of the Wav2Vec2CTCTokenizer
. Indeed, if we inspect the modelling code for this tokenizer, we observe that setting do_lower_case
to True
first runs the upper
method on the input string, converting the string to upper-case:
https://github.com/huggingface/transformers/blob/6d80c92c77593dc674052b5a46431902e6adfe88/src/transformers/models/wav2vec2/tokenization_wav2vec2.py#L238
Since the tokenizer was built on the lower-case vocabulary, each of these upper-case characters are OOV, and so assigned to the <unk>
token.
What we should do is set the do_lower_case
argument to False
, meaning the upper-case operation is not performed on the input string, and simply lower-case the input string before tokenizing:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/wav2vec2_ctc_cv9_tokenizer", do_lower_case=False)
input_str = "The cat sat on the mat.".lower()
input_ids = tokenizer(input_str).input_ids
decoded_str = tokenizer.decode(input_ids)
print("Input str: ", input_str)
print("Input ids: ", input_ids)
print("Decoded str: ", decoded_str)
Input str: the cat sat on the mat.
Input ids: [18, 9, 7, 25, 17, 35, 18, 25, 8, 35, 18, 25, 16, 6, 25, 18, 9, 7, 25, 37, 35, 18, 5]
Decoded str: the cat sat on the mat.
This simply required modification of the lines which instantiate the tokenizer from pretrained in the CTC training script:
seq2seq-speech/run_flax_speech_recognition_ctc.py
Lines 809 to 816 in de22a44
You're 100% right! The naming is a bit misleading here. do_lower_case
should only be set to True for the tokenizer if the tokenizer has upper case letters in the vocab