rustformers/llm

Parallel loading of the model tensors

philpax opened this issue · 5 comments

People have reported faster loading of the models in upstream when the tensors are loaded in parallel: ggerganov/llama.cpp#85

This should be pretty easy to do with Rust if we convert loading to an iter and then use par_iter instead. It seems like this should be I/O bound, but perhaps the actual loading process has computational overhead?

Sort of related to speeding up loading, I've been messing around with rewriting it to use a mmap-based approach and nom. I don't know if it's really on the right track.

This is what just loading the header and vocabulary looks like:

pub mod mmap_loader {
    use mmap_rs::{MmapFlags, MmapOptions};
    #[allow(unused_imports)]
    use nom::{
        branch::alt,
        bytes::complete as nby,
        combinator as ncom,
        error::ParseError,
        multi as nm,
        number::complete::{self as nnum, le_f32, le_i32, le_u32},
        sequence as nseq, IResult, Parser, Slice,
    };
    use std::fs::File;

    use super::*;

    pub struct Flib;

    #[derive(Debug)]
    struct Header {
        legacy: bool,
        hyper: Hyperparameters,
    }

    impl Flib {
        fn parse_header(i: &[u8]) -> IResult<&[u8], Header> {
            let (i, magic) = le_i32(i)?;
            let legacy = match magic {
                ggml::FILE_MAGIC => false,
                ggml::FILE_MAGIC_UNVERSIONED => true,
                _ => return nom::error::context("ohno", ncom::fail)(i),
            };
            ncom::map(Flib::parse_hyperparameters, move |hyper| Header {
                legacy,
                hyper,
            })(i)
        }

        fn parse_hyperparameters(i: &[u8]) -> IResult<&[u8], Hyperparameters> {
            ncom::map(
                nseq::tuple((le_i32, le_i32, le_i32, le_i32, le_i32, le_i32, le_i32)),
                |(n_vocab, n_embd, n_mult, n_head, n_layer, n_rot, f16_)| Hyperparameters {
                    n_vocab,
                    n_ctx: 0,
                    n_embd,
                    n_mult,
                    n_head,
                    n_layer,
                    n_rot,
                    f16_,
                },
            )(i)
        }

        fn parse_vocabulary<'a>(i: &'a [u8], hdr: &Header) -> IResult<&'a [u8], Vocabulary> {
            const TOKEN_PLACEHOLDER: &str = "�";
            let n_vocab = hdr.hyper.n_vocab as usize;
            let legacy = hdr.legacy;
            let mut id_to_token = Vec::with_capacity(n_vocab);
            let mut id_to_token_score = Vec::with_capacity(n_vocab);
            let mut token_to_id = HashMap::with_capacity(n_vocab);
            let vocabitem_parser = |i| {
                nseq::tuple((nm::length_data(le_u32), ncom::cond(!legacy, le_f32)))(i)
                    .map(|(i, (sbytes, score))| (i, (sbytes, score.unwrap_or_default())))
            };
            let folf = |mut mtl: usize, (sbytes, score)| {
                let tid = id_to_token.len();
                let (ok, token) = std::str::from_utf8(sbytes).map_or_else(
                    |_| (false, TOKEN_PLACEHOLDER.to_string()),
                    |s| (true, s.to_string()),
                );
                if ok {
                    mtl = mtl.max(token.len());
                    token_to_id.insert(token.clone(), tid as TokenId);
                }
                id_to_token.push(token);
                id_to_token_score.push(score);
                mtl
            };
            let (i, max_token_length) =
                nm::fold_many_m_n(n_vocab, n_vocab, vocabitem_parser, || 0, folf)(i)?;
            IResult::Ok((
                i,
                Vocabulary {
                    id_to_token,
                    id_to_token_score,
                    token_to_id,
                    max_token_length,
                },
            ))
        }

        pub fn load(path: impl AsRef<Path>) -> Result<(), LoadError> {
            let path = path.as_ref();
            let fp = File::open(path).map_err(|e| LoadError::OpenFileFailed {
                source: e,
                path: path.to_owned(),
            })?;
            let flen = fp.metadata()?.len();
            let m = unsafe {
                MmapOptions::new(flen as usize).and_then(|mo| {
                    mo.with_file(fp, 0)
                        .with_flags(MmapFlags::NO_CORE_DUMP)
                        .map()
                })
            }
            .map_err(|e| LoadError::MmapFailed { source: e })?;
            let mb = m.as_slice();
            let (i, hdr) = Self::parse_header(mb).unwrap();
            println!("Got: {hdr:?}");
            let (i, vocab) = Self::parse_vocabulary(i, &hdr).unwrap();
            println!(
                "Got: {} - {} - {}",
                vocab.max_token_length,
                vocab.id_to_token.len(),
                vocab.token_to_id.len()
            );
            Ok(())
        }
    }
}

I honestly don't really love parsers in Rust, it's so much nicer in Haskell but I guess this is more readable than the current code. A long time ago, I experimented with trying to combine nom and monadic do type notation but it wasn't really practical: https://github.com/KerfuffleV2/mdoexperiments

Along the lines of programmatic parsing, it might also be interesting to explore the use of https://github.com/jam1garner/binrw.

Not sure how that would impact parallel loading or #93, though.

Interesting. Weirdly enough, that actually only has limited support for non-streams (i.e. mmap). I don't know if it would be necessary to use the seek features for handling the GGML format, but if so that would mean mmaping was impossible.

Don't really need mmap. smol+nuclei+2 fd should be enough.

With mmap support I'm not sure how relevant this is now. It doesn't do much actual work when setting up the tensors.