sanchit-gandhi/seq2seq-speech

CTC tokenizer returns <unk> tokens with `do_lower_case=True`

sanchit-gandhi opened this issue · 1 comments

This issue explores the behaviour of the example CV9 tokenizer created in the get_ctc_tokenizer.py file saved at https://huggingface.co/patrickvonplaten/wav2vec2_ctc_cv9_tokenizer.

With the do_lower_case argument set to true, we observe that the tokenizer returns the <unk> token for all input chars:

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/wav2vec2_ctc_cv9_tokenizer", do_lower_case=True)

input_str = "The cat sat on the mat."
input_ids = tokenizer(input_str).input_ids
decoded_str = tokenizer.decode(input_ids)

print("Input str: ", input_str)
print("Input ids: ", input_ids)
print("Decoded str: ", decoded_str)
print("<unk> token id: ", tokenizer.unk_token_id)
print("Type: ", type(tokenizer))

Output:

Input str:  The cat sat on the mat.
Input ids:  [3, 3, 3, 25, 3, 3, 3, 25, 3, 3, 3, 25, 3, 3, 25, 3, 3, 3, 25, 3, 3, 3, 5]
Decoded str:  <unk> <unk> <unk> <unk> <unk> <unk>.
<unk> token id:  3
Type:  <class 'transformers.models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizer'>

Every input character in the alphabet is incorrectly mapped to the unknown token. However, what's interesting here is that we see the punctuation (word-spaces and full-stop) correctly tokenized and decoded. These are the two inputs that are unaffected by the do_lower_case operation.

If we now re-run this operation with the do_lower_case argument set to False, we observe that the decoded word string is correct (barring the first capital T):

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/wav2vec2_ctc_cv9_tokenizer", do_lower_case=False)

input_str = "The cat sat on the mat."
input_ids = tokenizer(input_str).input_ids
decoded_str = tokenizer.decode(input_ids)

print("Input str: ", input_str)
print("Input ids: ", input_ids)
print("Decoded str: ", decoded_str)

Output:

Input str:  The cat sat on the mat.
Input ids:  [3, 9, 7, 25, 17, 35, 18, 25, 8, 35, 18, 25, 16, 6, 25, 18, 9, 7, 25, 37, 35, 18, 5]
Decoded str:  <unk>he cat sat on the mat.

Thus, the issue must lie within the do_lower_case attribute of the Wav2Vec2CTCTokenizer. Indeed, if we inspect the modelling code for this tokenizer, we observe that setting do_lower_case to True first runs the upper method on the input string, converting the string to upper-case:
https://github.com/huggingface/transformers/blob/6d80c92c77593dc674052b5a46431902e6adfe88/src/transformers/models/wav2vec2/tokenization_wav2vec2.py#L238
Since the tokenizer was built on the lower-case vocabulary, each of these upper-case characters are OOV, and so assigned to the <unk> token.

What we should do is set the do_lower_case argument to False, meaning the upper-case operation is not performed on the input string, and simply lower-case the input string before tokenizing:

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/wav2vec2_ctc_cv9_tokenizer", do_lower_case=False)

input_str = "The cat sat on the mat.".lower()
input_ids = tokenizer(input_str).input_ids
decoded_str = tokenizer.decode(input_ids)

print("Input str: ", input_str)
print("Input ids: ", input_ids)
print("Decoded str: ", decoded_str)
Input str:  the cat sat on the mat.
Input ids:  [18, 9, 7, 25, 17, 35, 18, 25, 8, 35, 18, 25, 16, 6, 25, 18, 9, 7, 25, 37, 35, 18, 5]
Decoded str:  the cat sat on the mat.

This simply required modification of the lines which instantiate the tokenizer from pretrained in the CTC training script:

do_lower_case = data_args.do_lower_case
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
do_lower_case=do_lower_case,
)

You're 100% right! The naming is a bit misleading here. do_lower_case should only be set to True for the tokenizer if the tokenizer has upper case letters in the vocab