Significant performance penalty w.r.t. whisper.cpp
wdoppenberg opened this issue · 10 comments
First of all, thank you for your work creating a safe wrapper around whisper.cpp
.
As mentioned in #73, the performance of whisper-rs
is quite poor compared to the reference implementation. I'll attempt to demonstrate below.
Setup
I'm using an M2 Max Macbook Pro with 64GB of (shared) memory. My goal is to run a web server with CoreML enabled, but if necessary I can run the tests with CPU only later. I'll attach output generated by flamegraph
.
Rust script
Click me
// src/main.rs
mod utils;
use clap::Parser;
use anyhow::Result;
use log::info;
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext};
use crate::utils::{decode, forward_pass, load_audio};
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct WhisperCli {
#[arg(short, long, default_value = "sample_data/jfk.wav")]
input_file: String,
#[arg(short, long, default_value = "models/ggml-medium.bin")]
model_path: String,
}
fn main() -> Result<()> {
env_logger::init();
let args = WhisperCli::parse();
let mut params = FullParams::new(
SamplingStrategy::BeamSearch { beam_size: 12, patience: 0.1}
);
params.set_n_threads(8);
params.set_translate(false);
info!("Loading audio file {}", args.input_file);
let raw_audio = load_audio(&args.input_file)?;
info!("Decoding audio file");
let decoded_audio = decode(raw_audio)?;
info!("Loading model from {}", args.model_path);
let whisper_context = WhisperContext::new(&args.model_path)?;
let mut whisper_state = whisper_context.create_state()?;
info!("Running forward pass");
let result = forward_pass(&decoded_audio, params, &mut whisper_state);
println!("{}", result);
Ok(())
}
// src/utils.rs
use anyhow::Result;
use std::fs::File;
use std::io::Read;
use std::io::Cursor;
use log::{debug, trace};
use rodio::{Decoder, Source};
use rodio::source::UniformSourceIterator;
use whisper_rs::{FullParams, WhisperState};
const SAMPLE_RATE: u32 = 16000;
const CHANNELS: u16 = 1;
const LOW_PASS: u32 = 3000;
const HIGH_PASS: u32 = 200;
/// Load an audio file from the given path
/// and return as a vector of u8
///
/// # Arguments
/// * `path` - The path to the audio file
pub fn load_audio(path: &str) -> Result<Vec<u8>> {
let mut file = File::open(path)?;
let mut buffer = Vec::new();
file.read_to_end(&mut buffer)?;
Ok(buffer)
}
/// Decode the audio file and return as a vector of f32
pub fn decode(bytes: Vec<u8>) -> Result<Vec<f32>> {
// Decode the audio file
let input = Cursor::new(bytes);
let source = Decoder::new(input)?;
// Resample to output sample rate and channels
let resample = UniformSourceIterator::new(
source, CHANNELS, SAMPLE_RATE,
);
// High and low pass filters to enhance the audio
let pass_filter = resample
.low_pass(LOW_PASS)
.high_pass(HIGH_PASS)
.convert_samples();
Ok(whisper_rs::convert_integer_to_float_audio(&pass_filter.collect::<Vec<i16>>()))
}
pub fn forward_pass(decoded_audio: &[f32], params: FullParams, whisper_state: &mut WhisperState) -> String {
debug!("Starting forward pass");
let _ = whisper_state
.full(params, decoded_audio)
.expect("Failed to run model");
let mut result = String::new();
debug!("Decoding results");
// fetch the results
let num_segments = whisper_state
.full_n_segments()
.expect("failed to get number of segments");
for i in 0..num_segments {
let segment = whisper_state
.full_get_segment_text(i)
.expect("failed to get segment");
let start_timestamp = whisper_state
.full_get_segment_t0(i)
.expect("failed to get segment start timestamp");
let end_timestamp = whisper_state
.full_get_segment_t1(i)
.expect("failed to get segment end timestamp");
trace!("[{} - {}]: {}", start_timestamp, end_timestamp, segment);
result.push_str(&format!("[{} - {}]: {}\n", start_timestamp, end_timestamp, segment));
}
result
}
# Cargo.toml
[package]
name = "whisper-rs-cli"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0.71"
whisper-rs = { version = "0.8.0", features = ["coreml"] }
clap = { version = "4.3.12", features = ["derive"]}
env_logger = "0.10.0"
log = "0.4.19"
rodio = "0.17.1"
[[bin]]
name = "whisper-rs-cli"
path = "src/main.rs"
Data
For testing, I've converted a short JFK speech to WAV. See this link. Converting to WAV is done using ffmpeg
:
ffmpeg -i main.mp4 -acodec pcm_s16le -ac 1 -ar 16000 jfk.wav
Results
I won't do averages of iterations since the differences are quite clear. Furthermore I've run both scripts before to ensure that the CoreML model is properly compiled for my architecture. I can confirm that the model loading step is not the issue. The chosen model is ggml-medium.bin
. The commands used are, given that you have compiled whisper.cpp
& whisper-rs
and are in the root of each repository, as follows:
sudo time flamegraph -- ./main -m models/ggml-medium.bin -t 8 -f jfk.wav
sudo time flamegraph -- target/release/whisper-rs-cli
CoreML enabled
whisper-rs
85.27 real 537.47 user 6.41 sys
whisper.cpp
12.30 real 42.14 user 15.36 sys
Please let me know what you think and where I can help out. Admittedly I'm a bit inexperienced with Rust but I'd love to learn, especially solving such an issue.
I was able to reproduce this, and am working on it with @tazz4843
At this point the only thing I have that could be the difference is that by default whisper-rs is using v1.4.2 of whisper.cpp, while you're likely using git master upstream for these tests. Seems that upgrading whisper-rs's version to git master speeds it up somewhat, but I don't have Apple Silicon myself to test on so I can't do much myself in terms of poking around.
My test isn't the same but runs 2 seconds slower after updating whisper-rs to use whisper.cpp master
I've also pulled main for whisper.cpp. I had to edit 2 function signatures in whisper-rs
to get it to work. I had to add *mut whisper_context
as an argument to the following:
/// Get the ID of the translate task token.
///
/// # C++ equivalent
/// `whisper_token whisper_token_translate ()`
pub fn token_translate(ctx: *mut whisper_context) -> WhisperToken {
unsafe { whisper_rs_sys::whisper_token_translate(ctx) }
}
/// Get the ID of the transcribe task token.
///
/// # C++ equivalent
/// `whisper_token whisper_token_transcribe()`
pub fn token_transcribe(ctx: *mut whisper_context) -> WhisperToken {
unsafe { whisper_rs_sys::whisper_token_transcribe(ctx) }
}
Only slight improvements:
77.39 real 354.37 user 118.68 sys
@wdoppenberg do you have any guesses as to why there's such a big performance discrepancy?
Given I don't have Apple Silicon myself to test on, I can't do much to help with this besides suggest x86 instead. Hopefully someone else can figure it out.
I have M1 and M2 and can try it out if someone provides a repo with a test case that can be run.
right now, for the canonical "your country" jfk test case, whisper-rs and whisper.cpp are the same.
I put a timer around whisper.cpp whisper_full_params
- it took 1612.61 ms
I put a timer around state.full -> state.get_segment_text
- it took 1648 ms
C code changes:
struct whisper_full_params params,
const float * samples,
int n_samples) {
- // clear old results
+
+ struct timeval start_time, end_time;
+ gettimeofday(&start_time, NULL);
+ // clear old results
auto & result_all = state->result_all;
result_all.clear();
@@ -4761,7 +4768,16 @@ int whisper_full_with_state(
}
}
- return 0;
+ gettimeofday(&end_time, NULL);
+
+ // Calculate the elapsed time in milliseconds
+ double elapsed_ms = (end_time.tv_sec - start_time.tv_sec) * 1000.0 +
+ (end_time.tv_usec - start_time.tv_usec) / 1000.0;
+
+ // Print the result to stdout
+ printf("Elapsed time: %.2f ms\n", elapsed_ms);
+
+ return 0;
}
❯ time ./main -m models/ggml-medium.en.bin -t 8 -f jfk.wav
whisper_init_from_file_no_state: loading model from 'models/ggml-medium.en.bin'
whisper_model_load: loading model
whisper_model_load: n_vocab = 51864
whisper_model_load: n_audio_ctx = 1500
whisper_model_load: n_audio_state = 1024
whisper_model_load: n_audio_head = 16
whisper_model_load: n_audio_layer = 24
whisper_model_load: n_text_ctx = 448
whisper_model_load: n_text_state = 1024
whisper_model_load: n_text_head = 16
whisper_model_load: n_text_layer = 24
whisper_model_load: n_mels = 80
whisper_model_load: ftype = 1
whisper_model_load: qntvr = 0
whisper_model_load: type = 4
whisper_model_load: mem required = 1899.00 MB (+ 43.00 MB per decoder)
whisper_model_load: adding 1607 extra tokens
whisper_model_load: model ctx = 1462.58 MB
whisper_model_load: model size = 1462.12 MB
whisper_init_state: kv self size = 42.00 MB
whisper_init_state: kv cross size = 140.62 MB
whisper_init_state: loading Core ML model from 'models/ggml-medium.en-encoder.mlmodelc'
whisper_init_state: first run on a device may take a while ...
whisper_init_state: Core ML model loaded
system_info: n_threads = 8 / 8 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | COREML = 1 | OPENVINO = 0 |
main: processing 'jfk.wav' (176000 samples, 11.0 sec), 8 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
[00:00:00.000 --> 00:00:11.000] And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
Elapsed time: 1612.61 ms
whisper-rs:
whisper_init_state: first run on a device may take a while ...
whisper_init_state: Core ML model loaded
[src/main.rs:66] idx = 1041
1041 [00:00:10.410] And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
[src/main.rs:98] elapsed = 1648
65 if let Some(frames) = buf.add(idx, mel) {
66 dbg!(idx);
67 let path = format!("{}/frame_{}.tga", mel_path, idx);
68 let _ = save_tga_8bit(&frames, n_mels, &path);
69
70 let ms = duration_ms_for_n_frames(hop_size, sampling_rate, idx);
71 let time = format_milliseconds(ms as u64);
72
73 let start = std::time::Instant::now();
74 let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 });
75 params.set_n_threads(6);
76 params.set_single_segment(true);
77 params.set_language(Some("en"));
78 params.set_print_special(false);
79 params.set_print_progress(false);
80 params.set_print_realtime(false);
81 params.set_print_timestamps(false);
82 state.set_mel(&frames).unwrap();
83
84 let empty = vec![];
85 state.full(params, &empty[..]).unwrap();
86
87 let num_segments = state.full_n_segments().unwrap();
88 if num_segments > 0 {
89 if let Ok(text) = state.full_get_segment_text(0) {
90 let msg = format!("{} [{}] {}", idx, time, text);
91 println!("{}", msg);
92 } else {
93 println!("Error retrieving text for segment.");
94 }
95 }
96
97 let elapsed = start.elapsed().as_millis();
98 dbg!(elapsed);
99 }
I'm using the set_mel api but there's no reason to think pcm audio will be slower rust vs c. If someone can provide a test case that can be run easily I can try them out
Still haven't been able to find the issue. When examining the call tree I find that, as one might expect, almost all calls are part of the ggml_compute_forward_mul_mat
when running in a single thread.
It almost feels like the whisper.cpp lib is not compiled with optimization flags enabled, which is not the case ofcourse.
think it's your machine?
I have found the issue:
In my script, I used SamplingStrategy::BeamSearch { beam_size: 12, patience: 0.1}
. This is quite heavy...
Using SamplingStrategy::Greedy { best_of: 0 }
it runs in roughly 10.11 seconds (with a sync'd whisper.cpp
submodule).