spiraldb/fsst

Miri is incredibly slow

Closed this issue · 3 comments

a10y commented

I've tried running Miri with MIRI_LOG=info, it generates several GBs of logs but if you filter for miri::machine you get some simple enter/exit spans with timings. Still a bit hard to read, but at first glance some major contributors:

  • The codes_twobyte Vec in the compressor. 65,536 element vector takes many seconds to build. And it gets rebuilt for every new Compressor that is built during training
  • The implementation of optimize is where a lot of time in Miri is being spent

I'll upload a miri trace once I compress it small enough to attach

I generally found miri to be slow if you're allocating a lot of memory. Personally I wouldn't worry too much since that's the nature of interepreter focused on correctness. Maybe there's a compile time constant we can tweak to make allocations under Miri smaller?

a10y commented

Ok so I've actually narrowed it down a bit: allocations seem fine, and building large Vec's seems pretty quick now as well (I tried to use zeroed vecs where possible to avoid the Vec internals from calling clone).

What is really slow is the optimize method on Compressor. This is the method that looks at the occurrences of codes i the compressed sample and loops up to anywhere from 256^2 to 511^2 times to build a new symbol table.

Here's a version of the optimize method that has been augmented with some logging:

    /// Using a set of counters and the existing set of symbols, build a new
    /// set of symbols/codes that optimizes the gain over the distribution in `counter`.
    pub fn optimize(&self, counters: &Counter, include_ascii: bool) -> Self {
        println!("[optimize]");
        let mut res = Compressor::default();
        let mut pqueue = BinaryHeap::with_capacity(65_536);
        println!("{:?} | built pq", std::time::Instant::now());
        println!("{:?} | looping...", std::time::Instant::now());
        for code1 in 0u16..(256u16 + self.n_symbols as u16) {
            if code1 % 10 == 0 {
                println!("{:?} | outer loop {code1}", std::time::Instant::now());
            }
            let symbol1 = self.symbols[code1 as usize];
            let mut gain = counters.count1(code1) * symbol1.len();
            // NOTE: use heuristic from C++ implementation to boost the gain of single-byte symbols.
            // This helps to reduce exception counts.
            if code1 < 256 {
                gain *= 8;
            }
            if gain > 0 {
                println!("{:?} | pushing", std::time::Instant::now());
                pqueue.push(Candidate {
                    symbol: symbol1,
                    gain,
                });
                println!("{:?} | done pushing", std::time::Instant::now());
            }

            for code2 in 0u16..(256u16 + self.n_symbols as u16) {
                // println!("inner loop {code1} {code2}");
                let symbol2 = &self.symbols[code2 as usize];
                // If either symbol is zero-length, or if merging would yield a symbol of
                // length greater than 8, skip.
                if symbol1.len() + symbol2.len() > 8 {
                    continue;
                }
                let new_symbol = symbol1.concat(symbol2);
                let gain = counters.count2(code1, code2) * new_symbol.len();
                if gain > 0 {
                    println!("{:?} | inner pushing", std::time::Instant::now());
                    pqueue.push(Candidate {
                        symbol: new_symbol,
                        gain,
                    });
                    println!("{:?} | inner done pushing", std::time::Instant::now());
                }
            }
        }

        // Pop the 255 best symbols.
        println!("{:?} | popping time", std::time::Instant::now());
        let mut n_symbols = 0;
        while !pqueue.is_empty() && n_symbols < 255 {
            let candidate = pqueue.pop().unwrap();
            if res.insert(candidate.symbol) {
                n_symbols += 1;
            }
        }

        // If there are leftover slots, fill them with ASCII chars.
        // This helps reduce the number of escapes.
        //
        // Note that because of the lossy hash table, we won't accidentally
        // save the same ASCII character twice into the table.
        if include_ascii {
            for character in
                " abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ[](){}:?/<>".bytes()
            {
                if n_symbols == 255 {
                    break;
                }

                if res.insert(Symbol::from_u8(character)) {
                    n_symbols += 1
                }
            }
        }

        res
    }

And a corresponding trace:

running 1 test
test test_compressor ... [optimize]
Instant { tv_sec: 1, tv_nsec: 55790000 } | built pq
Instant { tv_sec: 1, tv_nsec: 101305000 } | looping...
Instant { tv_sec: 1, tv_nsec: 149705000 } | outer loop 0
Instant { tv_sec: 3, tv_nsec: 543360000 } | outer loop 10
Instant { tv_sec: 5, tv_nsec: 935480000 } | outer loop 20
Instant { tv_sec: 8, tv_nsec: 330060000 } | outer loop 30
Instant { tv_sec: 10, tv_nsec: 724025000 } | outer loop 40
Instant { tv_sec: 13, tv_nsec: 118915000 } | outer loop 50
Instant { tv_sec: 15, tv_nsec: 512575000 } | outer loop 60
Instant { tv_sec: 17, tv_nsec: 906235000 } | outer loop 70
Instant { tv_sec: 20, tv_nsec: 299895000 } | outer loop 80
Instant { tv_sec: 22, tv_nsec: 693555000 } | outer loop 90
Instant { tv_sec: 25, tv_nsec: 87215000 } | outer loop 100
Instant { tv_sec: 27, tv_nsec: 480865000 } | outer loop 110
Instant { tv_sec: 29, tv_nsec: 875420000 } | outer loop 120
Instant { tv_sec: 32, tv_nsec: 268745000 } | outer loop 130
Instant { tv_sec: 34, tv_nsec: 662685000 } | outer loop 140
Instant { tv_sec: 37, tv_nsec: 59085000 } | outer loop 150
Instant { tv_sec: 39, tv_nsec: 453965000 } | outer loop 160
Instant { tv_sec: 41, tv_nsec: 846675000 } | outer loop 170
Instant { tv_sec: 44, tv_nsec: 240000000 } | outer loop 180
Instant { tv_sec: 46, tv_nsec: 633940000 } | outer loop 190
Instant { tv_sec: 49, tv_nsec: 29110000 } | outer loop 200
Instant { tv_sec: 51, tv_nsec: 423375000 } | outer loop 210
Instant { tv_sec: 53, tv_nsec: 817930000 } | outer loop 220
Instant { tv_sec: 56, tv_nsec: 211870000 } | outer loop 230
Instant { tv_sec: 58, tv_nsec: 605810000 } | outer loop 240
Instant { tv_sec: 60, tv_nsec: 999135000 } | outer loop 250
Instant { tv_sec: 62, tv_nsec: 458225000 } | popping time
ok

In this sample run, I feed it an all-zero set of counters, so there is nothing actually being pushed onto the binary heap or popped off of it, it's just pure loop overhead.

This method is called 5 times when building a symbol table, and the time it takes doesn't vary with the size of the sample, so every time you call Compressor::train() miri will execute it in between 4-20 minutes.

a10y commented

Did some work in #13 and #16 to speed up miri enough to turn it back on. It's still very slow but roughly within expected bounds