int main(int argc, char *argv[]) { long seed = time(NULL); srand48(seed); SetSeed(to_ZZ(seed)); unsigned p = 2027; unsigned g = 3; unsigned logQ = 120; FHEcontext context(p-1, logQ, p, g); activeContext = &context; context.SetUpSIContext(); FHESISecKey secretKey(context); const FHESIPubKey &publicKey(secretKey); KeySwitchSI keySwitch(secretKey); long phim = context.zMstar.phiM(); long numSlots = context.GetPlaintextSpace().GetTotalSlots(); long rotAmt = rand() % numSlots; long rotDeg = 1; for (int i = 0; i < rotAmt; i++) { rotDeg *= context.Generator(); rotDeg %= context.zMstar.M(); } KeySwitchSI automorphKeySwitch(secretKey, rotDeg); Plaintext p0, p1, p2, p3; randomizePlaintext(p0, phim, p); randomizePlaintext(p1, phim, p); randomizePlaintext(p2, phim, p); randomizePlaintext(p3, phim, p); Plaintext const1, const2; randomizePlaintext(const1, phim, p); randomizePlaintext(const2, phim, p); Ciphertext c0(publicKey); Ciphertext c1(publicKey); Ciphertext c2(publicKey); Ciphertext c3(publicKey); publicKey.Encrypt(c0, p0); publicKey.Encrypt(c1, p1); publicKey.Encrypt(c2, p2); publicKey.Encrypt(c3, p3); p1 *= p2; p0 += const1; p2 *= const2; p3 >>= rotAmt; p1 *= -1; p3 *= p2; p0 -= p3; c1 *= c2; keySwitch.ApplyKeySwitch(c1); c0 += const1.message; c2 *= const2.message; c3 >>= rotDeg; automorphKeySwitch.ApplyKeySwitch(c3); c1 *= -1; c3 *= c2; keySwitch.ApplyKeySwitch(c3); Ciphertext tmp(c3); tmp *= -1; c0 += tmp; Plaintext pp0, pp1, pp2, pp3; secretKey.Decrypt(pp0, c0); secretKey.Decrypt(pp1, c1); secretKey.Decrypt(pp2, c2); secretKey.Decrypt(pp3, c3); if (!(pp0 == p0)) cerr << "oops 0" << endl; if (!(pp1 == p1)) cerr << "oops 1" << endl; if (!(pp2 == p2)) cerr << "oops 2" << endl; if (!(pp3 == p3)) cerr << "oops 3" << endl; cout << "All tests finished." << endl; }
//Comparison protocol based on "" bool COM(bool disp, long long seed, unsigned p, FHEcontext &context) { ZZ seedZZ; seedZZ = seed; srand48(seed); SetSeed(seedZZ); FHESISecKey secretKey(context); const FHESIPubKey &publicKey(secretKey); long phim = context.zMstar.phiM(); ZZ_pX ptxt1Poly, ptxt2Poly, sum, sumMult, prod, prod2, sumQuad; Plaintext resSum, resSumMult, resProd, resProdSwitch, resProd2, resSumQuad; //gen plaintext ptxt1Poly.rep.SetLength(phim); ptxt2Poly.rep.SetLength(phim); for (long i=0; i < phim; i++) { ptxt1Poly.rep[i] = RandomBnd(p); ptxt2Poly.rep[i] = RandomBnd(p); } //printpoly(ptxt1Poly, phim); ptxt1Poly.normalize(); ptxt2Poly.normalize(); #ifdef DEBUG cout<<"phim:"<<phim<<endl; cout<<"p1:"<<endl; printpoly(ptxt1Poly, phim); cout<<"p2:"<<endl; printpoly(ptxt2Poly, phim); #endif //plaintext operation sum = ptxt1Poly + ptxt2Poly; sumMult = ptxt2Poly * 7; prod = ptxt1Poly * ptxt2Poly; prod2 = prod * prod; sumQuad = prod2 * prod2 * 9; //\sum_((xy)^4) #ifdef DEBUG cout<<"sum:"<<endl; printpoly(sum, phim); cout<<"prod2:"<<endl; printpoly(prod2, phim); #endif rem(prod, prod, to_ZZ_pX(context.zMstar.PhimX())); rem(prod2, prod2, to_ZZ_pX(context.zMstar.PhimX())); rem(sumQuad, sumQuad, to_ZZ_pX(context.zMstar.PhimX())); //encryption start = std::clock(); Plaintext ptxt1(context, ptxt1Poly), ptxt2(context, ptxt2Poly); duration = ( std::clock() - start ) / (double) CLOCKS_PER_SEC; duration = duration/2; cout<<"Encryption:"<< duration <<'\n'; Ciphertext ctxt1(publicKey), ctxt2(publicKey); publicKey.Encrypt(ctxt1, ptxt1); publicKey.Encrypt(ctxt2, ptxt2); Ciphertext cSum = ctxt1; cSum += ctxt2; Ciphertext cSumMult = ctxt2; start = std::clock(); for (int i = 1; i < 7; i++) { cSumMult += ctxt2; } duration = ( std::clock() - start ) / (double) (CLOCKS_PER_SEC); duration = duration/6; cout<<"Addition:"<< duration <<'\n'; Ciphertext cProd = ctxt1; cProd *= ctxt2; secretKey.Decrypt(resSum, cSum); secretKey.Decrypt(resSumMult, cSumMult); KeySwitchSI keySwitch(secretKey); keySwitch.ApplyKeySwitch(cProd); secretKey.Decrypt(resProd, cProd); cProd *= cProd; Ciphertext tmp = cProd; Ciphertext cSumQuad = cProd; keySwitch.ApplyKeySwitch(cProd); secretKey.Decrypt(resProd2, cProd); for (int i = 0; i < 8; i++) { cSumQuad += tmp; } keySwitch.ApplyKeySwitch(cSumQuad); //apply key switch after summing all prod start = std::clock(); cSumQuad *= cProd; duration = ( std::clock() - start ) / (double) (CLOCKS_PER_SEC); cout<<"HMult without key switch:"<< duration <<'\n'; keySwitch.ApplyKeySwitch(cSumQuad); duration = ( std::clock() - start ) / (double) (CLOCKS_PER_SEC); cout<<"HMult with key switch:"<< duration <<'\n'; start = std::clock(); secretKey.Decrypt(resSumQuad, cSumQuad); duration = ( std::clock() - start ) / (double) (CLOCKS_PER_SEC); cout<<"Decryption:"<< duration <<'\n'; //comparison bool success = ((resSum.message == sum) && (resSumMult.message == sumMult) && (resProd.message == prod) && (resProd2.message == prod2) && (resSumQuad == sumQuad)); if (disp || !success) { cout << "Seed: " << seed << endl << endl; if (resSum.message != sum) { cout << "Add failed." << endl; } if (resSumMult.message != sumMult) { cout << "Adding multiple times failed." << endl; } if (resProd.message != prod) { cout << "Multiply failed." << endl; } if (resProd2.message != prod2) { cout << "Squaring failed." << endl; } if (resSumQuad.message != sumQuad) { cout << "Sum and quad failed." << endl; } } if (disp || !success) { cout << "Test " << (success ? "SUCCEEDED" : "FAILED") << endl; } return success; }