tazz4843/whisper-rs

Significant performance penalty w.r.t. whisper.cpp

wdoppenberg opened this issue · 10 comments

First of all, thank you for your work creating a safe wrapper around whisper.cpp.

As mentioned in #73, the performance of whisper-rs is quite poor compared to the reference implementation. I'll attempt to demonstrate below.

Setup

I'm using an M2 Max Macbook Pro with 64GB of (shared) memory. My goal is to run a web server with CoreML enabled, but if necessary I can run the tests with CPU only later. I'll attach output generated by flamegraph.

Rust script

Click me
// src/main.rs
mod utils;

use clap::Parser;
use anyhow::Result;
use log::info;

use whisper_rs::{FullParams, SamplingStrategy, WhisperContext};
use crate::utils::{decode, forward_pass, load_audio};

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct WhisperCli {
    #[arg(short, long, default_value = "sample_data/jfk.wav")]
    input_file: String,
    #[arg(short, long, default_value = "models/ggml-medium.bin")]
    model_path: String,
}


fn main() -> Result<()> {
    env_logger::init();
    let args = WhisperCli::parse();

    let mut params = FullParams::new(
        SamplingStrategy::BeamSearch { beam_size: 12, patience: 0.1}
    );

    params.set_n_threads(8);
    params.set_translate(false);


    info!("Loading audio file {}", args.input_file);
    let raw_audio = load_audio(&args.input_file)?;

    info!("Decoding audio file");
    let decoded_audio = decode(raw_audio)?;

    info!("Loading model from {}", args.model_path);
    let whisper_context = WhisperContext::new(&args.model_path)?;
    let mut whisper_state = whisper_context.create_state()?;

    info!("Running forward pass");
    let result = forward_pass(&decoded_audio, params, &mut whisper_state);

    println!("{}", result);
    Ok(())
}
// src/utils.rs
use anyhow::Result;
use std::fs::File;
use std::io::Read;

use std::io::Cursor;
use log::{debug, trace};
use rodio::{Decoder, Source};
use rodio::source::UniformSourceIterator;
use whisper_rs::{FullParams, WhisperState};

const SAMPLE_RATE: u32 = 16000;
const CHANNELS: u16 = 1;

const LOW_PASS: u32 = 3000;
const HIGH_PASS: u32 = 200;

/// Load an audio file from the given path
/// and return as a vector of u8
///
/// # Arguments
/// * `path` - The path to the audio file
pub fn load_audio(path: &str) -> Result<Vec<u8>> {
    let mut file = File::open(path)?;
    let mut buffer = Vec::new();

    file.read_to_end(&mut buffer)?;

    Ok(buffer)
}


/// Decode the audio file and return as a vector of f32
pub fn decode(bytes: Vec<u8>) -> Result<Vec<f32>> {
    // Decode the audio file
    let input = Cursor::new(bytes);
    let source = Decoder::new(input)?;

    // Resample to output sample rate and channels
    let resample = UniformSourceIterator::new(
        source, CHANNELS, SAMPLE_RATE,
    );
    // High and low pass filters to enhance the audio
    let pass_filter = resample
        .low_pass(LOW_PASS)
        .high_pass(HIGH_PASS)
        .convert_samples();

    Ok(whisper_rs::convert_integer_to_float_audio(&pass_filter.collect::<Vec<i16>>()))
}


pub fn forward_pass(decoded_audio: &[f32], params: FullParams, whisper_state: &mut WhisperState) -> String {
    debug!("Starting forward pass");
    let _ = whisper_state
        .full(params, decoded_audio)
        .expect("Failed to run model");
    let mut result = String::new();

    debug!("Decoding results");
    // fetch the results
    let num_segments = whisper_state
        .full_n_segments()
        .expect("failed to get number of segments");

    for i in 0..num_segments {
        let segment = whisper_state
            .full_get_segment_text(i)
            .expect("failed to get segment");
        let start_timestamp = whisper_state
            .full_get_segment_t0(i)
            .expect("failed to get segment start timestamp");
        let end_timestamp = whisper_state
            .full_get_segment_t1(i)
            .expect("failed to get segment end timestamp");
        trace!("[{} - {}]: {}", start_timestamp, end_timestamp, segment);
        result.push_str(&format!("[{} - {}]: {}\n", start_timestamp, end_timestamp, segment));
    }
    result
}
# Cargo.toml
[package]
name = "whisper-rs-cli"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
anyhow = "1.0.71"
whisper-rs = { version = "0.8.0", features = ["coreml"] }
clap = { version = "4.3.12", features = ["derive"]}
env_logger = "0.10.0"
log = "0.4.19"
rodio = "0.17.1"


[[bin]]
    name = "whisper-rs-cli"
    path = "src/main.rs"

Data

For testing, I've converted a short JFK speech to WAV. See this link. Converting to WAV is done using ffmpeg:

ffmpeg -i main.mp4 -acodec pcm_s16le -ac 1 -ar 16000 jfk.wav

Results

I won't do averages of iterations since the differences are quite clear. Furthermore I've run both scripts before to ensure that the CoreML model is properly compiled for my architecture. I can confirm that the model loading step is not the issue. The chosen model is ggml-medium.bin. The commands used are, given that you have compiled whisper.cpp & whisper-rs and are in the root of each repository, as follows:

sudo time flamegraph -- ./main -m models/ggml-medium.bin -t 8  -f jfk.wav
sudo time flamegraph -- target/release/whisper-rs-cli

CoreML enabled

whisper-rs

85.27 real       537.47 user         6.41 sys

Rust Flamegraph

whisper.cpp

12.30 real        42.14 user        15.36 sys

Cpp Flamegraph

Please let me know what you think and where I can help out. Admittedly I'm a bit inexperienced with Rust but I'd love to learn, especially solving such an issue.

I was able to reproduce this, and am working on it with @tazz4843

At this point the only thing I have that could be the difference is that by default whisper-rs is using v1.4.2 of whisper.cpp, while you're likely using git master upstream for these tests. Seems that upgrading whisper-rs's version to git master speeds it up somewhat, but I don't have Apple Silicon myself to test on so I can't do much myself in terms of poking around.

My test isn't the same but runs 2 seconds slower after updating whisper-rs to use whisper.cpp master

I've also pulled main for whisper.cpp. I had to edit 2 function signatures in whisper-rs to get it to work. I had to add *mut whisper_context as an argument to the following:

/// Get the ID of the translate task token.
///
/// # C++ equivalent
/// `whisper_token whisper_token_translate ()`
pub fn token_translate(ctx: *mut whisper_context) -> WhisperToken {
    unsafe { whisper_rs_sys::whisper_token_translate(ctx) }
}

/// Get the ID of the transcribe task token.
///
/// # C++ equivalent
/// `whisper_token whisper_token_transcribe()`
pub fn token_transcribe(ctx: *mut whisper_context) -> WhisperToken {
    unsafe { whisper_rs_sys::whisper_token_transcribe(ctx) }
}

Only slight improvements:

77.39 real       354.37 user       118.68 sys

@wdoppenberg do you have any guesses as to why there's such a big performance discrepancy?

Given I don't have Apple Silicon myself to test on, I can't do much to help with this besides suggest x86 instead. Hopefully someone else can figure it out.

I have M1 and M2 and can try it out if someone provides a repo with a test case that can be run.

right now, for the canonical "your country" jfk test case, whisper-rs and whisper.cpp are the same.

I put a timer around whisper.cpp whisper_full_params - it took 1612.61 ms

I put a timer around state.full -> state.get_segment_text - it took 1648 ms

C code changes:

     struct whisper_full_params   params,
                    const float * samples,
                            int   n_samples) {
-    // clear old results
+
+   struct timeval start_time, end_time;
+    gettimeofday(&start_time, NULL);
+  // clear old results
     auto & result_all = state->result_all;

     result_all.clear();
@@ -4761,7 +4768,16 @@ int whisper_full_with_state(
         }
     }

-    return 0;
+      gettimeofday(&end_time, NULL);
+
+    // Calculate the elapsed time in milliseconds
+    double elapsed_ms = (end_time.tv_sec - start_time.tv_sec) * 1000.0 +
+                       (end_time.tv_usec - start_time.tv_usec) / 1000.0;
+
+    // Print the result to stdout
+    printf("Elapsed time: %.2f ms\n", elapsed_ms);
+
+     return 0;
 }
❯ time ./main -m models/ggml-medium.en.bin -t 8  -f jfk.wav
whisper_init_from_file_no_state: loading model from 'models/ggml-medium.en.bin'
whisper_model_load: loading model
whisper_model_load: n_vocab       = 51864
whisper_model_load: n_audio_ctx   = 1500
whisper_model_load: n_audio_state = 1024
whisper_model_load: n_audio_head  = 16
whisper_model_load: n_audio_layer = 24
whisper_model_load: n_text_ctx    = 448
whisper_model_load: n_text_state  = 1024
whisper_model_load: n_text_head   = 16
whisper_model_load: n_text_layer  = 24
whisper_model_load: n_mels        = 80
whisper_model_load: ftype         = 1
whisper_model_load: qntvr         = 0
whisper_model_load: type          = 4
whisper_model_load: mem required  = 1899.00 MB (+   43.00 MB per decoder)
whisper_model_load: adding 1607 extra tokens
whisper_model_load: model ctx     = 1462.58 MB
whisper_model_load: model size    = 1462.12 MB
whisper_init_state: kv self size  =   42.00 MB
whisper_init_state: kv cross size =  140.62 MB
whisper_init_state: loading Core ML model from 'models/ggml-medium.en-encoder.mlmodelc'
whisper_init_state: first run on a device may take a while ...
whisper_init_state: Core ML model loaded

system_info: n_threads = 8 / 8 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | COREML = 1 | OPENVINO = 0 |

main: processing 'jfk.wav' (176000 samples, 11.0 sec), 8 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...


[00:00:00.000 --> 00:00:11.000]   And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
Elapsed time: 1612.61 ms

whisper-rs:

whisper_init_state: first run on a device may take a while ...
whisper_init_state: Core ML model loaded
[src/main.rs:66] idx = 1041
1041 [00:00:10.410]  And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
[src/main.rs:98] elapsed = 1648
 65             if let Some(frames) = buf.add(idx, mel) {
 66                 dbg!(idx);
 67                 let path = format!("{}/frame_{}.tga", mel_path, idx);
 68                 let _ = save_tga_8bit(&frames, n_mels, &path);
 69
 70                 let ms = duration_ms_for_n_frames(hop_size, sampling_rate, idx);
 71                 let time = format_milliseconds(ms as u64);
 72
 73                 let start = std::time::Instant::now();
 74                 let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 });
 75                 params.set_n_threads(6);
 76                 params.set_single_segment(true);
 77                 params.set_language(Some("en"));
 78                 params.set_print_special(false);
 79                 params.set_print_progress(false);
 80                 params.set_print_realtime(false);
 81                 params.set_print_timestamps(false);
 82                 state.set_mel(&frames).unwrap();
 83
 84                 let empty = vec![];
 85                 state.full(params, &empty[..]).unwrap();
 86
 87                 let num_segments = state.full_n_segments().unwrap();
 88                 if num_segments > 0 {
 89                     if let Ok(text) = state.full_get_segment_text(0) {
 90                         let msg = format!("{} [{}] {}", idx, time, text);
 91                         println!("{}", msg);
 92                     } else {
 93                         println!("Error retrieving text for segment.");
 94                     }
 95                 }
 96
 97                 let elapsed = start.elapsed().as_millis();
 98                 dbg!(elapsed);
 99             }

I'm using the set_mel api but there's no reason to think pcm audio will be slower rust vs c. If someone can provide a test case that can be run easily I can try them out

Still haven't been able to find the issue. When examining the call tree I find that, as one might expect, almost all calls are part of the ggml_compute_forward_mul_mat when running in a single thread.

Screenshot 2023-08-01 at 09 05 02

It almost feels like the whisper.cpp lib is not compiled with optimization flags enabled, which is not the case ofcourse.

think it's your machine?

I have found the issue:

In my script, I used SamplingStrategy::BeamSearch { beam_size: 12, patience: 0.1}. This is quite heavy...

Using SamplingStrategy::Greedy { best_of: 0 } it runs in roughly 10.11 seconds (with a sync'd whisper.cpp submodule).