void PAlgebraMod2r::init(unsigned r, unsigned m) { if (m == zmStar.M()) return; // nothign to do if (r<2 || r>=NTL_SP_NBITS) return; // sanity check ((PAlgebraModTwo&)modTwo).init(m); // initialize zmStar and modTwo, if needed if (zmStar.M()==0 || zmStar.NSlots()==0) return; // error in zmStar init long nSlots = zmStar.NSlots(); // Take the factors and their CRT coefficients mod 2 and lift them to mod 2^r zz_p::init(2); // convert the factors of Phi_m(X) from GF2X to zz_pX objects with p=2 vec_zz_pX vzzp; // no direct conversion from GF2X to zz_pX, vec_ZZX vzz; // need to go via ZZX vzzp.SetLength(nSlots); vzz.SetLength(nSlots); for (long i=0; i<nSlots; i++) { vzz[i] = to_ZZX(modTwo.factors[i]); conv(vzzp[i], vzz[i]); } // lift the factors of Phi_m(X) from mod-2 to mod-2^r MultiLift(vzz, vzzp, zmStar.PhimX(), r); // defined in NTL::ZZXFactoring // Compute the zz_pContext object for mod-2^r arithmetic rr = r; unsigned two2r = 1UL << r; // compute 2^r zz_p::init(two2r); mod2rContext.save(); PhimXmod = to_zz_pX(zmStar.PhimX()); // Phi_m(X) mod 2^r factors.SetLength(nSlots); for (long i=0; i<nSlots; i++) // Convert from ZZX to zz_pX conv(factors[i],vzz[i]); /* Debugging sanity-check #1: we should have Ft= GCD(F1(X^t),Phi_m(X)) zz_pXModulus F1(factors[0]); // We choose factors[0] as F1 zz_pXModulus Pm2(PhimXmod); for (long i=1; i<nSlots; i++) { unsigned t = zmStar.ith_rep(i); zz_pX X2t = PowerXMod(t,PhimXmod); // X2t = X^t mod Phi_m(X) zz_pX Ft = GCD(CompMod(F1,X2t,Pm2),Pm2); if (Ft != factors[i]) { cout << "Ft != F1(X^t) mod Phi_m(X), t=" << t << endl; exit(0); } }*******************************************************************/ // Finally compute the CRT coefficients for the factors crtCoeffs.SetLength(nSlots); for (long i=0; i<nSlots; i++) { zz_pX& fct = factors[i]; zz_pX te = PhimXmod / fct; // \prod_{j\ne i} Fj te %= fct; // \prod_{j\ne i} Fj mod Fi InvMod(crtCoeffs[i], te, fct);// \prod_{j\ne i} Fj^{-1} mod Fi } }
// return a degree-d irreducible polynomial mod p ZZX makeIrredPoly(long p, long d) { assert(d >= 1); assert(ProbPrime(p)); if (d == 1) return ZZX(1, 1); // the monomial X zz_pBak bak; bak.save(); zz_p::init(p); return to_ZZX(BuildIrred_zz_pX(d)); }
void ModComp(ZZX& res, const ZZX& g, const ZZX& h, const ZZX& f) { assert(LeadCoeff(f) == 1); ZZX hh = h % f; ZZX r = to_ZZX(0); for (long i = deg(g); i >= 0; i--) r = (r*hh + coeff(g, i)) % f; res = r; }
ZZX myCRT::EncodeMessageMxN(ZZX &mess){ ZZ_p::init(to_ZZ("2")); ZZ_pX res; SetCoeff(res, 0, 0); for(int i=0; i<size; i++) if(coeff(mess,i) == 1) res = res + MxN[i]; res = res%modulus; return to_ZZX(res); }
void EncryptedArrayDerived<type>::shift1D(Ctxt& ctxt, long i, long k) const { FHE_TIMER_START; const PAlgebra& al = context.zMStar; const vector< vector< RX > >& maskTable = tab.getMaskTable(); RBak bak; bak.save(); tab.restoreContext(); assert(&context == &ctxt.getContext()); assert(i >= 0 && i < (long)al.numOfGens()); long ord = al.OrderOf(i); if (k <= -ord || k >= ord) { ctxt.multByConstant(to_ZZX(0)); return; } // Make sure amt is in the range [1,ord-1] long amt = k % ord; if (amt == 0) return; if (amt < 0) amt += ord; RX mask = maskTable[i][ord-amt]; long val; if (k < 0) val = PowerMod(al.ZmStarGen(i), amt-ord, al.getM()); else { mask = 1 - mask; val = PowerMod(al.ZmStarGen(i), amt, al.getM()); } DoubleCRT m1(conv<ZZX>(mask), context, ctxt.getPrimeSet()); ctxt.multByConstant(m1); // zero out slots where mask=0 ctxt.smartAutomorph(val); // shift left by val FHE_TIMER_STOP; }
// Assumes current zz_p modulus is p^r // computes S = F^{-1} mod G via Hensel lifting void InvModpr(zz_pX& S, const zz_pX& F, const zz_pX& G, long p, long r) { ZZX ff, gg, ss, tt; ff = to_ZZX(F); gg = to_ZZX(G); zz_pBak bak; bak.save(); zz_p::init(p); zz_pX f, g, s, t; f = to_zz_pX(ff); g = to_zz_pX(gg); s = InvMod(f, g); t = (1-s*f)/g; assert(s*f + t*g == 1); ss = to_ZZX(s); tt = to_ZZX(t); ZZ pk = to_ZZ(1); for (long k = 1; k < r; k++) { // lift from p^k to p^{k+1} pk = pk * p; assert(divide(ss*ff + tt*gg - 1, pk)); zz_pX d = to_zz_pX( (1 - (ss*ff + tt*gg))/pk ); zz_pX s1, t1; s1 = (s * d) % g; t1 = (d-s1*f)/g; ss = ss + pk*to_ZZX(s1); tt = tt + pk*to_ZZX(t1); } bak.restore(); S = to_zz_pX(ss); assert((S*F) % G == 1); }
// prime power solver // zz_p::modulus() is assumed to be p^r, for p prime, r >= 1 // A is an n x n matrix, b is a length n (row) vector, // and a solution for the matrix-vector equation x A = b is found. // If A is not inverible mod p, then error is raised. void ppsolve(vec_zz_pE& x, const mat_zz_pE& A, const vec_zz_pE& b, long p, long r) { if (r == 1) { zz_pE det; solve(det, x, A, b); if (det == 0) Error("ppsolve: matrix not invertible"); return; } long n = A.NumRows(); if (n != A.NumCols()) Error("ppsolve: matrix not square"); if (n == 0) Error("ppsolve: matrix of dimension 0"); zz_pContext pr_context; pr_context.save(); zz_pEContext prE_context; prE_context.save(); zz_pX G = zz_pE::modulus(); ZZX GG = to_ZZX(G); vector< vector<ZZX> > AA; convert(AA, A); vector<ZZX> bb; convert(bb, b); zz_pContext p_context(p); p_context.restore(); zz_pX G1 = to_zz_pX(GG); zz_pEContext pE_context(G1); pE_context.restore(); // we are now working mod p... // invert A mod p mat_zz_pE A1; convert(A1, AA); mat_zz_pE I1; zz_pE det; inv(det, I1, A1); if (det == 0) { Error("ppsolve: matrix not invertible"); } vec_zz_pE b1; convert(b1, bb); vec_zz_pE y1; y1 = b1 * I1; vector<ZZX> yy; convert(yy, y1); // yy is a solution mod p for (long k = 1; k < r; k++) { // lift solution yy mod p^k to a solution mod p^{k+1} pr_context.restore(); prE_context.restore(); // we are now working mod p^r vec_zz_pE d, y; convert(y, yy); d = b - y * A; vector<ZZX> dd; convert(dd, d); long pk = power_long(p, k); vector<ZZX> ee; div(ee, dd, pk); p_context.restore(); pE_context.restore(); // we are now working mod p vec_zz_pE e1; convert(e1, ee); vec_zz_pE z1; z1 = e1 * I1; vector<ZZX> zz, ww; convert(zz, z1); mul(ww, zz, pk); add(yy, yy, ww); } pr_context.restore(); prE_context.restore(); convert(x, yy); assert(x*A == b); }
int main(int argc, char *argv[]) { if (argc<2) { cout << "\nUsage: " << argv[0] << " L [c=2 w=64 k=80 d=1]" << endl; cout << " L is the number of levels\n"; cout << " optional c is number of columns in the key-switching matrices (default=2)\n"; cout << " optional w is Hamming weight of the secret key (default=64)\n"; cout << " optional k is the security parameter (default=80)\n"; cout << " optional d specifies GF(2^d) arithmetic (default=1, must be <=16)\n"; // cout << " k is the security parameter\n"; // cout << " m determines the ring mod Phi_m(X)" << endl; cout << endl; exit(0); } cout.unsetf(ios::floatfield); cout.precision(4); long L = atoi(argv[1]); long c = 2; long w = 64; long k = 80; long d = 1; if (argc>2) c = atoi(argv[2]); if (argc>3) w = atoi(argv[3]); if (argc>4) k = atoi(argv[4]); if (argc>5) d = atoi(argv[5]); if (d>16) Error("d cannot be larger than 16\n"); cout << "\nTesting FHE with parameters L="<<L << ", c="<<c<<", w="<<w<<", k="<<k<<", d="<<d<< endl; // get a lower-bound on the parameter N=phi(m): // 1. Empirically, we use ~20-bit small primes in the modulus chain (the main // constraints is that 2m must divide p-1 for every prime p). The first // prime is larger, a 40-bit prime. (If this is a 32-bit machine then we // use two 20-bit primes instead.) // 2. With L levels, the largest modulus for "fresh ciphertexts" has size // q0 ~ p0 * p^{L} ~ 2^{40+20L} // 3. We break each ciphertext into upto c digits, do each digit is as large // as D=2^{(40+20L)/c} // 4. The added noise variance term from the key-switching operation is // c*N*sigma^2*D^2, and this must be mod-switched down to w*N (so it is // on part with the added noise from modulus-switching). Hence the ratio // P that we use for mod-switching must satisfy c*N*sigma^2*D^2/P^2<w*N, // or P > sqrt(c/w) * sigma * 2^{(40+20L)/c} // 5. With this extra P factor, the key-switching matrices are defined // relative to a modulus of size // Q0 = q0*P ~ sqrt{c/w} sigma 2^{(40+20L)(1+1/c)} // 6. To get k-bit security we need N>log(Q0/sigma)(k+110)/7.2, i.e. roughly // N > (40+20L)(1+1/c)(k+110) / 7.2 long ptxtSpace = 2; double cc = 1.0+(1.0/(double)c); long N = (long) ceil((pSize*L+p0Size)*cc*(k+110)/7.2); cout << " bounding phi(m) > " << N << endl; #if 0 // A small m for debugging purposes long m = 15; #else // pre-computed values of [phi(m),m,d] long ms[][4] = { //phi(m) m ord(2) c_m*1000 { 1176, 1247, 28, 3736}, { 1936, 2047, 11, 3870}, { 2880, 3133, 24, 3254}, { 4096, 4369, 16, 3422}, { 5292, 5461, 14, 4160}, { 5760, 8435, 24, 8935}, { 8190, 8191, 13, 1273}, {10584, 16383, 14, 8358}, {10752, 11441, 48, 3607}, {12000, 13981, 20, 2467}, {11520, 15665, 24, 14916}, {14112, 18415, 28, 11278}, {15004, 15709, 22, 3867}, {15360, 20485, 24, 12767}, // {16384, 21845, 16, 12798}, {17208 ,21931, 24, 18387}, {18000, 18631, 25, 4208}, {18816, 24295, 28, 16360}, {19200, 21607, 40, 35633}, {21168, 27305, 28, 15407}, {23040, 23377, 48, 5292}, {24576, 24929, 48, 5612}, {27000, 32767, 15, 20021}, {31104, 31609, 71, 5149}, {42336, 42799, 21, 5952}, {46080, 53261, 24, 33409}, {49140, 57337, 39, 2608}, {51840, 59527, 72, 21128}, {61680, 61681, 40, 1273}, {65536, 65537, 32, 1273}, {75264, 82603, 56, 36484}, {84672, 92837, 56, 38520} }; #if 0 for (long i = 0; i < 25; i++) { long m = ms[i][1]; PAlgebra alg(m); alg.printout(); cout << "\n"; // compute phi(m) directly long phim = 0; for (long j = 0; j < m; j++) if (GCD(j, m) == 1) phim++; if (phim != alg.phiM()) cout << "ERROR\n"; } exit(0); #endif // find the first m satisfying phi(m)>=N and d | ord(2) in Z_m^* long m = 0; for (unsigned i=0; i<sizeof(ms)/sizeof(long[3]); i++) if (ms[i][0]>=N && (ms[i][2] % d) == 0) { m = ms[i][1]; c_m = 0.001 * (double) ms[i][3]; break; } if (m==0) Error("Cannot support this L,d combination"); #endif // m = 257; FHEcontext context(m); #if 0 context.stdev = to_xdouble(0.5); // very low error #endif activeContext = &context; // Mark this as the "current" context context.zMstar.printout(); cout << endl; // Set the modulus chain #if 1 // The first 1-2 primes of total p0size bits #if (NTL_SP_NBITS > p0Size) AddPrimesByNumber(context, 1, 1UL<<p0Size); // add a single prime #else AddPrimesByNumber(context, 2, 1UL<<(p0Size/2)); // add two primes #endif #endif // The next L primes, as small as possible AddPrimesByNumber(context, L); ZZ productOfCtxtPrimes = context.productOfPrimes(context.ctxtPrimes); double productSize = context.logOfProduct(context.ctxtPrimes); // might as well test that the answer is roughly correct cout << " context.logOfProduct(...)-log(context.productOfPrimes(...)) = " << productSize-log(productOfCtxtPrimes) << endl; // calculate the size of the digits context.digits.resize(c); IndexSet s1; #if 0 for (long i=0; i<c-1; i++) context.digits[i] = IndexSet(i,i); context.digits[c-1] = context.ctxtPrimes / IndexSet(0,c-2); AddPrimesByNumber(context, 2, 1, true); #else double sizeSoFar = 0.0; double maxDigitSize = 0.0; if (c>1) { // break ciphetext into a few digits double dsize = productSize/c; // initial estimate double target = dsize-(pSize/3.0); long idx = context.ctxtPrimes.first(); for (long i=0; i<c-1; i++) { // compute next digit IndexSet s; while (idx <= context.ctxtPrimes.last() && sizeSoFar < target) { s.insert(idx); sizeSoFar += log((double)context.ithPrime(idx)); idx = context.ctxtPrimes.next(idx); } context.digits[i] = s; s1.insert(s); double thisDigitSize = context.logOfProduct(s); if (maxDigitSize < thisDigitSize) maxDigitSize = thisDigitSize; cout << " digit #"<<i+1<< " " <<s << ": size " << thisDigitSize << endl; target += dsize; } IndexSet s = context.ctxtPrimes / s1; // all the remaining primes context.digits[c-1] = s; double thisDigitSize = context.logOfProduct(s); if (maxDigitSize < thisDigitSize) maxDigitSize = thisDigitSize; cout << " digit #"<<c<< " " <<s << ": size " << thisDigitSize << endl; } else { maxDigitSize = context.logOfProduct(context.ctxtPrimes); context.digits[0] = context.ctxtPrimes; } // Add primes to the chain for the P factor of key-switching double sizeOfSpecialPrimes = maxDigitSize + log(c/(double)w)/2 + log(context.stdev *2); AddPrimesBySize(context, sizeOfSpecialPrimes, true); #endif cout << "* ctxtPrimes: " << context.ctxtPrimes << ", log(q0)=" << context.logOfProduct(context.ctxtPrimes) << endl; cout << "* specialPrimes: " << context.specialPrimes << ", log(P)=" << context.logOfProduct(context.specialPrimes) << endl; for (long i=0; i<context.numPrimes(); i++) { cout << " modulus #" << i << " " << context.ithPrime(i) << endl; } cout << endl; setTimersOn(); const ZZX& PhimX = context.zMstar.PhimX(); // The polynomial Phi_m(X) long phim = context.zMstar.phiM(); // The integer phi(m) FHESecKey secretKey(context); const FHEPubKey& publicKey = secretKey; #if 0 // Debug mode: use sk=1,2 DoubleCRT newSk(to_ZZX(2), context); long id1 = secretKey.ImportSecKey(newSk, 64, ptxtSpace); newSk -= 1; long id2 = secretKey.ImportSecKey(newSk, 64, ptxtSpace); #else long id1 = secretKey.GenSecKey(w,ptxtSpace); // A Hamming-weight-w secret key long id2 = secretKey.GenSecKey(w,ptxtSpace); // A second Hamming-weight-w secret key #endif ZZX zero = to_ZZX(0); // Ctxt zeroCtxt(publicKey); /******************************************************************/ /** TESTS BEGIN HERE ***/ /******************************************************************/ cout << "ptxtSpace = " << ptxtSpace << endl; GF2X G; // G is the AES polynomial, G(X)= X^8 +X^4 +X^3 +X +1 SetCoeff(G,8); SetCoeff(G,4); SetCoeff(G,3); SetCoeff(G,1); SetCoeff(G,0); GF2X X; SetX(X); #if 1 // code for rotations... { GF2X::HexOutput = 1; const PAlgebra& al = context.zMstar; const PAlgebraModTwo& al2 = context.modTwo; long ngens = al.numOfGens(); long nslots = al.NSlots(); DoubleCRT tmp(context); vector< vector< DoubleCRT > > maskTable; maskTable.resize(ngens); for (long i = 0; i < ngens; i++) { if (i==0 && al.SameOrd(i)) continue; long ord = al.OrderOf(i); maskTable[i].resize(ord+1, tmp); for (long j = 0; j <= ord; j++) { // initialize the mask that is 1 whenever // the ith coordinate is at least j vector<GF2X> maps, alphas, betas; al2.mapToSlots(maps, G); // Change G to X to get bits in the slots alphas.resize(nslots); for (long k = 0; k < nslots; k++) if (coordinate(al, i, k) >= j) alphas[k] = 1; else alphas[k] = 0; GF2X ptxt; al2.embedInSlots(ptxt, alphas, maps); // Sanity-check, make sure that encode/decode works as expected al2.decodePlaintext(betas, ptxt, G, maps); for (long k = 0; k < nslots; k++) { if (alphas[k] != betas[k]) { cout << " Mask computation failed, i="<<i<<", j="<<j<<"\n"; return 0; } } maskTable[i][j] = to_ZZX(ptxt); } } vector<GF2X> maps; al2.mapToSlots(maps, G); vector<GF2X> alphas(nslots); for (long i=0; i < nslots; i++) random(alphas[i], 8); // random degree-7 polynomial mod 2 for (long amt = 0; amt < 20; amt++) { cout << "."; GF2X ptxt; al2.embedInSlots(ptxt, alphas, maps); DoubleCRT pp(context); pp = to_ZZX(ptxt); rotate(pp, amt, maskTable); GF2X ptxt1 = to_GF2X(to_ZZX(pp)); vector<GF2X> betas; al2.decodePlaintext(betas, ptxt1, G, maps); for (long i = 0; i < nslots; i++) { if (alphas[i] != betas[(i+amt)%nslots]) { cout << " amt="<<amt<<" oops\n"; return 0; } } } cout << "\n"; #if 0 long ord0 = al.OrderOf(0); for (long i = 0; i < nslots; i++) { cout << alphas[i] << " "; if ((i+1) % (nslots/ord0) == 0) cout << "\n"; } cout << "\n\n"; cout << betas.size() << "\n"; for (long i = 0; i < nslots; i++) { cout << betas[i] << " "; if ((i+1) % (nslots/ord0) == 0) cout << "\n"; } #endif return 0; } #endif // an initial sanity check on noise estimates, // comparing the estimated variance to the actual average cout << "pk:"; checkCiphertext(publicKey.pubEncrKey, zero, secretKey); ZZX ptxt[6]; // first four are plaintext, last two are constants std::vector<Ctxt> ctxt(4, Ctxt(publicKey)); // Initialize the plaintext and constants to random 0-1 polynomials for (size_t j=0; j<6; j++) { ptxt[j].rep.SetLength(phim); for (long i = 0; i < phim; i++) ptxt[j].rep[i] = RandomBnd(ptxtSpace); ptxt[j].normalize(); if (j<4) { publicKey.Encrypt(ctxt[j], ptxt[j], ptxtSpace); cout << "c"<<j<<":"; checkCiphertext(ctxt[j], ptxt[j], secretKey); } } // perform upto 2L levels of computation, each level computing: // 1. c0 += c1 // 2. c1 *= c2 // L1' = max(L1,L2)+1 // 3. c1.reLinearlize // 4. c2 *= p4 // 5. c2.automorph(k) // k is the first generator of Zm^* /(2) // 6. c2.reLinearlize // 7. c3 += p5 // 8. c3 *= c0 // L3' = max(L3,L0,L1)+1 // 9. c2 *= c3 // L2' = max(L2,L0+1,L1+1,L3+1)+1 // 10. c0 *= c0 // L0' = max(L0,L1)+1 // 11. c0.reLinearlize // 12. c2.reLinearlize // 13. c3.reLinearlize // // The levels of the four ciphertexts behave as follows: // 0, 0, 0, 0 => 1, 1, 2, 1 => 2, 3, 3, 2 // => 4, 4, 5, 4 => 5, 6, 6, 5 // => 7, 7, 8, 7 => 8,,9, 9, 10 => [...] // // We perform the same operations on the plaintext, and after each operation // we check that decryption still works, and print the curretn modulus and // noise estimate. We stop when we get the first decryption error, or when // we reach 2L levels (which really should not happen). zz_pContext zzpc; zz_p::init(ptxtSpace); zzpc.save(); const zz_pXModulus F = to_zz_pX(PhimX); long g = context.zMstar.ZmStarGen(0); // the first generator in Zm* zz_pX x2g(g, 1); zz_pX p2; // generate a key-switching matrix from s(X^g) to s(X) secretKey.GenKeySWmatrix(/*powerOfS= */ 1, /*powerOfX= */ g, 0, 0, /*ptxtSpace=*/ ptxtSpace); // generate a key-switching matrix from s^2 to s secretKey.GenKeySWmatrix(/*powerOfS= */ 2, /*powerOfX= */ 1, 0, 0, /*ptxtSpace=*/ ptxtSpace); // generate a key-switching matrix from s^3 to s secretKey.GenKeySWmatrix(/*powerOfS= */ 3, /*powerOfX= */ 1, 0, 0, /*ptxtSpace=*/ ptxtSpace); for (long lvl=0; lvl<2*L; lvl++) { cout << "=======================================================\n"; ctxt[0] += ctxt[1]; ptxt[0] += ptxt[1]; PolyRed(ptxt[0], ptxtSpace, true); cout << "c0+=c1: "; checkCiphertext(ctxt[0], ptxt[0], secretKey); ctxt[1].multiplyBy(ctxt[2]); ptxt[1] = (ptxt[1] * ptxt[2]) % PhimX; PolyRed(ptxt[1], ptxtSpace, true); cout << "c1*=c2: "; checkCiphertext(ctxt[1], ptxt[1], secretKey); ctxt[2].multByConstant(ptxt[4]); ptxt[2] = (ptxt[2] * ptxt[4]) % PhimX; PolyRed(ptxt[2], ptxtSpace, true); cout << "c2*=p4: "; checkCiphertext(ctxt[2], ptxt[2], secretKey); ctxt[2] >>= g; zzpc.restore(); p2 = to_zz_pX(ptxt[2]); CompMod(p2, p2, x2g, F); ptxt[2] = to_ZZX(p2); cout << "c2>>="<<g<<":"; checkCiphertext(ctxt[2], ptxt[2], secretKey); ctxt[2].reLinearize(); cout << "c2.relin:"; checkCiphertext(ctxt[2], ptxt[2], secretKey); ctxt[3].addConstant(ptxt[5]); ptxt[3] += ptxt[5]; PolyRed(ptxt[3], ptxtSpace, true); cout << "c3+=p5: "; checkCiphertext(ctxt[3], ptxt[3], secretKey); ctxt[3].multiplyBy(ctxt[0]); ptxt[3] = (ptxt[3] * ptxt[0]) % PhimX; PolyRed(ptxt[3], ptxtSpace, true); cout << "c3*=c0: "; checkCiphertext(ctxt[3], ptxt[3], secretKey); ctxt[0].square(); ptxt[0] = (ptxt[0] * ptxt[0]) % PhimX; PolyRed(ptxt[0], ptxtSpace, true); cout << "c0*=c0: "; checkCiphertext(ctxt[0], ptxt[0], secretKey); ctxt[2].multiplyBy(ctxt[3]); ptxt[2] = (ptxt[2] * ptxt[3]) % PhimX; PolyRed(ptxt[2], ptxtSpace, true); cout << "c2*=c3: "; checkCiphertext(ctxt[2], ptxt[2], secretKey); } /******************************************************************/ /** TESTS END HERE ***/ /******************************************************************/ cout << endl; return 0; }
// bootstrap a ciphertext to reduce noise void FHEPubKey::reCrypt(Ctxt &ctxt) { FHE_TIMER_START; // Some sanity checks for dummy ciphertext long ptxtSpace = ctxt.getPtxtSpace(); if (ctxt.isEmpty()) return; if (ctxt.parts.size()==1 && ctxt.parts[0].skHandle.isOne()) { // Dummy encryption, just ensure that it is reduced mod p ZZX poly = to_ZZX(ctxt.parts[0]); for (long i=0; i<poly.rep.length(); i++) poly[i] = to_ZZ( rem(poly[i],ptxtSpace) ); poly.normalize(); ctxt.DummyEncrypt(poly); return; } assert(recryptKeyID>=0); // check that we have bootstrapping data long p = getContext().zMStar.getP(); long r = getContext().alMod.getR(); long p2r = getContext().alMod.getPPowR(); // the bootstrapping key is encrypted relative to plaintext space p^{e-e'+r}. long e = getContext().rcData.e; long ePrime = getContext().rcData.ePrime; long p2ePrime = power_long(p,ePrime); long q = power_long(p,e)+1; assert(e>=r); #ifdef DEBUG_PRINTOUT cerr << "reCrypt: p="<<p<<", r="<<r<<", e="<<e<<" ePrime="<<ePrime << ", q="<<q<<endl; #endif // can only bootstrap ciphertext with plaintext-space dividing p^r assert(p2r % ptxtSpace == 0); FHE_NTIMER_START(preProcess); // Make sure that this ciphertxt is in canonical form if (!ctxt.inCanonicalForm()) ctxt.reLinearize(); // Mod-switch down if needed IndexSet s = ctxt.getPrimeSet() / getContext().specialPrimes; // set minus if (s.card()>2) { // leave only bottom two primes long frst = s.first(); long scnd = s.next(frst); IndexSet s2(frst,scnd); s.retain(s2); // retain only first two primes } ctxt.modDownToSet(s); // key-switch to the bootstrapping key ctxt.reLinearize(recryptKeyID); // "raw mod-switch" to the bootstrapping mosulus q=p^e+1. vector<ZZX> zzParts; // the mod-switched parts, in ZZX format double noise = ctxt.rawModSwitch(zzParts, q); noise = sqrt(noise); // Add multiples of p2r and q to make the zzParts divisible by p^{e'} long maxU=0; for (long i=0; i<(long)zzParts.size(); i++) { // make divisible by p^{e'} long newMax = makeDivisible(zzParts[i].rep, p2ePrime, p2r, q, getContext().rcData.alpha); zzParts[i].normalize(); // normalize after working directly on the rep if (maxU < newMax) maxU = newMax; } // Check that the estimated noise is still low if (noise + maxU*p2r*(skHwts[recryptKeyID]+1) > q/2) cerr << " * noise/q after makeDivisible = " << ((noise + maxU*p2r*(skHwts[recryptKeyID]+1))/q) << endl; for (long i=0; i<(long)zzParts.size(); i++) zzParts[i] /= p2ePrime; // divide by p^{e'} // Multiply the post-processed cipehrtext by the encrypted sKey #ifdef DEBUG_PRINTOUT cerr << "+ Before recryption "; decryptAndPrint(cerr, recryptEkey, *dbgKey, *dbgEa, printFlag); #endif double p0size = to_double(coeffsL2Norm(zzParts[0])); double p1size = to_double(coeffsL2Norm(zzParts[1])); ctxt = recryptEkey; ctxt.multByConstant(zzParts[1], p1size*p1size); ctxt.addConstant(zzParts[0], p0size*p0size); #ifdef DEBUG_PRINTOUT cerr << "+ Before linearTrans1 "; decryptAndPrint(cerr, ctxt, *dbgKey, *dbgEa, printFlag); #endif FHE_NTIMER_STOP(preProcess); // Move the powerful-basis coefficients to the plaintext slots FHE_NTIMER_START(LinearTransform1); ctxt.getContext().rcData.firstMap->apply(ctxt); FHE_NTIMER_STOP(LinearTransform1); #ifdef DEBUG_PRINTOUT cerr << "+ After linearTrans1 "; decryptAndPrint(cerr, ctxt, *dbgKey, *dbgEa, printFlag); #endif // Extract the digits e-e'+r-1,...,e-e' (from fully packed slots) extractDigitsPacked(ctxt, e-ePrime, r, ePrime, context.rcData.unpackSlotEncoding); #ifdef DEBUG_PRINTOUT cerr << "+ Before linearTrans2 "; decryptAndPrint(cerr, ctxt, *dbgKey, *dbgEa, printFlag); #endif // Move the slots back to powerful-basis coefficients FHE_NTIMER_START(LinearTransform2); ctxt.getContext().rcData.secondMap->apply(ctxt); FHE_NTIMER_STOP(LinearTransform2); }
void EncryptedArrayDerived<type>::shift(Ctxt& ctxt, long k) const { FHE_TIMER_START; const PAlgebra& al = context.zMStar; const vector< vector< RX > >& maskTable = tab.getMaskTable(); RBak bak; bak.save(); tab.restoreContext(); assert(&context == &ctxt.getContext()); // Simple case: just one generator if (al.numOfGens()==1) { shift1D(ctxt, 0, k); return; } long nSlots = al.getNSlots(); // Shifting by more than the number of slots gives an all-zero cipehrtext if (k <= -nSlots || k >= nSlots) { ctxt.multByConstant(to_ZZX(0)); return; } // Make sure that amt is in [1,nslots-1] long amt = k % nSlots; if (amt == 0) return; if (amt < 0) amt += nSlots; // rotate the ciphertext, one dimension at a time long i = al.numOfGens()-1; long v = al.coordinate(i, amt); RX mask = maskTable[i][v]; Ctxt tmp(ctxt.getPubKey()); const RXModulus& PhimXmod = tab.getPhimXMod(); rotate1D(ctxt, i, v); for (i--; i >= 0; i--) { v = al.coordinate(i, amt); DoubleCRT m1(conv<ZZX>(mask), context, ctxt.getPrimeSet()); tmp = ctxt; tmp.multByConstant(m1); // only the slots in which mask=1 ctxt -= tmp; // only the slots in which mask=0 if (i>0) { rotate1D(ctxt, i, v+1); rotate1D(tmp, i, v); ctxt += tmp; // combine the two parts mask = ((mask * (maskTable[i][v] - maskTable[i][v+1])) % PhimXmod) + maskTable[i][v+1]; // update the mask before next iteration } else { // i == 0 if (k < 0) v -= al.OrderOf(0); shift1D(tmp, 0, v); shift1D(ctxt, 0, v+1); ctxt += tmp; } } FHE_TIMER_STOP; }