facebookresearch/SONAR

Training on lower precision

King-Rafat opened this issue · 2 comments

Hi, great work done here!
Have you tried training or inferring the models at a lower precision? What is the performance loss for that?

Hi @King-Rafat!
Based on my (rather limited) experiments, training SONAR models at half precision (float16) can be sometimes unstable when computing cross-entropy loss for the decoder. So I would probably recommend float32 or some mixed precision for training.

However, the inference of SONAR text models in float16 seems to be totally fine.
The code snippet below illustrated how SONAR translation quality isn't affected by quantization.

Example

import datasets
import torch
from sacrebleu import BLEU
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline, EmbeddingToTextModelPipeline

# loading the models to GPU (by default, they are in float32 precision)
device = torch.device("cuda")
t2vec_model = TextToEmbeddingModelPipeline(encoder="text_sonar_basic_encoder",
                                           tokenizer="text_sonar_basic_encoder", device=device)
vec2text_model = EmbeddingToTextModelPipeline(decoder="text_sonar_basic_decoder",
                                              tokenizer="text_sonar_basic_encoder", device=device)

# setting the test dataset
src_lang, tgt_lang = "eng_Latn", "fra_Latn"
lang2flores = {
    lang: datasets.load_dataset("facebook/flores", lang, trust_remote_code=True)
    for lang in [src_lang, tgt_lang]
}
source = lang2flores[src_lang]['dev']['sentence']
target = lang2flores[tgt_lang]['dev']['sentence']

# computing the embeddings in two precisions
embs_32 = t2vec_model.predict(source, source_lang=src_lang, batch_size=32, progress_bar=True)
t2vec_model.half();
embs_16 = t2vec_model.predict(source, source_lang=src_lang, batch_size=32, progress_bar=True)

# translating each embeddings matrix into French in two precisions
pred_32x32 = vec2text_model.predict(embs_32,  target_lang=tgt_lang, batch_size=32, progress_bar=True)
pred_16x32 = vec2text_model.predict(embs_16.to(torch.float32), target_lang=tgt_lang, batch_size=32, progress_bar=True)
vec2text_model.half();
pred_32x16 = vec2text_model.predict(embs_32.to(torch.float16),  target_lang=tgt_lang, batch_size=32, progress_bar=True)
pred_16x16 = vec2text_model.predict(embs_16.to(torch.float16), target_lang=tgt_lang, batch_size=32, progress_bar=True)

# evaluating the quality (higher BLEU <=> better)
bleu_calc = BLEU()

print(bleu_calc.corpus_score(pred_32x32, [target]).score)  # 45.35502456250957
print(bleu_calc.corpus_score(pred_16x32, [target]).score)  # 45.41064939316419
print(bleu_calc.corpus_score(pred_32x16, [target]).score)  # 45.385803594567314
print(bleu_calc.corpus_score(pred_16x16, [target]).score)  # 45.42170584536023

Also, there is evidence that NLLB models can be quantized (e.g. with ctranslate2) even to int8 representations without serious performance degradations. And SONAR text models are essentially a fine-tuned NLLB model (but with a fixed-size representation bottleneck), so I would expect them to be quantizable to int8 as well.

Hi @avidale, thank you for your feedback!