futo-org/whisper-acft

Suppress tokens during training

pprobst opened this issue · 4 comments

Hello. First of all, thank you for your work.

I came upon this repo when trying to improve the transcription speed in whispercpp by using a lower audio_ctx.
However, while fine-tuning with this code, it does not seem to be suppressing tokens. I adapted the code a bit to set suppress_tokens and to use my dataset in Spanish, but everything else of importance remained unchanged.

When fine-tuning normally, at least for WhisperForConditionalGeneration, setting model.config.suppress_tokens = suppress works. But I'm not sure it's working here. Furthermore, my training dataset (about 6100 audio files) does not use any punctuation marks, that is, "," is "comma" and not the actual character. So I at least expect the model to learn not to use punctuation from the data itself, even if I do not suppress them explicitly during inference, but it's not what is happening here -- even if I train for 15 epochs -- the same number of epochs I use in my "normal" pipeline that does not use dynamic audio context.

Also, in the code below, suppressing tokens only works during inference if I set these in the model.generate function:

model.model = model_train.eval().cuda()
predicted_ids_train = model.generate(
    input_features,
    suppress_tokens=suppress,
    forced_decoder_ids=processor.get_decoder_prompt_ids(
        language=language, task="transcribe"
    ),
)

This made me suspect that it's not suppressing tokens during training. But I don't know how to verify this.

Full code (except loading the dataset, but nothing unusual there):

#!/usr/bin/env python3

import shutil
import torch

from datasets import load_dataset
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from torch import nn
from transformers import (
    WhisperModel,
    WhisperTokenizer,
    WhisperProcessor,
    WhisperForConditionalGeneration,
)
from dataset import CSVDataset
from sys import argv

EPOCHS = 3
SEED = 42
SUPPRESS_TOKENS_TRAIN = "0123456789@#%&*+=_$€:-.,?¿;!¡"

languages = {
    "spanish": "es",
    "portuguese": "pt",
    "es": "es",
    "pt": "pt",
}

MODEL = argv[1]
language = languages[argv[2].lower()]
data_path = Path(argv[3])


def get_suppress_tokens(tokenizer: WhisperTokenizer, tokens: str):
    suppressed_tokens = [
        i
        for i in range(tokenizer.eos_token_id)
        if all(c in tokens for c in tokenizer.decode([i]).removeprefix(" "))
    ]
    return suppressed_tokens


processor = WhisperProcessor.from_pretrained(
    f"openai/whisper-{MODEL}", task="transcribe", language=language
)


model_train = WhisperModel.from_pretrained(f"openai/whisper-{MODEL}")
model_base = WhisperModel.from_pretrained(f"openai/whisper-{MODEL}")

suppress = []
suppress = [-1] + get_suppress_tokens(processor.tokenizer, SUPPRESS_TOKENS_TRAIN)
model_train.config.suppress_tokens = suppress
model_base.config.suppress_tokens = suppress
model_train.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
    language=language, task="transcribe"
)
model_base.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
    language=language, task="transcribe"
)

model_train = model_train.train().cuda()
model_base = model_base.eval().cuda()


ds_full = CSVDataset(
    data_path=data_path,
    processor=processor,
    add_silence=False,
    ignore_datasets=[],
    audio_augs=True,
)

ds = ds_full.dataset["train"]

# ds = load_dataset("google/fleurs", "en_us", split="train")


def get_sample(example):
    waveform = example["audio"]["array"]
    sampling_rate = example["audio"]["sampling_rate"]

    # Use the model and processor to transcribe the audio:
    input_features = processor(
        waveform, sampling_rate=sampling_rate, return_tensors="pt"
    ).input_features

    return {
        "length": len(waveform) / sampling_rate,
        "input_features": input_features,
        "input_ids": processor.tokenizer.encode(example["transcript"].lower()),
    }


# if not (".en" in MODEL):
#    print(processor.get_decoder_prompt_ids(language="english", task="transcribe"))
# [processor.tokenizer.decode(i) for i in get_sample(ds[1])["input_ids"]]


def compute_partially_encoder(model, data, n_audio_ctx):
    diffy = 2 * n_audio_ctx - data.shape[2]

    if diffy > 0:
        data = nn.functional.pad(data, [0, diffy, 0, 0, 0, 0], "constant", 0.0)
    elif diffy < 0:
        data = data[:, :, :diffy]

    if n_audio_ctx == 1500:
        return model.encoder(data).last_hidden_state

    input_embeds = nn.functional.gelu(model.encoder.conv1(data))
    input_embeds = nn.functional.gelu(model.encoder.conv2(input_embeds))
    input_embeds = input_embeds.permute(0, 2, 1)

    embed_pos = model.encoder.embed_positions.weight[:n_audio_ctx]

    hidden_states = input_embeds + embed_pos
    hidden_states = nn.functional.dropout(
        hidden_states, p=model.encoder.dropout, training=model.encoder.training
    )

    for idx, encoder_layer in enumerate(model.encoder.layers):
        to_drop = False
        if model.encoder.training:
            dropout_probability = torch.rand([])
            if dropout_probability < model.encoder.layerdrop:
                to_drop = True

        if to_drop:
            layer_outputs = (None, None)
        else:
            if model.encoder.gradient_checkpointing and model.encoder.training:
                layer_outputs = model.encoder._gradient_checkpointing_func(
                    encoder_layer.__call__,
                    hidden_states,
                    None,
                    None,
                    False,
                )
            else:
                layer_outputs = encoder_layer(
                    hidden_states,
                    None,
                    layer_head_mask=None,
                    output_attentions=False,
                )

            hidden_states = layer_outputs[0]

    hidden_states = model.encoder.layer_norm(hidden_states)
    return hidden_states


def compute_hidden_state_loss(model_train, model_base, optimizer, criterion, example):
    optimizer.zero_grad()

    n_ctx = int(round((1500.0 / 30.0) * example["length"]))

    extra_ctx = torch.randint(-min(64, n_ctx // 3), min(64, n_ctx // 3), (1,)).item()
    n_ctx += extra_ctx

    input_features = example["input_features"].cuda()
    input_ids = torch.tensor([example["input_ids"]], dtype=torch.long).cuda()

    encoder_hidden_states_partial = compute_partially_encoder(
        model_train, input_features, n_ctx
    )
    output_partial = model_train.decoder(
        input_ids=input_ids,
        encoder_hidden_states=encoder_hidden_states_partial,
        output_hidden_states=True,
    )

    with torch.no_grad():
        encoder_hidden_states_full = compute_partially_encoder(
            model_base, input_features, 1500
        )
        output_full = model_base.decoder(
            input_ids=input_ids,
            encoder_hidden_states=encoder_hidden_states_full,
            output_hidden_states=True,
        )
        # print(output_full)

    loss = criterion(
        # output_partial.hidden_states[-1],
        # output_full.hidden_states[-1]
        torch.cat(output_partial.hidden_states, 0),
        torch.cat(output_full.hidden_states, 0),
    )

    loss.backward()
    optimizer.step()

    return loss


criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model_train.parameters(), lr=1e-6)


writer = SummaryWriter()
writer.add_text("name", f"{MODEL} v3")

num_length = 0
step = 0
for epoch in range(EPOCHS):
    pbar = tqdm(ds.shuffle(seed=SEED))
    for example in pbar:
        example = get_sample(example)
        if example["length"] > 29.0:
            continue

        loss = compute_hidden_state_loss(
            model_train, model_base, optimizer, criterion, example
        )
        step += 1
        num_length += example["length"]

        writer.add_scalar("loss/train", loss.item(), step)
        writer.add_scalar("length/train", num_length, step)
        writer.add_scalar("epoch/train", epoch, step)

        pbar.set_description(f"Epoch {epoch}, Loss: {loss.item()}")


# Select an audio file and read it:
# ds_eval = load_dataset(
#    "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
# )
ds_eval = ds_full.dataset["val"]

# Load the Whisper model in Hugging Face format:
model = WhisperForConditionalGeneration.from_pretrained(f"openai/whisper-{MODEL}")

model.config.suppress_tokens = suppress
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
    language=language, task="transcribe"
)
model = model.eval().cuda()

for i in range(len(ds_eval)):
    audio_sample = ds_eval[i]["audio"]
    waveform = audio_sample["array"]
    sampling_rate = audio_sample["sampling_rate"]
    # print(ds_eval[i])
    # input_features = ds_eval[i]["input_features"].cuda()

    # Use the model and processor to transcribe the audio:
    input_features = processor(
        waveform, sampling_rate=sampling_rate, return_tensors="pt"
    ).input_features.cuda()

    model.model = model_base.eval().cuda()
    # Suppress tokens only work here, not when setting in the model.config. Why?
    predicted_ids_base = model.generate(
        input_features,
        suppress_tokens=suppress,
        forced_decoder_ids=processor.get_decoder_prompt_ids(
            language=language, task="transcribe"
        ),
    )
    model.model = model_train.eval().cuda()
    predicted_ids_train = model.generate(
        input_features,
        suppress_tokens=suppress,
        forced_decoder_ids=processor.get_decoder_prompt_ids(
            language=language, task="transcribe"
        ),
    )

    # Decode token ids to text
    transcription = processor.batch_decode(
        [predicted_ids_base[0], predicted_ids_train[0]], skip_special_tokens=True
    )

    # Use self.tokenizer._basic_normalize(pred).strip() to normalize the transcriptions
    transcription = [
        processor.tokenizer._basic_normalize(pred).strip() for pred in transcription
    ]

    print(
        f"\n\nGrndTr: {ds_eval[i]['transcript'].lower()}\nModelB: {transcription[0]}\nModelT: {transcription[1]}"
    )

model = (
    WhisperForConditionalGeneration.from_pretrained(f"openai/whisper-{MODEL}")
    .eval()
    .cpu()
)
model.config.suppress_tokens = suppress
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
    language=language, task="transcribe"
)
model.model = model_train.eval().cpu()

model.save_pretrained(f"model_train-{MODEL}3")
processor.tokenizer.save_pretrained(f"model_train-{MODEL}3")

shutil.make_archive(f"model_train-{MODEL}3", "zip", f"model_train-{MODEL}3")

Finally, when running inference on a test set using whispercpp and the fine-tuned model, I get a WER of 39% instead of the usual 7% that I get when trained "normally".

I at least expect the model to learn not to use punctuation from the data itself, even if I do not suppress them explicitly during inference, but it's not what is happening here

During training with this method, the model is not trained to predict the given labels exactly, but rather it's trained to have the same hidden states as the original model. Meaning, if the original model thinks the word "comma" in "hello comma how are you" is unlikely, the finetuned model will also think it's unlikely. This is in contrast to regular finetuning, where the finetuned model gets nudged to believe the training data is probable.

It's useful to think of this finetuning process more as a distillation process, where the training data just so happens to be a reference for how the models behave, rather than a source of truth that the model is trained to match. If you want your model to have specific behavior/domain-knowledge and be able to use audio_ctx, you'll have to finetune them normally first, and then use this method to make it more robust with audio_ctx.

I'm not entirely sure about the suppress_tokens issue you mention though

I at least expect the model to learn not to use punctuation from the data itself, even if I do not suppress them explicitly during inference, but it's not what is happening here

During training with this method, the model is not trained to predict the given labels exactly, but rather it's trained to have the same hidden states as the original model. Meaning, if the original model thinks the word "comma" in "hello comma how are you" is unlikely, the finetuned model will also think it's unlikely. This is in contrast to regular finetuning, where the finetuned model gets nudged to believe the training data is probable.

It's useful to think of this finetuning process more as a distillation process, where the training data just so happens to be a reference for how the models behave, rather than a source of truth that the model is trained to match. If you want your model to have specific behavior/domain-knowledge and be able to use audio_ctx, you'll have to finetune them normally first, and then use this method to make it more robust with audio_ctx.

I'm not entirely sure about the suppress_tokens issue you mention though

Oh, this changes everything, thanks. I believe I misunderstood, then. I suppose that if I use my already fine-tuned model and "distil" it using your code, then things will work. I'll do this and post the results here for history.

Hm. I've not been so successful. I ran whispercpp's main with varying audio context on a benchmark dataset, and these are the results:

ac1500 (default model)
eval_spanish_base-20240328_6k16b_1500/log_2024-04-12T12:59.txt
2024-04-12 13:02:24 __main__     INFO     WER: 0.06779981114258735 (6.78%)

ac768 (default model)
eval_spanish_base-20240328_6k16b_768/log_2024-04-12T13:03.txt
9:2024-04-12 13:05:53 __main__     INFO     WER: 0.07034938621340887 (7.03%)

# Below are whisper-actf models (fine-tuned over my "default" model over 8 epochs):

ac1500
actf_spanish_base-20240328_6k16b_1500/log_2024-04-12T12:35.txt
2024-04-12 12:38:43 __main__     INFO     WER: 0.15401322001888573 (15.40%)

ac768
actf_spanish_base-20240328_6k16b_768/log_2024-04-12T12:43.txt
2024-04-12 12:46:00 __main__     INFO     WER: 0.1474032105760151 (14.74%)

ac512
actf_spanish_base-20240328_6k16b_512/log_2024-04-12T12:46.txt
2024-04-12 12:48:34 __main__     INFO     WER: 0.1457979225684608 (14.58%)

ac512 (greedy)
actf_spanish_base-20240328_6k16b_512_greedy/log_2024-04-12T12:54.txt
2024-04-12 12:56:32 __main__     INFO     WER: 0.15920679886685551 (15.92%)

ac500
actf_spanish_base-20240328_6k16b_500/log_2024-04-12T12:51.txt
2024-04-12 12:54:08 __main__     INFO     WER: 0.14513692162417374 (14.51%)

ac256
actf_spanish_base-20240328_6k16b_256/log_2024-04-12T12:49.txt
2024-04-12 12:51:28 __main__     INFO     WER: 0.42360717658168084 (42.36%)

WER doubles essentially, even if I use the default ac.

Just to confirm, you made sure to replace both model_train and model_base with copies of your finetuned model? What if you try finetuning with the default dataset instead of the custom one?