google/heir

Rotation Keys & Composite Rotations

AlexanderViand-Intel opened this issue · 6 comments

In RLWE-schemes (B/FV, BGV, CKKS), rotations over packed/batched (SIMD) ciphertexts require a key switching operation to make the resulting ciphertext decrypt correctly under a non-permuted key, which requires a key-switching-key (ksk) for each supported rotation. These keys are also commonly called "rotation keys" or "galois keys" and, for most applications, make up the vast majority of the "evaluation keys".

While it would be possible to generate a rotation key for every possible rotation -(N-1), ... ,-1,1, ... ,N-1 this would create in a prohibitively large amount of key material. Instead, existing FHE libraries generally default to the strategy of generating keys for rotations +/- 2^k and then assemble all rotations from those. For example, a rotation by 9 would be realized as "8, 1"; a rotation by 15 as "16, -1"; a rotation by 58 as "64, -8, 2". Note that the availability of the negative rotation offsets leads to significantly more efficient paths than if we merely had the positive rotations available. Clearly, there are multiple ways to get to each number, but (afaik) libraries determine a unique path by computing the Non-Adjacent Form of the desired rotation offset. More specifically, it seems like the balanced NAF is the best approach.

The +/-2^k approach is a great default solution, yet for many programs (e.g., kernels of a fixed size, as in the SIMD vectorizer examples) we end up needing a much smaller number of unique rotation offsets and can gain a lot by generating only the keys for those exact rotations. This is doubly advantageous, as it both (a) reduces the amount of key material, speeding up keygen and reducing memory load and (b) allows us to realize all rotations with a single "native" rotation (i.e., all rotation paths are length one), speeding up the homomorphic computation.

Targets such as SEAL/OpenFHE hide this complexity behind their "rotate" API and, afaik, are happy as long as the rotation requested can be realized from the keys available. However, for targets such as LLVM/x86, HW accelerators, and anything else that doesn't have this internal logic, the compiler needs to do this translation. In addition, we might want to manually control how rotations are realized even for OpenFHE/etc, for example when our compiler can provide a better solution than the built-in defaults.

Effectively, this requires the following:

  • Add a lowering from arbitrary rotation offsets to a specific subset
    • Add an analysis that collects all required rotation* keys
      *(technically, in order to support multi-key/advanced use cases, the analysis should probably collect all key-switching-keys for a given base key)
    • Add a heuristic /decision logic for which approach to use (individual keys vs powers-of-two approach) that can be overriden/configured via pass parameter(s)
    • Implement a (balanced) NAF generator and associated lowering
    • Tie this into the OpenFHE* context/keygen logic
      *(it might be interesting to create something equivalent that works across various targets)

PS: Clearly, there is the usual circular dependency issue here w.r.t. optimization: given a program, we can easily find the optimal set of keys, and given a constrained set of keys, we can (probably?) find an optimal(-ish) program. However, given a high-level algorithm, it's not clear it's feasible to find the best program+keys for non-toy-sized prgorams. Given what we do/support now in terms of SIMD, following the program->rotation keys approach seems more reasonable right now.

PPS: This discussion was prompted by #742 (bgv.rotate's rotation index is now an Attribute, i.e., must be statically known). While I agree with the change, there's a bit of a false dichotomy in the description: the opposite of a statically known rotation index isn't necessarily a blind rotate - it could still be a plaintext rotation index, just one that's only available at runtime. However, this case, if ever needed, should probably be handled by code-generating dynamic logic that uses a NAF-based approach anyway, as that's target independent.

libraries determine a unique path by computing the Non-Adjacent Form of the desired rotation offset. More specifically, it seems like the balanced NAF is the best approach.

TIL! Great resources.

it could still be a plaintext rotation index

Agreed about the semantic distinction, but I don't see a real use case for this, do you? The rotation indices are inserted by the compiler so they should be static by definition.

Tie this into the OpenFHE* context/keygen logic

This was done in a basic manner by #696 via

builder.create<openfhe::GenRotKeyOp>(cryptoContext, privateKey, rotIndices);

I'm not sure to what extent that is generalizable. Maybe you mean just extract GenRotKeyOp to a shared dialect?

it could still be a plaintext rotation index

Agreed about the semantic distinction, but I don't see a real use case for this, do you? The rotation indices are inserted by the compiler so they should be static by definition.

I think this is unlikely in "compute" style style, but more likely in "protocol" style applications. Here's a somewhat contrived example: the client puts a 0/1 in each slot i, indicating if they are willing to sell item i today, then the server adds up (via rotations) the slots corresponding to the elements it is interested in buying today (which would be a dynamic input), the client learns how many items the server wants to buy today but not which (not actually a secure protocol, but hopefully illustrates the idea). But, as I said, in this case we can just emit mlir/std code corresponding to the NAF path construction.

This was done in a basic manner by #696

Oh, nice! I also didn't realize that there's already

SmallVector<int64_t> findAllRotIndices(func::FuncOp op) {
std::set<int64_t> distinctRotIndices;
op.walk([&](openfhe::RotOp rotOp) {
distinctRotIndices.insert(rotOp.getIndex().getInt());
return WalkResult::advance();
});
SmallVector<int64_t> rotIndicesResult(distinctRotIndices.begin(),
distinctRotIndices.end());
return rotIndicesResult;
}

So we'd need to pull this out, add the NAF logic, and make the rotIndices configurable in the ConfigureCryptoContext pass!

I'm not sure to what extent that is generalizable. Maybe you mean just extract GenRotKeyOp to a shared dialect?

I was actually thinking we'd want some kind of interface/API/convention for how passes like ConfigureCryptoContext can consume the result of the rotation/keys analysis. I guess that could be as simple as just defining it as an actual Analysis that this and similar passes can use.

Clearly, there are multiple ways to get to each number, but (afaik) libraries determine a unique path by computing the Non-Adjacent Form of the desired rotation offset. More specifically, it seems like the balanced NAF is the best approach.

For minimizing key material I don't think balanced NAF is what we want. Minimizing Hamming weight is closer to what we want, but even the algorithms in the Wikipedia link you mentioned are minimal on average, while we have a specific subset of rotations. I think we can cook up an optimization problem that achieves the specific subset of rotations we have, and, taking into account their quantity of use in the program, provides a tradeoff between minimizing key material and minimizing latency of the rotation operations.

I wrote a very naive (and slow) ILP that does this: https://github.com/j2kun/sdr/blob/main/sdr_ilp.py

Basically, it has an integer variable for each "base" rotation in [-N+1, N-1] (2N variables), one for each coefficient of the decomposition of an input rotation in terms of the base rotations (2N*k variables where k is the number of input rotations), and then an objective to balance minimizing the number of selected keys vs the cost of multiple rotations.

rotations = [3, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
N = 16

solution = find_optimal_rotations(rotations, N, key_material_weight=0.5)
print(solution)
# Solution(
#     objective=11.0,
#     rotations=[3, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
#     reconstructions=[[3], [6], [7], [8], [9], [10], [11], [12], [13], [14], [15]],
#     solve_time_seconds=0.09973740577697754,
# )

solution = find_optimal_rotations(rotations, N, key_material_weight=10)
print(solution)
# Solution(
#     objective=50.49999999999999,
#     rotations=[3, 4, 6, 8],
#     reconstructions=[
#         [3],
#         [6],
#         [3, 4],
#         [8],
#         [3, 6],
#         [4, 6],
#         [3, 8],
#         [4, 8],
#         [3, 4, 6],
#         [6, 8],
#         [3, 4, 8],
#     ],
#     solve_time_seconds=20.313071727752686,
# )

It should also be updated to have a latency cost in terms of how often the rotations are actually used in the input IR (a hard-to-represent rotation used once might not outweigh many easy-to-represent rotations that are used with high multiplicity).

Obviously the solve time of 20 seconds (for N=16) is a deal breaker, but maybe we can make it more efficient by restricting the input rotations to powers of 2, plus some heuristically chosen rotations based on the input set.

Demonstrating slowness, the box_blur 64x64 test case gives a trivial solution and takes 850 seconds

Finished building model, starting solver
Finished solving, extracting solution.
Solution(objective=5.0, rotations=[63, 65, 127, 3968, 4032], reconstructions=[[3968], [4032], [63], [127], [65]], solve_time_seconds=851.5273044109344)