dblalock/bolt

mithral C++ parameter optimization

fjrdev opened this issue · 16 comments

Does the mithral C++ codebase provide functions for parameter optimization? I have problems finding it and executing run_matmul() in the struct mithral_amm_task returns a wrong output matrix, since the optimization parameters are set randomly. Thank you!

From my understanding the C++ code does not contain parameter optimization. Only the python code called here does.

I have been trying to piece together what pieces from the python code need to be exported to make the cpp code work but I'm failing at some pieces. For example:

void mithral_encode(
    const float* X, int64_t nrows, int ncols,
    const uint32_t* splitdims, const int8_t* all_splitvals,
    const float* scales, const float* offsets, int ncodebooks, uint8_t* out);

here splitdims and all_splitvals can be derived from the MithralEncoder::splits_lists array but some of the asserts in the cpp code seem to fail (or at least I don't know how to properly extract the values form the python code).

at the same time creating the query lookup tables as shown here I can't seem to figure out what to pass into the cpp equivalent functions which have the functions:

void mithral_lut_dense(const float* Q, int nrows, int ncols, int ncodebooks,
    const float* centroids, float& out_offset_sum, float& out_scale,
    float*__restrict__ tmp_lut_f32, uint8_t* out);

Overall while the papers results look impressive, I don't think it is currently possible to reproduce the speed and accuracy numbers reported in the paper with a single codebase.

The python code produces the correct effectiveness numbers and tradeoffs but without proper integration into the c++ bindings don't show the same speed characteristics as what is reported in the paper.

@mpetri Im also having problems with using the python version to learn the parameters and the c++ version to do the rest.
Using clusterize.learn_mithral(...) should be the right way to extract the parameters though. For the C++ part, Im using this struct, which sets the learned values randomly at first. Im just overwriting them with the learned values and run it.

But Im still facing some problems with the calculation process. The last n/2 columns of my output matrix are always zeroed out (this also happens when not loading the previously learned parameters into the object, hence just running run_mithral()).

First I generate X in python and run clusterize.learn_mithral(...). Then I extract the needed parameters and overwrite them in the mithral_amm_task object (split_dims, split_vals, centroids, encode_offsets and encode_scales). For Q Im using the randomly generated Matrix from the object.
For the split-values, the C++ struct requires a Matrix with the following dimensions (16, #codebooks * #splits_per_codebook) == (16, #codebooks * 4). Since there are 16 rows, but we only get 15 values per codebook from python, I have zeroed out the last row at all columns. Then every column will be filled with one array of split-values from one specific layer of the codebook-tree (eg.: [171,0,0,...,0] for the first row, [16,255,0,...,0] for the second row and so on). So that one codebooks fills 4 columns of the splitvals-matrix. This is the point where Im uncertain whether this mapping procedure is correct. The other values are straight forward.

Im also curious how @dumpinfo has managed the py -> c++ parameter migration.

Since there are 16 rows, but we only get 15 values per codebook from python, I have zeroed out the last row at all columns

I have the same problem. I was printing the self.splits_lists from the python code and it only shows 15 values but the c++ code asserts there needing to be 16:

0-0 dim = 3
0-0 vals = [131]
0-0 scaleby = 256.0
0-0 offset = -1.1661134958267212
0-1 dim = 6
0-1 vals = [64 58]
0-1 scaleby = 128.0
0-1 offset = -0.4562876389827579
0-2 dim = 2
0-2 vals = [ 96 105  88  81]
0-2 scaleby = 256.0
0-2 offset = 2.1185343861579895
0-3 dim = 12
0-3 vals = [143 124  76  80 162 148  59 112]
0-3 scaleby = 256.0
0-3 offset = -0.5600365102291107

counting the different vals I only get 15.

Maybe @dblalock could help out here to help us move forward?

Do you have some code you can share?

Yes, padding with an extra zero at the end is probably the best solution. There are logically only 15, but the C++ cares about alignment and so wants blocks of 16 of them.

The best overall solution would be having clean wrappers for the C++ code and having the Python call those instead of the Python implementations. This would basically just solve everything.

But...I just never got to that. The results in the paper just join a table of accuracies and a table of speeds on the matrix shapes + num_codebooks, which is, in the words of the experiments readme, "kind of an abomination."

@fjrdev do you have some code to share that extracts the python values and imports into the c++ codebase?

@mpetri I print the split dims, split values , centroids, scales and offsets to a .txt document and load it again after initializing mithral_amm_task in C++. Then I just overwrite the randomly set matrices of the object with the learned parameters from python. But as I mentioned above, I still get a wrong output matrix (which appears to be quite similar to the one which I would receive without loading the learned parameters from python). Im still on it.

What procedure are you following when learning the parameters in python?

@fjrdev I ported the python code to rust that produces equivalent results. However, I'm now trying to incorporate the more efficient c++ code and I'm encountering more and more issues. For example:

  • The c++ encode function(s) require(s) input matrices to be column major which is unusual and transforming the input into that format could have a substantial speed penalty

  • the main mithral_encode() functions also seem to require additional layouts to the splitvals arrays and the split_encode_8b_colmajor function (and similar) in multisplit.hpp seem to more closely resemble what the python code does.

  • The c++ code requires split vals to be int8_t (to use i8 specific simd instructions) but all the python code requires uint8_t where everything is normalized to [0,255]:

split.vals = np.clip(split.vals, 0, 255).astype(np.int32)

It is unclear if scaling this into into int8_t ranges is safe and is never explained anywhere

any thoughts?

after spending substantial time on this I'm getting more convinced that without the help of the original author this can't be made to work as intended :(

The int8 vs uint8 shouldn't matter for the purpose of splitting. Any affine transformation is fine as long as it's applied to the data and the split vals the same way.

If I remember correctly, the split vals are each supposed to be sequences of 16x8bit vectors. It's one such vector per codebook. So the split vals array is a contiguous sequence of C 16B vectors, for a total of 16C bytes. The first split val is at element 0 in each vector, then the next two vals are at indices {1, 2}, etc. The last element is unused and just for alignment.

I think (based on the encoding logic + the fact that it's the minimal representation) that the splitdims are just C sequences of 4 ints, all contiguous.

And wow, I am really regretting not commenting the relevant functions better--I thought I had the main stuff doxygenated like in bolt.hpp, but boy was I wrong. Sorry about that.

Also, kind of a moot point, but I don't think column-major is that weird; it's the default in Eigen, Matlab, and Julia IIRC. But you're totally right that transposing will add overhead if you're starting with a rowmajor matrix.

"If I remember correctly, the split vals are each supposed to be sequences of 16x8bit vectors. It's one such vector per codebook."

@dblalock Im a little bit confused by the dimensions of the splitvals matrix. Since there are (# codebooks * 4) columns I assumed that one codebook fills 4 columns of the splitvals matrix of the mithral_amm_task-object. Given the following python splitvals output for one codebook:
[ [179], [41, 210], [29, 94, 193, 255], [8, 70, 101, 139, 189, 237, 255, 255] ]

The corresponding part of the splitvals matrix would look like this:
[179] [41] [29] [8]
[0] [210] [94] [70]
[0] [0] [193] [101]
[0] [0] [255] [139]
[0] [0] [0] [189]
[0] [0] [0] [237]
[0] [0] [0] [255]
[0] [0] [0] [255]
[0] [0] [0] [0]
[0] [0] [0] [0]
[0] [0] [0] [0]
[0] [0] [0] [0]
[0] [0] [0] [0]
[0] [0] [0] [0]
[0] [0] [0] [0]
[0] [0] [0] [0]

According to your answer, the correct alignment would look like this:
[179]
[41]
[210]
[29]
[94]
[193]
[255]
[8]
[70]
[101]
[139]
[189]
[237]
[255]
[255]
[0]

I don't quite understand what happens to the 3 remaining columns reserved for this codebook.

I looked at the code some more and I think your first (zero-padded) version is correct. There's no packing into a single vector happening.

    int split_idx = 0;
    for (int c = 0; c < ncodebooks; c++) {
        // compute input and output column starts
        ...
        for (int s = 0; s < nsplits_per_codebook; s++) {
            ...
            auto splitvals_ptr = all_splitvals + (vals_per_split * split_idx);
            current_vsplitval_luts[s] = _mm256_broadcastsi128_si256(
                load_si128i((const __m128i*)splitvals_ptr));
        }
        split_idx += nsplits_per_codebook;

Maybe I don't really understand the algorithm correctly but from my understand you use the split vals to walk down a binary tree updating the codes with 2 * code or 2 * code +1 if you go left or right in the tree. This is what this function does in the python version of the mithral_encode function: https://github.com/dblalock/bolt/blob/master/experiments/python/clusterize.py#L1910

however, there is a dependency obviously on the having processed the first level of the binary tree to decide what path in the tree to take next. if we load [179] [41] [29] [8] into one of the registers and process [179] represents the split val of the root node of the tree and [41] and [29] represents left/right. How can we decide to compare to [41] or [29] without first performing comparisons at the root node.

Actually looking at the code some more I see we have a for loop over blocks and inside that a for loop over split vals so that would mean we can make those decisions sequentially!

Yes, it's not obvious because it's SIMD-ified, but we actually are walking the tree in this loop (which I think is the one you're referring to).

The node for a given input is stored in the codes variable, which describes each element's index into the list of nodes at the current level of the tree. We use this variable to look up the associated splitval for each element. From there we cmp and update the code as code = 2*code + (val > splitval ? 1 : 0).

VpouL commented

Hello @mpetri , Can you share your python code to split the result of clusterize.learn_mithral(...). I tried to get the optimization parameters of mithral_amm through python, but also got a wrong result.

@VpouL A wrong result at out_mat(N, M)?

VpouL commented

@fjrdev Just like you, splitdims and all_splitvals can be derived from the MithralEncoder::splits_lists array but I don't know how to properly extract the values form the python code.
And I am also very confused that out_mat(N, M) data type is not float.