PBS with varying bit-widths
cgouert opened this issue ยท 11 comments
Is it possible to change the underlying plaintext modulus of a ciphertext with a PBS? For instance, if we have a 2-bit ciphertext, can the output of the PBS be a 4-bit ciphertext? This would effectively be a 2:4 LUT with 2^2 entries. If we want to evaluate 2x + 5, the ideal LUT would be as follows:
0 -> 5
1 -> 7
2 -> 9
3 -> 11
Can this be done without having to resort to a full 4:4 LUT as demonstrated in the following code?
use tfhe::integer::gen_keys_radix;
use tfhe::integer::wopbs::*;
use tfhe::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS;
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
fn main() {
let nb_block = 2;
let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_block);
let wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS);
let ct = cks.encrypt(2 as u64);
let ct = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct);
let lut = wopbs_key.generate_lut_radix(&ct, |x| 5 + 2 * x);
let ct_res = wopbs_key.wopbs(&ct, &lut);
let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res);
let res: u64 = cks.decrypt(&ct_res);
assert_eq!(res, 9);
}
hello @cgouert
the 2_2 parameters have 2 bits of message and 2 bits of carry available so you should be able to generate a lookup that fills carries and message, i.e. 4 bits in total
here did you try to use decrypt_message_and_carry it should have the high bits with what you put in the LUT ?
ah those are integer primitives :/
I'm running your unmodified code and it seems to work @cgouert ? Am I missing something ?
Hello @IceTDrinker, thanks for looking at this so quickly! Yes, the code does run, but these use 4-bit radix ciphertexts (i.e. block size of 2, with each block having 2 data bits and 2 carry bits) for both inputs and outputs. I was wondering if there was perhaps some way to evaluate a PBS over a single block of input (i.e. just a shortint ctxt with 2 data bits) and receive a radix ciphertext with 2 blocks. Or if this is not possible with radix ciphertexts, would one of the other APIs support it?
it might not be doable at the moment without hacking stuff up, on another hand integer is built over shortints so there should be a way to build something if you really want to yourself, but I understand the code base is not easy to get into ๐
could you describe your use case ? Maybe there is something better for your use case ?
If you only manipulate 4 bits, staying on shortint might be better and your use case is easy as pie :)
Thanks for the suggestion, we want to use a PBS that evaluates a non-linear function, but the bit sizes of the inputs and outputs of the function are not the same. For instance, we want to truncate a 16-bit value and use the top 8 bits to index a LUT where the outputs can grow to 2^16-1. We were wondering if there was a way to avoid encoding the actual PBS input as a full-fledged 16-bit encrypted radix ctxt and use a LUT with only 2^8 entries as opposed to 2^16.
Using trivial zeros as padding has a chance to provide the functionality you need with improved performance with the current code base, as trivial zeros are optimized both at the shortint level and the core crypto PBS level
@cgouert we'll have people who know the wopbs better give you an alternative that might be of interest to you :)
Thanks very much for your help and insights!
Hello @cgouert,
It seems that the simplest approach to execute a Wop-PBS with more output blocks than input is to perform two Wop-PBS operations: one to evaluate the least significant bits (LSB) and another for the most significant bits (MSB).
Note that at the end, we can recombine the LSB and MSB in a Radix context.
use tfhe::integer::{gen_keys_radix, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext};
use tfhe::integer::wopbs::*;
use tfhe::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS;
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
fn main() {
let nb_block = 2;
let msg = 14;
let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_block);
let wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS);
let ct = cks.encrypt(msg as u64);
let ct = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct);
let lut_lsb = wopbs_key.generate_lut_radix(&ct, |x| u64::pow(x,2) % (1<<4));
let ct_res_lsb = wopbs_key.wopbs(&ct, &lut_lsb);
let lut_msb = wopbs_key.generate_lut_radix(&ct, |x| ( u64::pow(x,2) >> 4) % (1<<4));
let ct_res_msb = wopbs_key.wopbs(&ct, &lut_msb);
let ct_res_msb = wopbs_key.keyswitch_to_pbs_params(&ct_res_msb);
let ct_res_lsb = wopbs_key.keyswitch_to_pbs_params(&ct_res_lsb);
let mut lsb_blocks = ct_res_lsb.clone().into_blocks();
let msb_blocks = ct_res_msb.clone().into_blocks();
lsb_blocks.extend(msb_blocks);
let ct_res = RadixCiphertext::from_blocks(lsb_blocks);
let res_lsb: u64 = cks.decrypt(&ct_res_lsb);
let res_msb: u64 = cks.decrypt(&ct_res_msb);
let res: u64 = cks.decrypt(&ct_res);
assert_eq!(res_lsb, u64::pow(msg,2) % (1<<4));
assert_eq!(res_msb, (u64::pow(msg,2) >> 4 ) % (1<<4));
assert_eq!((res_msb << 4) + res_lsb , u64::pow(msg,2));
assert_eq!(res , u64::pow(msg,2));
}
Thank you very much for your help @Loris-B and @IceTDrinker; this was very helpful!