Anush008/fastembed-rs

Slow inference compared to python version

Daniel-Kelvich opened this issue · 11 comments

Hey, I'm new to rust and may be doing something wrong, but I'm getting worse performance in rust that in python. I'm using mac with m1 pro.
Rust version:

init: 138.47ms
passage_embed: 579.84ms
query_embed: 147.26ms

Python version:

init: 106.26ms
passage_embed: 31.68ms
query_embed: 2.82ms

Here is my scripts:

use fastembed::{EmbeddingBase, EmbeddingModel, FlagEmbedding, InitOptions};
use std::time::Instant;

fn main() -> Result<(), Box<dyn std::error::Error>> {
    let now = Instant::now();

    // With custom InitOptions
    let model: FlagEmbedding = FlagEmbedding::try_new(InitOptions {
        model_name: EmbeddingModel::BGESmallEN,
        show_download_message: true,
        max_length: 512,
        ..Default::default()
    })?;

    let elapsed = now.elapsed();
    println!("init: {:.2?}", elapsed);

    let now = Instant::now();
    let documents = vec![
        "Maharana Pratap was a Rajput warrior king from Mewar",
        "He fought against the Mughal Empire led by Akbar",
        "The Battle of Haldighati in 1576 was his most famous battle",
        "He refused to submit to Akbar and continued guerrilla warfare",
        "His capital was Chittorgarh, which he lost to the Mughals",
        "He died in 1597 at the age of 57",
        "Maharana Pratap is considered a symbol of Rajput resistance against foreign rule",
        "His legacy is celebrated in Rajasthan through festivals and monuments",
        "He had 11 wives and 17 sons, including Amar Singh I who succeeded him as ruler of Mewar",
        "His life has been depicted in various films, TV shows, and books",
    ];
    let embeddings = model.passage_embed(documents, Some(1))?;
    let elapsed = now.elapsed();
    println!("passage_embed: {:.2?}", elapsed);

    let now = Instant::now();
    let query = "Who was Maharana Pratap?";
    let query_embed = model.query_embed(query)?; 
    let elapsed = now.elapsed();
    println!("query_embed: {:.2?}", elapsed);

    Ok(())
}
from typing import List
import numpy as np
from fastembed.embedding import FlagEmbedding as Embedding
import time

t1=time.time()
embedding_model = Embedding(model_name="BAAI/bge-small-en", max_length=512)
print(f'init: {(time.time()-t1)*1000:.02f}ms')

t1=time.time()
documents: List[str] = [
    "Maharana Pratap was a Rajput warrior king from Mewar",
    "He fought against the Mughal Empire led by Akbar",
    "The Battle of Haldighati in 1576 was his most famous battle",
    "He refused to submit to Akbar and continued guerrilla warfare",
    "His capital was Chittorgarh, which he lost to the Mughals",
    "He died in 1597 at the age of 57",
    "Maharana Pratap is considered a symbol of Rajput resistance against foreign rule",
    "His legacy is celebrated in Rajasthan through festivals and monuments",
    "He had 11 wives and 17 sons, including Amar Singh I who succeeded him as ruler of Mewar",
    "His life has been depicted in various films, TV shows, and books",
]

embeddings: List[np.ndarray] = list(
    embedding_model.passage_embed(documents)
)  # notice that we are casting the generator to a list
print(f'passage_embed: {(time.time()-t1)*1000:.02f}ms')

t1=time.time()
query = "Who was Maharana Pratap?"
query_embedding = list(embedding_model.query_embed(query))[0]
print(f'query_embed: {(time.time()-t1)*1000:.02f}ms')

Hai. Can your try by running your Rust program in release mode if you aren't already?

Also, 2.8 ms seems inaccurate.

Thanks for a quick reply
This is with --release.

init: 106.64ms
passage_embed: 557.83ms
query_embed: 126.10ms

I've also checked python code, and 3ms seems to be correct. There is no cuda, so the code should be synchronous.

Also, from htop it seems like rust version is using considerably more cpu.

This is definitely unexpected. Maybe we can try profiling the code to see what parts are taking time.
Also, moving to ort v2.

@Daniel-Kelvich could you try for larger list of documents. That is *1000.
Rayon should perform significantly better in that case.

If the results improve, we can infer that we'll have to use Rayon only for larger data and native iterations for smaller ones because of the time Rayon takes to provision workers.

I've tried 1000 sentences, and it took 50sec. Seems to be scaling linearly.
I'll try the same code on linux machine.

Linux inference took even longer, but the processor is weaker than m1 pro.

Thanks @Daniel-Kelvich for raising this issue.
I'll try to look into it when I find time.

I've begun migrating the code to the new version of ORT.

Resolved in 7e0526a.