sealPir 其他安全强度支持(4096安全参数支持)
Yinbenxin opened this issue · 4 comments
此问题发生在尝试使用4096安全强度来进行匿踪查询,经过为期1周的排查,问题终于浮出水面并得以解决。
问题代码:
seal_pir.cc中std::vectorseal::Ciphertext SealPirServer::ExpandQuery函数:
std::vector<seal::Ciphertext> SealPirServer::ExpandQuery(
const seal::Ciphertext &encrypted, std::uint32_t m) {
uint64_t plain_mod = seal_params_->plain_modulus().value();
seal::GaloisKeys &galkey = galois_key_;
// Assume that m is a power of 2. If not, round it to the next power of 2.
uint32_t logm = std::ceil(std::log2(m));
std::vector<int> galois_elts;
auto n = seal_params_->poly_modulus_degree();
YACL_ENFORCE(logm <= std::ceil(std::log2(n)), "m > n is not allowed.");
galois_elts.reserve(std::ceil(std::log2(n)));
for (size_t i = 0; i < std::ceil(std::log2(n)); i++) {
galois_elts.push_back((n + seal::util::exponentiate_uint(2, i)) /
seal::util::exponentiate_uint(2, i));
}
std::vector<seal::Ciphertext> results(1);
results[0] = encrypted;
seal::Plaintext tempPt;
for (size_t j = 0; j < logm - 1; j++) {
std::vector<seal::Ciphertext> results2(1 << (j + 1));
int step = 1 << j;
seal::Plaintext pt0(n);
seal::Plaintext pt1(n);
pt0.set_zero();
pt0[n - step] = plain_mod - 1;
std::cout << "plain_mods:" << plain_mod << std::endl;
int index_raw = (n << 1) - (1 << j); // -2^j
int index = (index_raw * galois_elts[j]) % (n << 1);
pt1.set_zero();
pt1[index] = 1;
std::cout << "pt0:" << pt0.to_string() << std::endl;
std::cout << "pt1:" << pt1.to_string() << std::endl;
// int nstep = -step;
yacl::parallel_for(0, step, [&](int64_t begin, int64_t end) {
for (int k = begin; k < end; k++) {
seal::Ciphertext c0;
seal::Ciphertext c1;
seal::Ciphertext t0;
seal::Ciphertext t1;
c0 = results[k];
// SPDLOG_INFO("apply_galois j:{} k:{}", j, k);
evaluator_->apply_galois(c0, galois_elts[j], galkey,
t0); // t0 = Sub(c0,N/(2^i)+1)
evaluator_->add(c0, t0, results2[k]); // c0 + Sub(c0,N/(2^i)+1)
// multiply_power_of_X(c0, c1, index_raw);
evaluator_->multiply_plain(c0, pt0, c1); // c1 = c0*(-x)^(-2j)
evaluator_->multiply_plain(t0, pt1, t1);
// Sub(c0,N/(2^i)+1) * x^(-2j*(N+2^i)/(2^i))=Sub(c1,N/2^j+1)
evaluator_->add(c1, t1, results2[k + step]);
}
});
results = results2;
}
// Last step of the loop
std::vector<seal::Ciphertext> results2(results.size() << 1);
seal::Plaintext two("2");
seal::Plaintext pt0(n);
seal::Plaintext pt1(n);
pt0.set_zero();
pt0[n - results.size()] = plain_mod - 1;
int index_raw = (n << 1) - (1 << (logm - 1));
int index = (index_raw * galois_elts[logm - 1]) % (n << 1);
pt1.set_zero();
pt1[index] = 1;
for (uint32_t k = 0; k < results.size(); k++) {
if (k >= (m - (1 << (logm - 1)))) { // corner case.
evaluator_->multiply_plain(results[k], two,
results2[k]); // plain multiplication by 2.
} else {
seal::Ciphertext c0;
seal::Ciphertext c1;
seal::Ciphertext t0;
seal::Ciphertext t1;
c0 = results[k];
evaluator_->apply_galois(c0, galois_elts[logm - 1], galkey, t0);
evaluator_->add(c0, t0, results2[k]);
// multiply_power_of_X(c0, c1, index_raw);
evaluator_->multiply_plain(c0, pt0, c1);
evaluator_->multiply_plain(t0, pt1, t1);
evaluator_->add(c1, t1, results2[k + results.size()]);
}
}
auto first = results2.begin();
auto last = results2.begin() + m;
std::vector<seal::Ciphertext> new_vec(first, last);
return new_vec;
}
建议修改为:
std::vector<seal::Ciphertext> SealPirServer::ExpandQuery(
const seal::Ciphertext &encrypted, std::uint32_t m) {
seal::GaloisKeys &galkey = galois_key_;
// Assume that m is a power of 2. If not, round it to the next power of 2.
uint32_t logm = std::ceil(std::log2(m));
std::vector<int> galois_elts;
auto n = seal_params_->poly_modulus_degree();
YACL_ENFORCE(logm <= std::ceil(std::log2(n)), "m > n is not allowed.");
galois_elts.reserve(std::ceil(std::log2(n)));
for (size_t i = 0; i < std::ceil(std::log2(n)); i++) {
galois_elts.push_back((n + seal::util::exponentiate_uint(2, i)) /
seal::util::exponentiate_uint(2, i));
}
std::vector<seal::Ciphertext> results(1);
results[0] = encrypted;
seal::Plaintext tempPt;
for (size_t j = 0; j < logm - 1; j++) {
std::vector<seal::Ciphertext> results2(1 << (j + 1));
int step = 1 << j;
int index_raw = (n << 1) - (1 << j);
int index = (index_raw * galois_elts[j]) % (n << 1);
// int nstep = -step;
yacl::parallel_for(0, step, [&](int64_t begin, int64_t end) {
for (int k = begin; k < end; k++) {
seal::Ciphertext c0;
seal::Ciphertext c1;
seal::Ciphertext t0;
seal::Ciphertext t1;
c0 = results[k];
// SPDLOG_INFO("apply_galois j:{} k:{}", j, k);
evaluator_->apply_galois(c0, galois_elts[j], galkey,
t0);
evaluator_->add(c0, t0, results2[k]);
multiply_power_of_X(c0, c1, index_raw);
multiply_power_of_X(t0, t1, index);
evaluator_->add(c1, t1, results2[k + step]);
}
});
results = results2;
}
// Last step of the loop
std::vector<seal::Ciphertext> results2(results.size() << 1);
seal::Plaintext two("2");
seal::Plaintext pt0(n);
seal::Plaintext pt1(n);
int index_raw = (n << 1) - (1 << (logm - 1));
int index = (index_raw * galois_elts[logm - 1]) % (n << 1);
for (uint32_t k = 0; k < results.size(); k++) {
if (k >= (m - (1 << (logm - 1)))) { // corner case.
evaluator_->multiply_plain(results[k], two,
results2[k]); // plain multiplication by 2.
} else {
seal::Ciphertext c0;
seal::Ciphertext c1;
seal::Ciphertext t0;
seal::Ciphertext t1;
c0 = results[k];
evaluator_->apply_galois(c0, galois_elts[logm - 1], galkey, t0);
evaluator_->add(c0, t0, results2[k]);
multiply_power_of_X(c0, c1, index_raw);
multiply_power_of_X(t0, t1, index);
evaluator_->add(c1, t1, results2[k + results.size()]);
}
}
auto first = results2.begin();
auto last = results2.begin() + m;
std::vector<seal::Ciphertext> new_vec(first, last);
return new_vec;
}
void SealPirServer::multiply_power_of_X(const seal::Ciphertext &encrypted,
seal::Ciphertext &destination,
uint32_t index) {
auto coeff_mod_count = seal_params_->coeff_modulus().size() - 1;
auto coeff_count = seal_params_->poly_modulus_degree();
auto encrypted_count = encrypted.size();
destination = encrypted;
for (size_t i = 0; i < encrypted_count; i++) {
for (size_t j = 0; j < coeff_mod_count; j++) {
seal::util::negacyclic_shift_poly_coeffmod(
encrypted.data(i) + (j * coeff_count), coeff_count, index,
seal_params_->coeff_modulus()[j],
destination.data(i) + (j * coeff_count));
}
}
}
主要原因是,multiply_plain会严重损耗seal密态计算的噪音,但是negacyclic_shift_poly_coeffmod不会导致噪音增大,并且在乘x^n时该函数具有更快的计算速度。
为了说明这个问题,可以用以下例子进行说明:
#include <iostream>
#include "seal/seal.h"
#include "seal/util/polyarithsmallmod.h"
using namespace std;
using namespace seal;
using namespace seal::util;
inline void multiply_power_of_X(const Ciphertext &encrypted,EncryptionParameters enc_params_,
Ciphertext &destination,
uint32_t index) {
auto coeff_mod_count = enc_params_.coeff_modulus().size() - 1;
auto coeff_count = enc_params_.poly_modulus_degree();
auto encrypted_count = encrypted.size();
destination = encrypted;
for (int i = 0; i < encrypted_count; i++) {
for (int j = 0; j < coeff_mod_count; j++) {
negacyclic_shift_poly_coeffmod(encrypted.data(i) + (j * coeff_count),
coeff_count, index,
enc_params_.coeff_modulus()[j],
destination.data(i) + (j * coeff_count));
}
}
}
int main() {
// 初始化 SEAL 库
int N = 4096;
EncryptionParameters parms(scheme_type::bfv);
parms.set_poly_modulus_degree(N);
parms.set_coeff_modulus(CoeffModulus::BFVDefault(N));
parms.set_plain_modulus(PlainModulus::Batching(N, 20));
auto context = SEALContext(parms);
uint64_t plain_mod = parms.plain_modulus().value();
// 生成密钥
seal::PublicKey public_key;
seal::SecretKey secret_key;
KeyGenerator keygen(context);
keygen.create_public_key(public_key);
secret_key = keygen.secret_key();
// 创建加密器
Encryptor encryptor(context, public_key);
// 创建一个多项式
Plaintext plain_coefficients(N);
plain_coefficients.set_zero();
plain_coefficients[1] = 10;
// 加密多项式
Ciphertext ciphertext;
encryptor.encrypt(plain_coefficients, ciphertext);
// 创建一个 x^10 的明文
Plaintext plain_power(N);
int step = 1 << 4;
plain_coefficients.set_zero();
int index_raw = (N << 1) - step;
plain_power[N - step] = plain_mod - 1;
Evaluator evaluator(context);
Decryptor decryptor(context, secret_key);
Ciphertext mpfx= ciphertext;
Ciphertext mp= ciphertext;
for (int i = 0; i < 4; ++i) {
Ciphertext mpfx_result;
Ciphertext mp_result ;
Plaintext mpfx_plaint;
Plaintext mp_plaint;
multiply_power_of_X(mpfx, parms, mpfx_result,index_raw);
evaluator.multiply_plain(mp, plain_power,mp_result);
decryptor.decrypt(mpfx_result, mpfx_plaint);
decryptor.decrypt(mp_result, mp_plaint);
cout << "multiply_power_of_X result: " << mpfx_plaint.to_string().substr(0,50) << endl;
cout << "multiply_plain result: " << mp_plaint.to_string().substr(0,50) << endl;
cout << "multiply_power_of_X 剩余可用噪音: " << decryptor.invariant_noise_budget(mpfx_result) << endl;
cout << "multiply_plain 剩余可用噪音: " << decryptor.invariant_noise_budget(mp_result) << endl;
mpfx= mpfx_result;
mp= mp_result;
}
return 0;
}
multiply_power_of_X result: FBFF7x^4081
multiply_plain result: FBFF7x^4081
multiply_power_of_X 剩余可用噪音: 45
multiply_plain 剩余可用噪音: 25
multiply_power_of_X result: FBFF7x^4065
multiply_plain result: FBFF7x^4065
multiply_power_of_X 剩余可用噪音: 45
multiply_plain 剩余可用噪音: 5
multiply_power_of_X result: FBFF7x^4049
multiply_plain result: FAD3Ax^4095 + 1E1x^4094 + F0x^4093 + FBF11x^4092 +
multiply_power_of_X 剩余可用噪音: 45
multiply_plain 剩余可用噪音: 0
multiply_power_of_X result: FBFF7x^4033
multiply_plain result: 5DCD2x^4095 + 4065Ex^4094 + 4065Ex^4093 + 9E32Fx^4
multiply_power_of_X 剩余可用噪音: 45
multiply_plain 剩余可用噪音: 0
可以看到噪音会迅速降低,从而导致计算错误。
之前仅仅支持8192是因为查询的总量较小,噪音并未消耗完毕,在数据量较大时会出现噪音不够所导致的计算错误问题。
@qxzhou1010 Would you mind to take a look at this?
@Yinbenxin 非常感谢您提出这个issue,并给出了优化的实现。这里我们是想在密文下计算 c1 = c0*(-x)^(-2j),由于 BFV 中多项式模采用了非常特殊的负循环多项式(x^N+1),因此这里的乘法运算本质上就是对 c0 的负循环移位操作。所以我们可以使用 negacyclic_shift_poly_coeffmod 来加速这个运算,并且这个过程对噪声消耗是零的,因为只涉及到对密文多项式一些简单的移位操作,所以并不会增加密文中所包含的噪声。
multiply_plain 是因为涉及到密文*明文,因此结果密文中的噪声项会被放大,所以每一次操作都会导致对噪声预算的消耗。
实际上,在 SealPIR 官方仓库中正是采用的这个实现。可以参考:https://github.com/microsoft/SealPIR/blob/ee1a5a3922fc9250f9bb4e2416ff5d02bfef7e52/src/pir_server.cpp#L415。
我们后续将会对这个点的实现进行优化,再次感谢您提出的问题和进行的验证。