Suppress tokens during training
pprobst opened this issue · 4 comments
Hello. First of all, thank you for your work.
I came upon this repo when trying to improve the transcription speed in whispercpp by using a lower audio_ctx
.
However, while fine-tuning with this code, it does not seem to be suppressing tokens. I adapted the code a bit to set suppress_tokens
and to use my dataset in Spanish, but everything else of importance remained unchanged.
When fine-tuning normally, at least for WhisperForConditionalGeneration
, setting model.config.suppress_tokens = suppress
works. But I'm not sure it's working here. Furthermore, my training dataset (about 6100 audio files) does not use any punctuation marks, that is, "," is "comma" and not the actual character. So I at least expect the model to learn not to use punctuation from the data itself, even if I do not suppress them explicitly during inference, but it's not what is happening here -- even if I train for 15 epochs -- the same number of epochs I use in my "normal" pipeline that does not use dynamic audio context.
Also, in the code below, suppressing tokens only works during inference if I set these in the model.generate
function:
model.model = model_train.eval().cuda()
predicted_ids_train = model.generate(
input_features,
suppress_tokens=suppress,
forced_decoder_ids=processor.get_decoder_prompt_ids(
language=language, task="transcribe"
),
)
This made me suspect that it's not suppressing tokens during training. But I don't know how to verify this.
Full code (except loading the dataset, but nothing unusual there):
#!/usr/bin/env python3
import shutil
import torch
from datasets import load_dataset
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from torch import nn
from transformers import (
WhisperModel,
WhisperTokenizer,
WhisperProcessor,
WhisperForConditionalGeneration,
)
from dataset import CSVDataset
from sys import argv
EPOCHS = 3
SEED = 42
SUPPRESS_TOKENS_TRAIN = "0123456789@#%&*+=_$€:-.,?¿;!¡"
languages = {
"spanish": "es",
"portuguese": "pt",
"es": "es",
"pt": "pt",
}
MODEL = argv[1]
language = languages[argv[2].lower()]
data_path = Path(argv[3])
def get_suppress_tokens(tokenizer: WhisperTokenizer, tokens: str):
suppressed_tokens = [
i
for i in range(tokenizer.eos_token_id)
if all(c in tokens for c in tokenizer.decode([i]).removeprefix(" "))
]
return suppressed_tokens
processor = WhisperProcessor.from_pretrained(
f"openai/whisper-{MODEL}", task="transcribe", language=language
)
model_train = WhisperModel.from_pretrained(f"openai/whisper-{MODEL}")
model_base = WhisperModel.from_pretrained(f"openai/whisper-{MODEL}")
suppress = []
suppress = [-1] + get_suppress_tokens(processor.tokenizer, SUPPRESS_TOKENS_TRAIN)
model_train.config.suppress_tokens = suppress
model_base.config.suppress_tokens = suppress
model_train.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
language=language, task="transcribe"
)
model_base.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
language=language, task="transcribe"
)
model_train = model_train.train().cuda()
model_base = model_base.eval().cuda()
ds_full = CSVDataset(
data_path=data_path,
processor=processor,
add_silence=False,
ignore_datasets=[],
audio_augs=True,
)
ds = ds_full.dataset["train"]
# ds = load_dataset("google/fleurs", "en_us", split="train")
def get_sample(example):
waveform = example["audio"]["array"]
sampling_rate = example["audio"]["sampling_rate"]
# Use the model and processor to transcribe the audio:
input_features = processor(
waveform, sampling_rate=sampling_rate, return_tensors="pt"
).input_features
return {
"length": len(waveform) / sampling_rate,
"input_features": input_features,
"input_ids": processor.tokenizer.encode(example["transcript"].lower()),
}
# if not (".en" in MODEL):
# print(processor.get_decoder_prompt_ids(language="english", task="transcribe"))
# [processor.tokenizer.decode(i) for i in get_sample(ds[1])["input_ids"]]
def compute_partially_encoder(model, data, n_audio_ctx):
diffy = 2 * n_audio_ctx - data.shape[2]
if diffy > 0:
data = nn.functional.pad(data, [0, diffy, 0, 0, 0, 0], "constant", 0.0)
elif diffy < 0:
data = data[:, :, :diffy]
if n_audio_ctx == 1500:
return model.encoder(data).last_hidden_state
input_embeds = nn.functional.gelu(model.encoder.conv1(data))
input_embeds = nn.functional.gelu(model.encoder.conv2(input_embeds))
input_embeds = input_embeds.permute(0, 2, 1)
embed_pos = model.encoder.embed_positions.weight[:n_audio_ctx]
hidden_states = input_embeds + embed_pos
hidden_states = nn.functional.dropout(
hidden_states, p=model.encoder.dropout, training=model.encoder.training
)
for idx, encoder_layer in enumerate(model.encoder.layers):
to_drop = False
if model.encoder.training:
dropout_probability = torch.rand([])
if dropout_probability < model.encoder.layerdrop:
to_drop = True
if to_drop:
layer_outputs = (None, None)
else:
if model.encoder.gradient_checkpointing and model.encoder.training:
layer_outputs = model.encoder._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
None,
None,
False,
)
else:
layer_outputs = encoder_layer(
hidden_states,
None,
layer_head_mask=None,
output_attentions=False,
)
hidden_states = layer_outputs[0]
hidden_states = model.encoder.layer_norm(hidden_states)
return hidden_states
def compute_hidden_state_loss(model_train, model_base, optimizer, criterion, example):
optimizer.zero_grad()
n_ctx = int(round((1500.0 / 30.0) * example["length"]))
extra_ctx = torch.randint(-min(64, n_ctx // 3), min(64, n_ctx // 3), (1,)).item()
n_ctx += extra_ctx
input_features = example["input_features"].cuda()
input_ids = torch.tensor([example["input_ids"]], dtype=torch.long).cuda()
encoder_hidden_states_partial = compute_partially_encoder(
model_train, input_features, n_ctx
)
output_partial = model_train.decoder(
input_ids=input_ids,
encoder_hidden_states=encoder_hidden_states_partial,
output_hidden_states=True,
)
with torch.no_grad():
encoder_hidden_states_full = compute_partially_encoder(
model_base, input_features, 1500
)
output_full = model_base.decoder(
input_ids=input_ids,
encoder_hidden_states=encoder_hidden_states_full,
output_hidden_states=True,
)
# print(output_full)
loss = criterion(
# output_partial.hidden_states[-1],
# output_full.hidden_states[-1]
torch.cat(output_partial.hidden_states, 0),
torch.cat(output_full.hidden_states, 0),
)
loss.backward()
optimizer.step()
return loss
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model_train.parameters(), lr=1e-6)
writer = SummaryWriter()
writer.add_text("name", f"{MODEL} v3")
num_length = 0
step = 0
for epoch in range(EPOCHS):
pbar = tqdm(ds.shuffle(seed=SEED))
for example in pbar:
example = get_sample(example)
if example["length"] > 29.0:
continue
loss = compute_hidden_state_loss(
model_train, model_base, optimizer, criterion, example
)
step += 1
num_length += example["length"]
writer.add_scalar("loss/train", loss.item(), step)
writer.add_scalar("length/train", num_length, step)
writer.add_scalar("epoch/train", epoch, step)
pbar.set_description(f"Epoch {epoch}, Loss: {loss.item()}")
# Select an audio file and read it:
# ds_eval = load_dataset(
# "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
# )
ds_eval = ds_full.dataset["val"]
# Load the Whisper model in Hugging Face format:
model = WhisperForConditionalGeneration.from_pretrained(f"openai/whisper-{MODEL}")
model.config.suppress_tokens = suppress
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
language=language, task="transcribe"
)
model = model.eval().cuda()
for i in range(len(ds_eval)):
audio_sample = ds_eval[i]["audio"]
waveform = audio_sample["array"]
sampling_rate = audio_sample["sampling_rate"]
# print(ds_eval[i])
# input_features = ds_eval[i]["input_features"].cuda()
# Use the model and processor to transcribe the audio:
input_features = processor(
waveform, sampling_rate=sampling_rate, return_tensors="pt"
).input_features.cuda()
model.model = model_base.eval().cuda()
# Suppress tokens only work here, not when setting in the model.config. Why?
predicted_ids_base = model.generate(
input_features,
suppress_tokens=suppress,
forced_decoder_ids=processor.get_decoder_prompt_ids(
language=language, task="transcribe"
),
)
model.model = model_train.eval().cuda()
predicted_ids_train = model.generate(
input_features,
suppress_tokens=suppress,
forced_decoder_ids=processor.get_decoder_prompt_ids(
language=language, task="transcribe"
),
)
# Decode token ids to text
transcription = processor.batch_decode(
[predicted_ids_base[0], predicted_ids_train[0]], skip_special_tokens=True
)
# Use self.tokenizer._basic_normalize(pred).strip() to normalize the transcriptions
transcription = [
processor.tokenizer._basic_normalize(pred).strip() for pred in transcription
]
print(
f"\n\nGrndTr: {ds_eval[i]['transcript'].lower()}\nModelB: {transcription[0]}\nModelT: {transcription[1]}"
)
model = (
WhisperForConditionalGeneration.from_pretrained(f"openai/whisper-{MODEL}")
.eval()
.cpu()
)
model.config.suppress_tokens = suppress
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
language=language, task="transcribe"
)
model.model = model_train.eval().cpu()
model.save_pretrained(f"model_train-{MODEL}3")
processor.tokenizer.save_pretrained(f"model_train-{MODEL}3")
shutil.make_archive(f"model_train-{MODEL}3", "zip", f"model_train-{MODEL}3")
Finally, when running inference on a test set using whispercpp and the fine-tuned model, I get a WER of 39% instead of the usual 7% that I get when trained "normally".
I at least expect the model to learn not to use punctuation from the data itself, even if I do not suppress them explicitly during inference, but it's not what is happening here
During training with this method, the model is not trained to predict the given labels exactly, but rather it's trained to have the same hidden states as the original model. Meaning, if the original model thinks the word "comma" in "hello comma how are you" is unlikely, the finetuned model will also think it's unlikely. This is in contrast to regular finetuning, where the finetuned model gets nudged to believe the training data is probable.
It's useful to think of this finetuning process more as a distillation process, where the training data just so happens to be a reference for how the models behave, rather than a source of truth that the model is trained to match. If you want your model to have specific behavior/domain-knowledge and be able to use audio_ctx, you'll have to finetune them normally first, and then use this method to make it more robust with audio_ctx.
I'm not entirely sure about the suppress_tokens issue you mention though
I at least expect the model to learn not to use punctuation from the data itself, even if I do not suppress them explicitly during inference, but it's not what is happening here
During training with this method, the model is not trained to predict the given labels exactly, but rather it's trained to have the same hidden states as the original model. Meaning, if the original model thinks the word "comma" in "hello comma how are you" is unlikely, the finetuned model will also think it's unlikely. This is in contrast to regular finetuning, where the finetuned model gets nudged to believe the training data is probable.
It's useful to think of this finetuning process more as a distillation process, where the training data just so happens to be a reference for how the models behave, rather than a source of truth that the model is trained to match. If you want your model to have specific behavior/domain-knowledge and be able to use audio_ctx, you'll have to finetune them normally first, and then use this method to make it more robust with audio_ctx.
I'm not entirely sure about the suppress_tokens issue you mention though
Oh, this changes everything, thanks. I believe I misunderstood, then. I suppose that if I use my already fine-tuned model and "distil" it using your code, then things will work. I'll do this and post the results here for history.
Hm. I've not been so successful. I ran whispercpp's main
with varying audio context on a benchmark dataset, and these are the results:
ac1500 (default model)
eval_spanish_base-20240328_6k16b_1500/log_2024-04-12T12:59.txt
2024-04-12 13:02:24 __main__ INFO WER: 0.06779981114258735 (6.78%)
ac768 (default model)
eval_spanish_base-20240328_6k16b_768/log_2024-04-12T13:03.txt
9:2024-04-12 13:05:53 __main__ INFO WER: 0.07034938621340887 (7.03%)
# Below are whisper-actf models (fine-tuned over my "default" model over 8 epochs):
ac1500
actf_spanish_base-20240328_6k16b_1500/log_2024-04-12T12:35.txt
2024-04-12 12:38:43 __main__ INFO WER: 0.15401322001888573 (15.40%)
ac768
actf_spanish_base-20240328_6k16b_768/log_2024-04-12T12:43.txt
2024-04-12 12:46:00 __main__ INFO WER: 0.1474032105760151 (14.74%)
ac512
actf_spanish_base-20240328_6k16b_512/log_2024-04-12T12:46.txt
2024-04-12 12:48:34 __main__ INFO WER: 0.1457979225684608 (14.58%)
ac512 (greedy)
actf_spanish_base-20240328_6k16b_512_greedy/log_2024-04-12T12:54.txt
2024-04-12 12:56:32 __main__ INFO WER: 0.15920679886685551 (15.92%)
ac500
actf_spanish_base-20240328_6k16b_500/log_2024-04-12T12:51.txt
2024-04-12 12:54:08 __main__ INFO WER: 0.14513692162417374 (14.51%)
ac256
actf_spanish_base-20240328_6k16b_256/log_2024-04-12T12:49.txt
2024-04-12 12:51:28 __main__ INFO WER: 0.42360717658168084 (42.36%)
WER doubles essentially, even if I use the default ac
.
Just to confirm, you made sure to replace both model_train
and model_base
with copies of your finetuned model? What if you try finetuning with the default dataset instead of the custom one?