pisa-engine/pisa

Bit vector does shady stuff

elshize opened this issue · 12 comments

The following function has issues:

return *(reinterpret_cast<uint64_t const*>(ptr + pos / 8)) >> (pos % 8);

  1. We cast an arbitrary byte pointer to int, making it UB.
  2. We assume the bit vector has enough bytes to actually dereference.

I tried replacing this with std::memcpy, and once I did, the address sanitizer catches access after buffer. What fixes it is reading not 8 but min(8, n) bytes where n is however many bytes are left to the end of the buffer.

Now, according to the comment this is an unsafe function. I'm not sure what that was supposed to mean: we know it can fail? should we use it under certain assumptions? does it refer to the fact we're only guaranteeing 56 bits? Not sure, anyone knows what the intention was?

I realize this was supposed to be some kind of optimization, but we can't have optimization that might sometimes fail. We should either patch it to check for bounds, or carefully document it, and make sure that assumptions are true at call site.

The bitvector cleanup PR reminded me of this. I guess perhaps @ot could chime in if he has a minute, I think he probably wrote this piece of code in the first place and knows when/when not to use it.

I think the most troubling thing is that the address sanitizer actually found problems when replaced with memcpy. Not sure why it's not catching it with the cast, but it being UB, all bets are off, possibly including the sanitizer behavior.

Getting more insight into what was the intention and what the function unsafety was supposed to mean would give help us a lot. Otherwise, it might be quite an onerous task to debug and analyze it.

@amallia also share your insight.

@JMMackenzie I did a very quick test, I added this assert:

assert(pos / 8 < (m_bits.size() * 8) - 7);

And it fails. So if I understand this right, there's not enough bytes to read in the buffer. So even disregarding misalignment UB, this is another problem.

In fact, even this fails:

assert(pos / 8 < (m_bits.size() * 8) - 2);

Meaning, here we have pos / 8 >= (m_bits.size() * 8) - 2, which means that the starting byte is within 2 bytes of the end of the buffer. So we can read at most 2 bytes, which is 16 bits, which is less than the promised 56 bits!

So not only is the implementation incorrect, but also there must be a bug at call site, because we call a function that should return at least 7 bytes, but only 2 are left.

Does that make sense or am I making a mistake in my calculations somewhere?

For completeness, this passes the tests I executed (haven't run all of them).

assert(pos / 8 < (m_bits.size() * 8) - 1);

m_bits is a uint64_t *

So: assert(pos / 8 < (m_bits.size() * 8) - 7); is saying: "Find the byte corresponding to pos and make sure it is at least 7 bytes before the end of m_bits" right?

If so then I agree.

Maybe the "trick" here is that even though it knows it will run over the end of the buffer, it will still access the correct portion of the data via the >> (pos % 8); part? I'm surprised it doesn't blow up though...

So: assert(pos / 8 < (m_bits.size() * 8) - 7); is saying: "Find the byte corresponding to pos and make sure it is at least 7 bytes before the end of m_bits" right?

I believe so.

Maybe the "trick" here is that even though it knows it will run over the end of the buffer, it will still access the correct portion of the data via the >> (pos % 8); part? I'm surprised it doesn't blow up though...

pos % 8 will only shift 7 bits at most, so I don't think that's it.

It's not that surprising that reading after buffer doesn't blow up. It will read garbage, though, in this corner case.

We might not even see read past buffer happen in real use case, because we'd need to read at the very end, and it might only happen in tests because our test index is small enough. But I don't know this for a fact.

It could also be that the encoding used is not affected by the trailing garbage, but I have no idea. For example, say you got unary 1110XXXX, then you don't care what's beyond 0 maybe.

All that said, this needs to be fixed, it's a potential ticking bomb. But as I said, it's possible that we have more than one bug, as it seems like some call sites also call it incorrectly (I think).

It could also be that the encoding used is not affected by the trailing garbage, but I have no idea. For example, say you got unary 1110XXXX, then you don't care what's beyond 0 maybe.

Yeah this is what I was getting at. But agreed in general, it looks scary. It isn't called in many places, mostly just the bitvector and elias_fano stuff. Perhaps like you said we never actually hit the unsafe case in practice, but I am in favour of fixing. We might just want to do some benchmarking as well to see how much of a hit we take doing things safely.

I was under the impression that the constructor 0s out padding at the end like this:

        if (size != 0U) {
            m_cur_word = &m_bits.back();
            // clear padding bits
            if (init && ((size % 64) != 0U)) {
                *m_cur_word >>= 64 - (size % 64);
            }
        }

@mpetri right, but it seems like regardless of that, we have a situation in which we call a function that is supposed to return at least 56 bits, but there are only 2 bytes left in the buffer from that position. So something is not right there.

If we read beyond the buffer no amount of padding will help us, because we have no guarantee what that memory we just read is. In fact, it's undefined behavior, which means our program from that point forward is technically speaking undefined behavior, which means there's no guarantee that, say, changing a compiler flag, or compiler version, won't blow it up for some reason.

So yeah, this is my concern: get_word56() seems to promise to return at least 7 bytes worth of information, but there are fewer bytes available to actually read.

Also, to be clear, even if we're not reading past the buffer, our implementation is UB anyway, because it casts a potentially unaliased char pointer to a 8-byte int, which is UB already. But this alone should be easily fixable with memcpy, but when I tried I noticed the other problem...

ot commented

A few points

  • The comment is a bit misleading, it should read "retrieves at least 56 bits, limited to the available bits in the bitvector". What this function does is, if you assume that the bitvector has at least k bits following pos, and k <= 56, then the returned word will contain those k bits, and maybe garbage after them. Of course it cannot return more bits than there are in the bitvector. It is responsibility of the caller to know how many bits are available, but that is usually guaranteed by some data structure invariant (but no attempt at graceful error handling is done if the data is corrupted).

  • The out-of-bounds read relies on an implicit assumption that the buffer that contains the bitvector has at least 7 bytes following the bitvector. I believe this was documented somewhere in the original version, but the comment may have been lost at some point. This introduces a somewhat unnatural requirement, but avoiding the branch to handle the tail special case is absolutely worth it (even if it's well predicted). In practical applications, the bitvector is always part of a larger structure, possibly a collection of posting lists all mmapped together, so satisfying the requirement just needs a constant 7 bytes of padding at the end of the whole index.

  • The UB on the cast is real (the alignment is incorrect), but I haven't seen any compiler miscompile this. That said, it should be fixed. I would suggest using this trick from folly, which produces a single-instruction read even without optimizations (while IIRC memcpy is optimized out only with optimizations enabled):
    https://github.com/facebook/folly/blob/main/folly/lang/Bits.h#L328-L360

@ot Thanks a lot for your comments, the explanation really helps, we'll make sure to document it.

Just to clarify: in the encoding algorithms where this is used, when the garbage bytes are read, it won't affect the algorithm, right? Meaning, the algorithm that calls get_word56() one way or another knows not to read beyond the meaningful bits and discards the trailing garbage?

The out-of-bounds read relies on an implicit assumption that the buffer that contains the bitvector has at least 7 bytes following the bitvector

Would it be possible to simply pad the vector itself? Just always pad 7 0-bytes at the end when building? I agree that in practice it's unlikely to cause problems, but it would help me sleep better at night 😅

I would suggest using this trick from folly, which produces a single-instruction read even without optimizations (while IIRC memcpy is optimized out only with optimizations enabled)

Thanks for pointing this out, I'll have a look. Frankly, I don't see much problem in memcpy not being optimzed with disabled optimizations. But I'd be interested to see if it can be done. We can probably also check out how bit_cast in C++20 is implemented.

We can probably also check out how bit_cast in C++20 is implemented.

Somewhat anticlimactically, it seems to be implemented with __builtin_bit_cast by both gcc and clang. Which means we can at least implement it with the builtin whenever it's available... it should be as fast as it gets, right?

I think the builtins should be available depending on compiler version and not C++ version (I think?), so we might be able to use them without c++20.

Something I missed initially was that the address sanitizer issue was not when reading index but during testing specific encodings, where there is no buffer padding.

So the solution is to fix the tests (add padding) and document the fact that padding is necessary for it to work properly.