static void recursivePolyEval(Ctxt& ret, const Ctxt poly[], long nCoeffs, const Vec<Ctxt>& powers) { if (nCoeffs <= 1) { // edge condition if (nCoeffs == 0) ret.clear(); // empty polynomial else ret = poly[0]; // constant polynomial return; } long logD = NextPowerOfTwo(nCoeffs)-1; long d = 1L << logD; Ctxt tmp(ZeroCtxtLike, ret); recursivePolyEval(tmp, &(poly[d]), nCoeffs-d, powers); recursivePolyEval(ret, &(poly[0]), d, powers); tmp.multiplyBy(powers[logD]); ret += tmp; }
// The input is a plaintext table T[] and an array of encrypted bits // I[], holding the binary representation of an index i into T. // The output is the encrypted value T[i]. void tableLookup(Ctxt& out, const vector<zzX>& table, const CtPtrs& idx, std::vector<zzX>* unpackSlotEncoding) { FHE_TIMER_START; out.clear(); vector<Ctxt> products(lsize(table), out); // to hold subset products of idx CtPtrs_vectorCt pWrap(products); // A wrapper // Compute all products of ecnrypted bits =: b_i computeAllProducts(pWrap, idx, unpackSlotEncoding); // Compute the sum b_i * T[i] NTL_EXEC_RANGE(lsize(table), first, last) for(long i=first; i<last; i++) products[i].multByConstant(table[i]); // p[i] = p[i]*T[i] NTL_EXEC_RANGE_END for(long i=0; i<lsize(table); i++) out += products[i]; }
// Main entry point: Evaluate an encrypted polynomial on an encrypted input // return in ret = sum_i poly[i] * x^i void polyEval(Ctxt& ret, const Vec<Ctxt>& poly, const Ctxt& x) { if (poly.length()<=1) { // Some special cases if (poly.length()==0) ret.clear(); // empty polynomial else ret = poly[0]; // constant polynomial return; } long deg = poly.length()-1; long logD = NextPowerOfTwo(divc(poly.length(),3)); long d = 1L << logD; // We have d <= deg(poly) < 3d assert(d <= deg && deg < 3*d); Vec<Ctxt> powers(INIT_SIZE, logD+1, x); if (logD>0) { powers[1].square(); for (long i=2; i<=logD; i++) { // powers[i] = x^{2^i} powers[i] = powers[i-1]; powers[i].square(); } } // Compute in three parts p0(X) + ( p1(X) + p2(X)*X^d )*X^d Ctxt tmp(ZeroCtxtLike, ret); recursivePolyEval(ret, &poly[d], min(d,poly.length()-d), powers); // p1(X) if (poly.length() > 2*d) { // p2 is not empty recursivePolyEval(tmp, &poly[2*d], poly.length()-2*d, powers); // p2(X) tmp.multiplyBy(powers[logD]); ret += tmp; } ret.multiplyBy(powers[logD]); // ( p1(X) + p2(X)*X^d )*X^d recursivePolyEval(tmp, &poly[0], d, powers); // p0(X) ret += tmp; }
// Simple evaluation sum f_i * X^i, assuming that babyStep has enough powers static void simplePolyEval(Ctxt& ret, const ZZX& poly, DynamicCtxtPowers& babyStep) { ret.clear(); if (deg(poly)<0) return; // the zero polynomial always returns zero assert (deg(poly)<=babyStep.size()); // ensure that we have enough powers ZZ coef; ZZ p = to_ZZ(babyStep[0].getPtxtSpace()); for (long i=1; i<=deg(poly); i++) { rem(coef, coeff(poly,i),p); if (coef > p/2) coef -= p; Ctxt tmp = babyStep.getPower(i); // X^i tmp.multByConstant(coef); // f_i X^i ret += tmp; } // Add the free term rem(coef, ConstTerm(poly), p); if (coef > p/2) coef -= p; ret.addConstant(coef); // if (verbose) checkPolyEval(ret, babyStep[0], poly); }
// Main entry point: Evaluate a cleartext polynomial on an encrypted input void polyEval(Ctxt& ret, ZZX poly, const Ctxt& x, long k) // Note: poly is passed by value, so caller keeps the original { if (deg(poly)<=2) { // nothing to optimize here if (deg(poly)<1) { // A constant ret.clear(); ret.addConstant(coeff(poly, 0)); } else { // A linear or quadratic polynomial DynamicCtxtPowers babyStep(x, deg(poly)); simplePolyEval(ret, poly, babyStep); } return; } // How many baby steps: set k~sqrt(n/2), rounded up/down to a power of two // FIXME: There may be some room for optimization here: it may be possible // to choose k as something other than a power of two and still maintain // optimal depth, in principle we can try all possible values of k between // two consecutive powers of two and choose the one that gives the least // number of multiplies, conditioned on minimum depth. if (k<=0) { long kk = (long) sqrt(deg(poly)/2.0); k = 1L << NextPowerOfTwo(kk); // heuristic: if k>>kk then use a smaler power of two if ((k==16 && deg(poly)>167) || (k>16 && k>(1.44*kk))) k /= 2; } #ifdef DEBUG_PRINTOUT cerr << " k="<<k; #endif long n = divc(deg(poly),k); // n = ceil(deg(p)/k), deg(p) >= k*n DynamicCtxtPowers babyStep(x, k); const Ctxt& x2k = babyStep.getPower(k); // Special case when deg(p)>k*(2^e -1) if (n==(1L << NextPowerOfTwo(n))) { // n is a power of two DynamicCtxtPowers giantStep(x2k, n/2); degPowerOfTwo(ret, poly, k, babyStep, giantStep); return; } // If n is not a power of two, ensure that poly is monic and that // its degree is divisible by k, then call the recursive procedure const ZZ p = to_ZZ(x.getPtxtSpace()); ZZ top = LeadCoeff(poly); ZZ topInv; // the inverse mod p of the top coefficient of poly (if any) bool divisible = (n*k == deg(poly)); // is the degree divisible by k? long nonInvertibe = InvModStatus(topInv, top, p); // 0 if invertible, 1 if not // FIXME: There may be some room for optimization below: instead of // adding a term X^{n*k} we can add X^{n'*k} for some n'>n, so long // as n' is smaller than the next power of two. We could save a few // multiplications since giantStep[n'] may be easier to compute than // giantStep[n] when n' has fewer 1's than n in its binary expansion. ZZ extra = ZZ::zero(); // extra!=0 denotes an added term extra*X^{n*k} if (!divisible || nonInvertibe) { // need to add a term top = to_ZZ(1); // new top coefficient is one topInv = top; // also the new inverse is one // set extra = 1 - current-coeff-of-X^{n*k} extra = SubMod(top, coeff(poly,n*k), p); SetCoeff(poly, n*k); // set the top coefficient of X^{n*k} to one } long t = IsZero(extra)? divc(n,2) : n; DynamicCtxtPowers giantStep(x2k, t); if (!IsOne(top)) { poly *= topInv; // Multiply by topInv to make into a monic polynomial for (long i=0; i<=n*k; i++) rem(poly[i], poly[i], p); poly.normalize(); } recursivePolyEval(ret, poly, k, babyStep, giantStep); if (!IsOne(top)) { ret.multByConstant(top); } if (!IsZero(extra)) { // if we added a term, now is the time to subtract back Ctxt topTerm = giantStep.getPower(n); topTerm.multByConstant(extra); ret -= topTerm; } }