tazz4843/whisper-rs

Cannot create mutable state for Axum server

Closed this issue · 8 comments

I'm trying to create a whisper-rs server using Axum. In this server I'd like to only have to create the whisper state once at server startup. Hence, I created an AppState struct that could then be passed around using an Arc<Mutex<AppState>>.

However, I seem to be having issues with ownership (classic) when implementing ::new for this struct, and I have the idea that it may have to do with the underlying implementation of WhisperContext and WhisperState, since they using raw pointers underneath, and a lifetime specifier for a PhantomData<&'a WhisperContext> as a field for WhisperState.

My implementation is as follows:

use log::info;
use whisper_rs::{WhisperContext, WhisperState};

/// Application state
///
/// This struct holds the whisper context and state necessary for transcribing
///
/// # Fields
/// * `whisper_ctx` - The whisper context
/// * `whisper_state` - The whisper state
pub struct AppState<'ctx> {
    pub whisper_ctx: WhisperContext,
    pub whisper_state: WhisperState<'ctx>
}

impl<'ctx> AppState<'ctx> {
    /// Create a new AppState struct
    ///
    /// # Arguments
    /// * `model_path` - The path to the model file, e.g. "models/ggml-medium.bin"
    ///
    /// # Returns
    /// * `AppState` - The AppState struct
    ///
    /// # Example
    /// ```
    /// let model_path = "models/ggml-medium.bin";
    /// let app_state = AppState::new(model_path);
    /// ```
    fn new(model_path: &'_ str) -> Self {
        info!("Loading model from {}", model_path);
        // Strip the file name from the path
        let model_name = model_path.split('/').last().unwrap().to_string();
        info!("Model name: {}", model_name);

        // load a context and model
        let whisper_ctx = WhisperContext::new(model_path).expect("failed to load model");
        let whisper_state = whisper_ctx.create_state().expect("failed to create state");

        // Create AppState struct with owned data
        Self {
            whisper_ctx,
            whisper_state
        }
    }
}

This yields the following compilation error:

   |
38 |           let whisper_state = whisper_ctx.create_state().expect("failed to create state");
   |                               -------------------------- `whisper_ctx` is borrowed here
...
41 | /         Self {
42 | |             whisper_ctx,
43 | |             whisper_state
44 | |         }
   | |_________^ returns a value referencing data owned by the current function

Perhaps this has got to do with my understanding of Rust, in which case please tell me so. So far I've tried most things that conventional googling and/or ChatGPT would suggest such as wrapping the struct's field types in Arc, but to no avail.

Otherwise I'm curious to hear your thoughts, any help is greatly appreciated. I've unfortunately been stuck for some time.

I believe what you have here is a self-referencing struct. The compiler is currently unable to figure out that WhisperState does indeed live as long as WhisperContext, and some hacky tricks with raw pointers are required. I use ouroboros for this: https://users.rust-lang.org/t/ouroboros-a-crate-for-making-self-referential-structs/49025/ https://crates.io/crates/ouroboros

I see... I haven't been successful yet using 'ouroboros'. But I'll report back when I have something useful

So this works:

#[self_referencing(pub_extras)]
#[allow(unused_imports)]
pub struct AppState {
    pub whisper_ctx: WhisperContext,
    #[borrows(whisper_ctx)]
    #[covariant]
    pub whisper_state: WhisperState<'this>
}

Constructing a new instance with:

let model_path = "models/ggml-medium.bin";

let whisper_context = whisper_rs::WhisperContext::new(&model_path)
        .expect("Failed to load model");

let app_state = Arc::new(Mutex::new(
    AppStateAsyncSendBuilder {
        whisper_ctx: whisper_context,
        whisper_state_builder: |whisper_ctx| {
            Box::pin(async move {
                whisper_ctx.create_state().unwrap()
            })
        }
    }.build().await
));

And the endpoint:

async fn transcribe(
    state: State<Arc<Mutex<AppState>>>,
    audio: Bytes
) -> String {
    debug!("Transcribing audio file of size {}", audio.len());
    let decoded_audio = utils::decode(audio.to_vec()).expect("Failed to decode audio");

    // create a params object
    let params = FullParams::new(
        SamplingStrategy::BeamSearch { beam_size: 12, patience: 0.1}
    );

    // Get a mutable reference to the whisper state
    let mut lock = state.lock().await;

    lock.with_whisper_state_mut(
        |whisper_state| {
            forward_pass(&decoded_audio, params, whisper_state)
        }
    )
}

I am unfortunately getting very poor performance compared to whisper.cpp, even though I'm running on an M2 Max Macbook pro with a CoreML model. Is this a known issue?

Might as well check the obvious first: do you have the coreml feature enabled? And are you running with --release?

Fair enough, but yes I am running with --release and have coreml enabled.

How is the poor performance manifesting? Since I know we had #67 that was loading and caching the model every time.

I'll attempt to create a proper comparison between this and whisper.cpp tomorrow.

Closed since my problem has been solved, created new issue (#74) for performance penalty