Exemple #1
0
// Testing the I/O of the important classes of the library
// (context, keys, ciphertexts).
int main(int argc, char *argv[])
{
  ArgMapping amap;

  long r=1;
  long p=2;
  long c = 2;
  long w = 64;
  long L = 5;
  long mm=0;
  amap.arg("p", p, "plaintext base");
  amap.arg("r", r,  "lifting");
  amap.arg("c", c, "number of columns in the key-switching matrices");
  amap.arg("m", mm, "cyclotomic index","{31,127,1023}");
  amap.parse(argc, argv);

  bool useTable = (mm==0 && p==2);
  long ptxtSpace = power_long(p,r);
  long numTests = useTable? N_TESTS : 1;

  std::unique_ptr<FHEcontext> contexts[numTests];
  std::unique_ptr<FHESecKey> sKeys[numTests];
  std::unique_ptr<Ctxt> ctxts[numTests];
  std::unique_ptr<EncryptedArray> eas[numTests];
  vector<ZZX> ptxts[numTests];

  // first loop: generate stuff and write it to cout

  // open file for writing
  {fstream keyFile("iotest.txt", fstream::out|fstream::trunc);
   assert(keyFile.is_open());
  for (long i=0; i<numTests; i++) {
    long m = (mm==0)? ms[i][1] : mm;

    cout << "Testing IO: m="<<m<<", p^r="<<p<<"^"<<r<<endl;

    Vec<long> mvec(INIT_SIZE,2);
    mvec[0] = ms[i][4];  mvec[1] = ms[i][5];
    vector<long> gens(2);
    gens[0] = ms[i][6];  gens[1] = ms[i][7];
    vector<long> ords(2);
    ords[0] = ms[i][8];  ords[1] = ms[i][9];

    if (useTable && gens[0]>0)
      contexts[i].reset(new FHEcontext(m, p, r, gens, ords));
    else
      contexts[i].reset(new FHEcontext(m, p, r));
    contexts[i]->zMStar.printout();

    buildModChain(*contexts[i], L, c);  // Set the modulus chain
    if (mm==0 && m==1023) contexts[i]->makeBootstrappable(mvec);

    // Output the FHEcontext to file
    writeContextBase(keyFile, *contexts[i]);
    writeContextBase(cout, *contexts[i]);
    keyFile << *contexts[i] << endl;

    sKeys[i].reset(new FHESecKey(*contexts[i]));
    const FHEPubKey& publicKey = *sKeys[i];
    sKeys[i]->GenSecKey(w,ptxtSpace); // A Hamming-weight-w secret key
    addSome1DMatrices(*sKeys[i]);// compute key-switching matrices that we need
    eas[i].reset(new EncryptedArray(*contexts[i]));

    long nslots = eas[i]->size();

    // Output the secret key to file, twice. Below we will have two copies
    // of most things.
    keyFile << *sKeys[i] << endl;;
    keyFile << *sKeys[i] << endl;;

    vector<ZZX> b;
    long p2r = eas[i]->getContext().alMod.getPPowR();
    ZZX poly = RandPoly(0,to_ZZ(p2r)); // choose a random constant polynomial
    eas[i]->decode(ptxts[i], poly);

    ctxts[i].reset(new Ctxt(publicKey));
    eas[i]->encrypt(*ctxts[i], publicKey, ptxts[i]);
    eas[i]->decrypt(*ctxts[i], *sKeys[i], b);
    assert(ptxts[i].size() == b.size());
    for (long j = 0; j < nslots; j++) assert (ptxts[i][j] == b[j]);

    // output the plaintext
    keyFile << "[ ";
    for (long j = 0; j < nslots; j++) keyFile << ptxts[i][j] << " ";
    keyFile << "]\n";

    eas[i]->encode(poly,ptxts[i]);
    keyFile << poly << endl;

    // Output the ciphertext to file
    keyFile << *ctxts[i] << endl;
    keyFile << *ctxts[i] << endl;
    cerr << "okay " << i << endl<< endl;
  }
  keyFile.close();}
  cerr << "so far, so good\n\n";

  // second loop: read from input and repeat the computation

  // open file for read
  {fstream keyFile("iotest.txt", fstream::in);
  for (long i=0; i<numTests; i++) {

    // Read context from file
    unsigned long m1, p1, r1;
    vector<long> gens, ords;
    readContextBase(keyFile, m1, p1, r1, gens, ords);
    FHEcontext tmpContext(m1, p1, r1, gens, ords);
    keyFile >> tmpContext;
    assert (*contexts[i] == tmpContext);
    cerr << i << ": context matches input\n";

    // We define some things below wrt *contexts[i], not tmpContext.
    // This is because the various operator== methods check equality of
    // references, not equality of the referenced FHEcontext objects.
    FHEcontext& context = *contexts[i];
    FHESecKey secretKey(context);
    FHESecKey secretKey2(tmpContext);
    const FHEPubKey& publicKey = secretKey;
    const FHEPubKey& publicKey2 = secretKey2;

    keyFile >> secretKey;
    keyFile >> secretKey2;
    assert(secretKey == *sKeys[i]);
    cerr << "   secret key matches input\n";

    EncryptedArray ea(context);
    EncryptedArray ea2(tmpContext);

    long nslots = ea.size();

    // Read the plaintext from file
    vector<ZZX> a;
    a.resize(nslots);
    assert(nslots == (long)ptxts[i].size());
    seekPastChar(keyFile, '['); // defined in NumbTh.cpp
    for (long j = 0; j < nslots; j++) {
      keyFile >> a[j];
      assert(a[j] == ptxts[i][j]);
    }
    seekPastChar(keyFile, ']');
    cerr << "   ptxt matches input\n";

    // Read the encoded plaintext from file
    ZZX poly1, poly2;
    keyFile >> poly1;
    eas[i]->encode(poly2,a);
    assert(poly1 == poly2);
    cerr << "   eas[i].encode(a)==poly1 okay\n";

    ea.encode(poly2,a);
    assert(poly1 == poly2);
    cerr << "   ea.encode(a)==poly1 okay\n";

    ea2.encode(poly2,a);
    assert(poly1 == poly2);
    cerr << "   ea2.encode(a)==poly1 okay\n";

    eas[i]->decode(a,poly1);
    assert(nslots == (long)a.size());
    for (long j = 0; j < nslots; j++) assert(a[j] == ptxts[i][j]);
    cerr << "   eas[i].decode(poly1)==ptxts[i] okay\n";

    ea.decode(a,poly1);
    assert(nslots == (long)a.size());
    for (long j = 0; j < nslots; j++) assert(a[j] == ptxts[i][j]);
    cerr << "   ea.decode(poly1)==ptxts[i] okay\n";

    ea2.decode(a,poly1);
    assert(nslots == (long)a.size());
    for (long j = 0; j < nslots; j++) assert(a[j] == ptxts[i][j]);
    cerr << "   ea2.decode(poly1)==ptxts[i] okay\n";

    // Read ciperhtext from file
    Ctxt ctxt(publicKey);
    Ctxt ctxt2(publicKey2);
    keyFile >> ctxt;
    keyFile >> ctxt2;
    assert(ctxts[i]->equalsTo(ctxt,/*comparePkeys=*/false));
    cerr << "   ctxt matches input\n";

    sKeys[i]->Decrypt(poly2,*ctxts[i]);
    assert(poly1 == poly2);
    cerr << "   sKeys[i]->decrypt(*ctxts[i]) == poly1 okay\n";

    secretKey.Decrypt(poly2,*ctxts[i]);
    assert(poly1 == poly2);
    cerr << "   secretKey.decrypt(*ctxts[i]) == poly1 okay\n";

    secretKey.Decrypt(poly2,ctxt);
    assert(poly1 == poly2);
    cerr << "   secretKey.decrypt(ctxt) == poly1 okay\n";

    secretKey2.Decrypt(poly2,ctxt2);
    assert(poly1 == poly2);
    cerr << "   secretKey2.decrypt(ctxt2) == poly1 okay\n";

    eas[i]->decrypt(ctxt, *sKeys[i], a);
    assert(nslots == (long)a.size());
    for (long j = 0; j < nslots; j++) assert(a[j] == ptxts[i][j]);
    cerr << "   eas[i].decrypt(ctxt, *sKeys[i])==ptxts[i] okay\n";

    ea.decrypt(ctxt, secretKey, a);
    assert(nslots == (long)a.size());
    for (long j = 0; j < nslots; j++) assert(a[j] == ptxts[i][j]);
    cerr << "   ea.decrypt(ctxt, secretKey)==ptxts[i] okay\n";

    ea2.decrypt(ctxt2, secretKey2, a);
    assert(nslots == (long)a.size());
    for (long j = 0; j < nslots; j++) assert(a[j] == ptxts[i][j]);
    cerr << "   ea2.decrypt(ctxt2, secretKey2)==ptxts[i] okay\n";

    cerr << "test "<<i<<" okay\n\n";
  }}
  unlink("iotest.txt"); // clean up before exiting
}
//Comparison protocol based on ""
bool COM(bool disp, long long seed, unsigned p, FHEcontext &context) {
  ZZ seedZZ;
  seedZZ = seed;

  srand48(seed);
  SetSeed(seedZZ);

  FHESISecKey secretKey(context);
  const FHESIPubKey &publicKey(secretKey);
  
  long phim = context.zMstar.phiM();
  
  ZZ_pX ptxt1Poly, ptxt2Poly, sum, sumMult, prod, prod2, sumQuad;
  Plaintext resSum, resSumMult, resProd, resProdSwitch, resProd2, resSumQuad;

  //gen plaintext
  ptxt1Poly.rep.SetLength(phim);
  ptxt2Poly.rep.SetLength(phim);
  for (long i=0; i < phim; i++) {
    ptxt1Poly.rep[i] = RandomBnd(p);
    ptxt2Poly.rep[i] = RandomBnd(p);
  }

  //printpoly(ptxt1Poly, phim);

  ptxt1Poly.normalize();
  ptxt2Poly.normalize();
  
  #ifdef DEBUG
  cout<<"phim:"<<phim<<endl;
  cout<<"p1:"<<endl;
  printpoly(ptxt1Poly, phim);
  cout<<"p2:"<<endl;
  printpoly(ptxt2Poly, phim);
  #endif

  //plaintext operation
  sum = ptxt1Poly + ptxt2Poly;
  sumMult = ptxt2Poly * 7;
  prod = ptxt1Poly * ptxt2Poly;
  prod2 = prod * prod;
  sumQuad = prod2 * prod2 * 9; //\sum_((xy)^4)

  #ifdef DEBUG  
  cout<<"sum:"<<endl;
  printpoly(sum, phim);
  cout<<"prod2:"<<endl;
  printpoly(prod2, phim);
  #endif

  rem(prod, prod, to_ZZ_pX(context.zMstar.PhimX()));
  rem(prod2, prod2, to_ZZ_pX(context.zMstar.PhimX()));
  rem(sumQuad, sumQuad, to_ZZ_pX(context.zMstar.PhimX()));
  
  //encryption
  start = std::clock();
  Plaintext ptxt1(context, ptxt1Poly), ptxt2(context, ptxt2Poly);
  duration = ( std::clock() - start ) / (double) CLOCKS_PER_SEC;
  duration = duration/2;
  cout<<"Encryption:"<< duration <<'\n';

  Ciphertext ctxt1(publicKey), ctxt2(publicKey);
  publicKey.Encrypt(ctxt1, ptxt1);
  publicKey.Encrypt(ctxt2, ptxt2);

  Ciphertext cSum = ctxt1;
  cSum += ctxt2;

  Ciphertext cSumMult = ctxt2;
  start = std::clock();
  for (int i = 1; i < 7; i++) {
    cSumMult += ctxt2;
  }
  duration = ( std::clock() - start ) / (double) (CLOCKS_PER_SEC);
  duration = duration/6;
  cout<<"Addition:"<< duration <<'\n';

  Ciphertext cProd = ctxt1;
  cProd *= ctxt2;
 
  secretKey.Decrypt(resSum, cSum);
  secretKey.Decrypt(resSumMult, cSumMult);
  
  KeySwitchSI keySwitch(secretKey);
  keySwitch.ApplyKeySwitch(cProd);
  secretKey.Decrypt(resProd, cProd);

  cProd *= cProd;
  Ciphertext tmp = cProd;
  Ciphertext cSumQuad = cProd;
  
  keySwitch.ApplyKeySwitch(cProd);
  secretKey.Decrypt(resProd2, cProd);

  for (int i = 0; i < 8; i++) {
    cSumQuad += tmp;
  }
  keySwitch.ApplyKeySwitch(cSumQuad);  //apply key switch after summing all prod

  start = std::clock();
  cSumQuad *= cProd;
  duration = ( std::clock() - start ) / (double) (CLOCKS_PER_SEC);
  cout<<"HMult without key switch:"<< duration <<'\n';

  keySwitch.ApplyKeySwitch(cSumQuad);

  duration = ( std::clock() - start ) / (double) (CLOCKS_PER_SEC);
  cout<<"HMult with key switch:"<< duration <<'\n';

  start = std::clock();
  secretKey.Decrypt(resSumQuad, cSumQuad);
  duration = ( std::clock() - start ) / (double) (CLOCKS_PER_SEC);
  cout<<"Decryption:"<< duration <<'\n';
  
  //comparison
  bool success = ((resSum.message == sum) && (resSumMult.message == sumMult) &&
                  (resProd.message == prod) && (resProd2.message == prod2) &&
                  (resSumQuad == sumQuad));
  
  if (disp || !success) {
    cout << "Seed: " << seed << endl << endl;
    
    if (resSum.message != sum) {
      cout << "Add failed." << endl;
    }
    if (resSumMult.message != sumMult) {
      cout << "Adding multiple times failed." << endl;
    }
    if (resProd.message != prod) {
      cout << "Multiply failed." << endl;
    }
    if (resProd2.message != prod2) {
      cout << "Squaring failed." << endl;
    }
    if (resSumQuad.message != sumQuad) {
      cout << "Sum and quad failed." << endl;
    }
  }
  
  if (disp || !success) {
    cout << "Test " << (success ? "SUCCEEDED" : "FAILED") << endl;
  }

  return success;
}