tuneinsight/lattigo

Bug[rlwe]: unsigned digit decomposition producing additional error

Pro7ech opened this issue · 0 comments

What version of Lattigo are you using?

$ go get github.com/tuneinsight/lattigo/v5@latest

Does this issue persist with the latest release?

Yes

What were you trying to do?

Estimate the noise produced by $\textsf{RLWE}(m_{0}) \otimes (\textsf{RGSW}(m_{1}) \otimes \textsf{RGSW}(m_{2}))$

What were you expecting to happen?

Given a ring degree $N$, $B = 2^{w}$, $\textbf{w} = [2^{0}, 2^{1}, \dots, 2^{d}]$ with $d = \lceil\log(Q)/w\rceil$, and

  • $\textsf{RLWE}(m_{0}) = (-a_{0}s + m_{0} + e_{0}, a_{0})$
  • $\textsf{RGSW}(X^{i}) = ((-a_{1}s + \textbf{w}\cdot X^{i} + e_{1}, a_{1}), (-a_{2}s + e_{2}, a_{2} + \textbf{w}\cdot X^{i}))$

the variance of the noise of $\textsf{RLWE}(m_{0})\otimes\textsf{RGSW}(m_{1})$ was expected to match the following prediction:

$$ \begin{align} \sigma_{\textsf{RLWE}(m_{0})\otimes\textsf{RGSW}(X^{i})}=\sqrt{N\left(d\frac{B^2}{12}(\sigma_{e_{1}}^2 + \sigma_{e_{2}}^2)) + \sigma_{e_{0}}^2\right)} \end{align} $$

The variance of $\textsf{RGSW}(X^{i})\otimes\textsf{RGSW}(X^{j})$ follows directly since an $\textsf{RGSW}\otimes\textsf{RGSW}$ products is a set of $\textsf{RLWE}\otimes\textsf{RGSW}$ products, and thus it only requires to update $e_{1}$ and $e_{2}$ with the resulting noise.

What actually happened?

With $N=11$, $\log(Q)=50$ and $w=5$ the base 2 logarithm of the noise of $\textsf{RGSW}(X^{i})\otimes\textsf{RGSW}(X^{j})$ should be $12.5465$ (set $e_{0}, e_{1} = 3.2$) but is $\approx13.2623$ with Lattigo, similarly, the base 2 logarithm of the noise of $\textsf{RLWE}(m_{0}) \otimes (\textsf{RGSW}(m_{1}) \otimes \textsf{RGSW}(m_{2}))$ (set $e_{0}, e_{1} = 2^{12.54}$) should be $23.4150$ but is $\approx30.9518$, which is a substantial difference, enough to cause issue during parameterization and when estimating the failure probability of blind rotations.

This is caused by the unsigned digit decomposition (both in the gadget product and external product).

Reproducibility

package main

import(
	"fmt"
	"github.com/tuneinsight/lattigo/v5/core/rlwe"
	"github.com/tuneinsight/lattigo/v5/core/rgsw"
	"github.com/tuneinsight/lattigo/v5/ring"
	"github.com/tuneinsight/lattigo/v5/utils"
)


func main(){

	params0, err := rlwe.NewParametersFromLiteral(rlwe.ParametersLiteral{
		LogN: 11,
		LogQ: []int{50},
		NTTFlag: true,
	})

	if err != nil{
		panic(err)
	}

	params1, err := rlwe.NewParametersFromLiteral(rlwe.ParametersLiteral{
		LogN: 11,
		LogQ: []int{50},
		NTTFlag: true,
	})

	w := 5

	kgen := rlwe.NewKeyGenerator(params1)
	sk := kgen.GenSecretKeyNew()

	pt0 := rlwe.NewPlaintext(params0, params0.MaxLevel())
	for i := range pt0.Value.Coeffs{
		pt0.Value.Coeffs[i][1] = 1
	}
	if params0.NTTFlag(){
		params0.RingQ().NTT(pt0.Value, pt0.Value)
	}

	pt1 := rlwe.NewPlaintext(params0, params0.MaxLevel())
	for i := range pt1.Value.Coeffs{
		pt1.Value.Coeffs[i][1] = 1
	}
	if params0.NTTFlag(){
		params0.RingQ().NTT(pt1.Value, pt1.Value)
	}

	pt2 := rlwe.NewPlaintext(params0, params0.MaxLevel())
	for i := range pt2.Value.Coeffs{
		pt2.Value.Coeffs[i][0] = 1<<40
	}
	if params0.NTTFlag(){
		params0.RingQ().NTT(pt2.Value, pt2.Value)
	}

	ct0 := rgsw.NewCiphertext(params0, params0.MaxLevel(), params0.MaxLevelP(), w)
	ct1 := rgsw.NewCiphertext(params1, params1.MaxLevel(), params1.MaxLevelP(), w)

	n := 1024
	var std00, std01, std1 float64
	for w := 0; w < n; w++{
		fmt.Println(w)
		if err := rgsw.NewEncryptor(params0, sk).Encrypt(pt0, ct0); err != nil{
			panic(err)
		}

		for i := range ct0.Value{
			for j := range ct0.Value[i].Value{
				for k := range ct0.Value[i].Value[j]{
					params0.RingQ().IMForm(ct0.Value[i].Value[j][k][0].Q, ct0.Value[i].Value[j][k][0].Q)
					params0.RingQ().IMForm(ct0.Value[i].Value[j][k][1].Q, ct0.Value[i].Value[j][k][1].Q)
				}
			}
		}
		
		if err := rgsw.NewEncryptor(params1, sk).Encrypt(pt1, ct1); err != nil{
			panic(err)
		}

		eval := rgsw.NewEvaluator(params1, nil)

		cts := FlattenRGSWToRLWE(ct0)

		for k := 0; k < 1; k++{
			for i := range cts{
				eval.ExternalProduct(&cts[i], ct1, &cts[i])
			}

			params0.RingQ().MForm(pt0.Value, pt0.Value)
			params0.RingQ().MulCoeffsMontgomery(pt0.Value, pt1.Value, pt0.Value)
			
			x00, y00 := NoiseRGSWCiphertext(ct0, pt0.Value, sk, params0)
			std00 += x00
			std01 += y00
		}

		ct2, err := rlwe.NewEncryptor(params0, sk).EncryptNew(pt2)

		if err != nil{
			panic(err)
		}

		for i := range ct0.Value{
			for j := range ct0.Value[i].Value{
				for k := range ct0.Value[i].Value[j]{
					params0.RingQ().MForm(ct0.Value[i].Value[j][k][0].Q, ct0.Value[i].Value[j][k][0].Q)
					params0.RingQ().MForm(ct0.Value[i].Value[j][k][1].Q, ct0.Value[i].Value[j][k][1].Q)
				}
			}
		}

		eval.ExternalProduct(ct2, ct0, ct2)

		params0.RingQ().MForm(pt2.Value, pt2.Value)
		params0.RingQ().MulCoeffsMontgomery(pt2.Value, pt0.Value, pt2.Value)
		params0.RingQ().Sub(ct2.Value[0], pt2.Value, ct2.Value[0])

		pt4 := rlwe.NewDecryptor(params0, sk).DecryptNew(ct2)

		if pt4.IsNTT{
			params0.RingQ().INTT(pt4.Value, pt4.Value)
		}

		std1 += params0.RingQ().Log2OfStandardDeviation(pt4.Value)
	}

	std00 /= float64(n)
	std01 /= float64(n)
	std1 /= float64(n)

	fmt.Println(std00, std01)
	fmt.Println(std1)

}

func FlattenRGSWToRLWE(ct *rgsw.Ciphertext) (cts []rlwe.Ciphertext){
	for k := range ct.Value{
		for i := range ct.Value[k].Value{
			for j := range ct.Value[k].Value[i]{
				tmp := rlwe.Ciphertext{}
				tmp.Value = []ring.Poly{ct.Value[k].Value[i][j][0].Q, ct.Value[k].Value[i][j][1].Q}
				cts = append(cts, tmp)
			}
		}
	}

	return
}

func NoiseRGSWCiphertext(ct *rgsw.Ciphertext, pt ring.Poly, sk *rlwe.SecretKey, params rlwe.Parameters) (float64, float64) {
	ptsk := *pt.CopyNew()
	params.RingQ().AtLevel(ct.LevelQ()).MulCoeffsMontgomery(ptsk, sk.Value.Q, ptsk)
	return NoiseGadgetCiphertext(&ct.Value[0], pt, sk, params), NoiseGadgetCiphertext(&ct.Value[1], ptsk, sk, params)
}

func NoiseGadgetCiphertext(gct *rlwe.GadgetCiphertext, pt ring.Poly, sk *rlwe.SecretKey, params rlwe.Parameters) float64 {

	gct = gct.CopyNew()
	pt = *pt.CopyNew()
	levelQ, levelP := gct.LevelQ(), gct.LevelP()
	ringQP := params.RingQP().AtLevel(levelQ, levelP)

	ringQ, ringP := ringQP.RingQ, ringQP.RingP

	var BaseTwoDecompositionVectorSize = gct.BaseTwoDecompositionVectorSize()[0]
	for _, d := range gct.BaseTwoDecompositionVectorSize(){
		if d < BaseTwoDecompositionVectorSize{
			BaseTwoDecompositionVectorSize = d
		}
	}

	// Decrypts
	// [-asIn + w*P*sOut + e, a] + [asIn]
	for i := range gct.Value {
		for j := range gct.Value[i] {
			ringQP.MulCoeffsMontgomeryThenAdd(gct.Value[i][j][1], sk.Value, gct.Value[i][j][0])
		}
	}

	// Sums all bases together (equivalent to multiplying with CRT decomposition of 1)
	// sum([1]_w * [RNS*PW2*P*sOut + e]) = PWw*P*sOut + sum(e)
	for i := range gct.Value { // RNS decomp
		if i > 0 {
			for j := range gct.Value[i] { // PW2 decomp
				ringQP.Add(gct.Value[0][j][0], gct.Value[i][j][0], gct.Value[0][j][0])
			}
		}
	}

	if levelP != -1 {
		// sOut * P
		ringQ.MulScalarBigint(pt, ringP.ModulusAtLevel[levelP], pt)
	}

	var maxLog2Std float64

	for i := 0; i < BaseTwoDecompositionVectorSize; i++ {

		// P*s^i + sum(e) - P*s^i = sum(e)
		ringQ.Sub(gct.Value[0][i][0].Q, pt, gct.Value[0][i][0].Q)

		// Checks that the error is below the bound
		// Worst error bound is N * floor(6*sigma) * #Keys
		ringQP.INTT(gct.Value[0][i][0], gct.Value[0][i][0])

		maxLog2Std = utils.Max(maxLog2Std, ringQP.Log2OfStandardDeviation(gct.Value[0][i][0]))

		// sOut * P * PW2
		ringQ.MulScalar(pt, 1<<gct.BaseTwoDecomposition, pt)
	}

	return maxLog2Std
}