Training on lower precision
King-Rafat opened this issue · 2 comments
Hi, great work done here!
Have you tried training or inferring the models at a lower precision? What is the performance loss for that?
Hi @King-Rafat!
Based on my (rather limited) experiments, training SONAR models at half precision (float16
) can be sometimes unstable when computing cross-entropy loss for the decoder. So I would probably recommend float32
or some mixed precision for training.
However, the inference of SONAR text models in float16
seems to be totally fine.
The code snippet below illustrated how SONAR translation quality isn't affected by quantization.
Example
import datasets
import torch
from sacrebleu import BLEU
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline, EmbeddingToTextModelPipeline
# loading the models to GPU (by default, they are in float32 precision)
device = torch.device("cuda")
t2vec_model = TextToEmbeddingModelPipeline(encoder="text_sonar_basic_encoder",
tokenizer="text_sonar_basic_encoder", device=device)
vec2text_model = EmbeddingToTextModelPipeline(decoder="text_sonar_basic_decoder",
tokenizer="text_sonar_basic_encoder", device=device)
# setting the test dataset
src_lang, tgt_lang = "eng_Latn", "fra_Latn"
lang2flores = {
lang: datasets.load_dataset("facebook/flores", lang, trust_remote_code=True)
for lang in [src_lang, tgt_lang]
}
source = lang2flores[src_lang]['dev']['sentence']
target = lang2flores[tgt_lang]['dev']['sentence']
# computing the embeddings in two precisions
embs_32 = t2vec_model.predict(source, source_lang=src_lang, batch_size=32, progress_bar=True)
t2vec_model.half();
embs_16 = t2vec_model.predict(source, source_lang=src_lang, batch_size=32, progress_bar=True)
# translating each embeddings matrix into French in two precisions
pred_32x32 = vec2text_model.predict(embs_32, target_lang=tgt_lang, batch_size=32, progress_bar=True)
pred_16x32 = vec2text_model.predict(embs_16.to(torch.float32), target_lang=tgt_lang, batch_size=32, progress_bar=True)
vec2text_model.half();
pred_32x16 = vec2text_model.predict(embs_32.to(torch.float16), target_lang=tgt_lang, batch_size=32, progress_bar=True)
pred_16x16 = vec2text_model.predict(embs_16.to(torch.float16), target_lang=tgt_lang, batch_size=32, progress_bar=True)
# evaluating the quality (higher BLEU <=> better)
bleu_calc = BLEU()
print(bleu_calc.corpus_score(pred_32x32, [target]).score) # 45.35502456250957
print(bleu_calc.corpus_score(pred_16x32, [target]).score) # 45.41064939316419
print(bleu_calc.corpus_score(pred_32x16, [target]).score) # 45.385803594567314
print(bleu_calc.corpus_score(pred_16x16, [target]).score) # 45.42170584536023
Also, there is evidence that NLLB models can be quantized (e.g. with ctranslate2) even to int8
representations without serious performance degradations. And SONAR text models are essentially a fine-tuned NLLB model (but with a fixed-size representation bottleneck), so I would expect them to be quantizable to int8
as well.
Hi @avidale, thank you for your feedback!