tuneinsight/lattigo

Polynomial Evaluation on BGV

Closed this issue · 0 comments

func ObliviousPolynomialEvaluationBGV(params bgv.Parameters, degree int, ct *rlwe.Ciphertext, pre_power *rlwe.Ciphertext, coeffs_pt []*rlwe.Plaintext, evaluators []*bgv.Evaluator, encryptors []*rlwe.Encryptor, decryptor *rlwe.Decryptor, encoder *bgv.Encoder, targetLevel int) *rlwe.Ciphertext {
	var err error
	square_root := int(math.Ceil(math.Sqrt(float64(degree))))
	zero_cnt := 0
	blocks := make([]*rlwe.Ciphertext, square_root)
	for i := range blocks {
		blocks[i] = getItem(zero_pool)
		zero_cnt++
	}
	logger.Printf("=========== Evaluate Poly of degree %d ===========\n\n", degree)
	start := time.Now()
	small_base_power_num := int(math.Ceil(math.Log2(float64(square_root))))
	small_base_power := make([]*rlwe.Ciphertext, small_base_power_num+2)
	small_base_power[0] = getItem(one_pool) // one
	small_base_power[1] = ct                // x**1
	for i := 2; i < small_base_power_num+2; i++ {
		tmp, err := evaluators[0].MulRelinNew(small_base_power[i-1], small_base_power[i-1])
		if err != nil {
			panic(err)
		}
		err = evaluators[0].Rescale(tmp, tmp)
		if err != nil {
			panic(err)
		}
		small_base_power[i] = tmp
	}
	// power of x 0 1 2 4 8
	small_power_num := square_root
	small_power := make([]*rlwe.Ciphertext, small_power_num+1)
	small_power[0] = small_base_power[0] //one
	small_power[1] = small_base_power[1] //x**1
	wg.Add(small_power_num - 1)
	for i := 2; i < small_power_num+1; i++ {
		go func(i int) {
		index := i
		var one *rlwe.Ciphertext
		for j := 0; j <= small_base_power_num; j++ {
			if index&1 == 1 {
				if one == nil {
					one = small_base_power[j+1]
				} else {
					one, err = evaluators[i-2].MulRelinNew(small_base_power[j+1], one)
					if err != nil {
						panic(err)
					}
					err := evaluators[i-2].Rescale(one, one)
					if err != nil {
						panic(err)
					}
				}
			}
			index = index >> 1
		}
		small_power[i] = one
		res_mat := decryptor.DecryptNew(small_power[i])
		res := make([]uint64, params.MaxSlots())
		encoder.Decode(res_mat, res)
		fmt.Println("small pow", i, "level", small_power[i].Level(), res[:10])
		 wg.Done()
		 }(i)
	}
	wg.Wait()

	big_base_power_num := int(math.Ceil(math.Log2(math.Ceil(math.Sqrt(float64(degree)) - 1))))
	big_base_power := make([]*rlwe.Ciphertext, big_base_power_num+2)
	big_base_power[0] = small_base_power[0] //x**0
	if pre_power != nil {
		big_base_power[1] = pre_power //pre compute power x**16
	} else {
		big_base_power[1] = small_power[small_power_num] //the last power
	}
	for i := 2; i < big_base_power_num+2; i++ {
		tmp, err := evaluators[0].MulRelinNew(big_base_power[i-1], big_base_power[i-1])
		if err != nil {
			panic(err)
		}
		err = evaluators[0].Rescale(tmp, tmp)
		if err != nil {
			panic(err)
		}
		big_base_power[i] = tmp
	}
	// power of x**16 0 1 2 4 8 16

	big_power_num := square_root
	big_power := make([]*rlwe.Ciphertext, big_power_num)
	big_power[0] = big_base_power[0] //x**0
	big_power[1] = big_base_power[1] //x**16

	wg.Add(big_power_num - 2)

	for i := 2; i < big_power_num; i++ {
		go func(i int) {
		index := i
		var one *rlwe.Ciphertext
		for j := 0; j <= big_base_power_num; j++ {
			if index&1 == 1 {
				if one == nil {
					one = big_base_power[j+1]
				} else {
					one, err = evaluators[i-2].MulRelinNew(big_base_power[j+1], one)
					if err != nil {
						panic(err)
					}
					err := evaluators[i-2].Rescale(one, one)
					if err != nil {
						panic(err)
					}
				}
			}
			index = index >> 1
		}
		big_power[i] = one
		res_mat := decryptor.DecryptNew(big_power[i])
		res := make([]uint64, params.MaxSlots())
		encoder.Decode(res_mat, res)
		fmt.Println("big pow", i*square_root, "level", big_power[i].Level(), res[:10])
		 wg.Done()
		 }(i)
	}
	 wg.Wait()

	// compute every blocks
	coeffs := make([][]uint64, degree)
	for i := 0; i < degree; i++ {
		t := make([]uint64, params.N())
		for j := range t {
			t[j] = uint64(i + 1)
		}
		coeffs[i] = t
	}
	 wg.Add(big_power_num)
	for i := 0; i < big_power_num; i++ {
		 go func(i int) {
		level := targetLevel + 1
		sub_res_tmp := rlwe.NewCiphertext(params, 1, level)
		sub_res_tmp.Scale = params.DefaultScale().Div(big_power[i].Scale).Mul(rlwe.NewScale(params.RingQ().AtLevel(level).SubRings[level].Modulus))
		zero_cnt++
		for j := 0; j < small_power_num; j++ {
			if i*small_power_num+j == degree {
				break
			}
			evaluators[i].MulThenAdd(small_power[j], coeffs[i*small_power_num+j], sub_res_tmp)
		}
		fmt.Println(sub_res_tmp.Level(), big_power[i].Level())
		blocks[i], err = evaluators[i].MulRelinNew(sub_res_tmp, big_power[i])
		if err != nil {
			panic(err)
		}
		evaluators[i].Rescale(blocks[i], blocks[i])

		fmt.Println("level:", blocks[i].Level(), "block scale:", blocks[i].Scale.Uint64())

		 wg.Done()
		 }(i)
	}
	wg.Wait()

	res := blocks[0]

	for i := range blocks[1:] {
		res, err = evaluators[0].AddNew(res, blocks[i])

		if err != nil {
			panic(err)
		}
		res_mat := decryptor.DecryptNew(res)
		res_pt := make([]uint64, params.MaxSlots())
		encoder.Decode(res_mat, res_pt)
		fmt.Println("blocks", i, res_pt[:10], res.Scale.Uint64())
	}

	dt := time.Since(start)
	logger.Printf("> Eval Poly time %v \n\n", dt)
	logger.Printf("> Zero pop count %v \n\n", zero_cnt)

	fmt.Println("Output Scale:", res.Scale.Uint64())
	return res
}

I implemented the Paterson Stockmeyer algorithm and get the right result.
For example, I want to compute a polynomial f(x) of degree 255, fisrt compute small steps $x^0,x^1.....x^{15}$ and big steps $x^{16},x^{32},x^{48}..... x^{240}$ and then compute the blocks.

The block 0 is $a_0 x^0+a_1 x^1.....a_{15} x^{15}$ , here $a_i$ is the coeffs
the block 1 is $x^{16} (a_{16} + a_{17} x^1 ..... + a_{31}x^{15})$ and $block_i$ is the same
and finally I add these blocks together to get $f(x) = block_0 + block_1 .... + block_{15}$ .
The result of f(x) is right and it costs 9 multiplications to compute f(x).
I choose param

PQ := bgv.ParametersLiteral{
		LogN:             15,
		LogQ:             []int{40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40},
		LogP:             []int{40, 40, 40},
		PlaintextModulus: 65537,
	}

and I can compute 18 multiplications at most for this param.
I'm supposed to be able to perform 9 multiplications based on f(x) but I found out I can only do 7 multiplications.

Will my algorithm significantly reduce the noise budget?