Exemple #1
0
int main(int argc, char *argv[]) {
	long seed = time(NULL);

	srand48(seed);
	SetSeed(to_ZZ(seed));

	unsigned p = 2027;
	unsigned g = 3;
	unsigned logQ = 120;

	FHEcontext context(p-1, logQ, p, g);
	activeContext = &context;
	context.SetUpSIContext();

	FHESISecKey secretKey(context);
	const FHESIPubKey &publicKey(secretKey);
	KeySwitchSI keySwitch(secretKey);

	long phim = context.zMstar.phiM();
	long numSlots = context.GetPlaintextSpace().GetTotalSlots();

	long rotAmt = rand() % numSlots;
	long rotDeg = 1;
	for (int i = 0; i < rotAmt; i++) {
		rotDeg *= context.Generator();
		rotDeg %= context.zMstar.M();
	}
	KeySwitchSI automorphKeySwitch(secretKey, rotDeg);

	Plaintext p0, p1, p2, p3;
	randomizePlaintext(p0, phim, p);
	randomizePlaintext(p1, phim, p);
	randomizePlaintext(p2, phim, p);
	randomizePlaintext(p3, phim, p);

	Plaintext const1, const2;
	randomizePlaintext(const1, phim, p);
	randomizePlaintext(const2, phim, p);

	Ciphertext c0(publicKey);
	Ciphertext c1(publicKey);
	Ciphertext c2(publicKey);
	Ciphertext c3(publicKey);

	publicKey.Encrypt(c0, p0);
	publicKey.Encrypt(c1, p1);
	publicKey.Encrypt(c2, p2);
	publicKey.Encrypt(c3, p3);

	p1 *= p2;
	p0 += const1;
	p2 *= const2;
	p3 >>= rotAmt;
	p1 *= -1;
	p3 *= p2;
	p0 -= p3;

	c1 *= c2;
	keySwitch.ApplyKeySwitch(c1);
	c0 += const1.message;

	c2 *= const2.message;

	c3 >>= rotDeg;
	automorphKeySwitch.ApplyKeySwitch(c3);

	c1 *= -1;
	c3 *= c2;
	keySwitch.ApplyKeySwitch(c3);

	Ciphertext tmp(c3);
	tmp *= -1;
	c0 += tmp;

	Plaintext pp0, pp1, pp2, pp3;
	secretKey.Decrypt(pp0, c0);
	secretKey.Decrypt(pp1, c1);
	secretKey.Decrypt(pp2, c2);
	secretKey.Decrypt(pp3, c3);

    if (!(pp0 == p0)) cerr << "oops 0" << endl;
    if (!(pp1 == p1)) cerr << "oops 1" << endl;
    if (!(pp2 == p2)) cerr << "oops 2" << endl;
    if (!(pp3 == p3)) cerr << "oops 3" << endl;
    cout << "All tests finished." << endl;
}
//Comparison protocol based on ""
bool COM(bool disp, long long seed, unsigned p, FHEcontext &context) {
  ZZ seedZZ;
  seedZZ = seed;

  srand48(seed);
  SetSeed(seedZZ);

  FHESISecKey secretKey(context);
  const FHESIPubKey &publicKey(secretKey);
  
  long phim = context.zMstar.phiM();
  
  ZZ_pX ptxt1Poly, ptxt2Poly, sum, sumMult, prod, prod2, sumQuad;
  Plaintext resSum, resSumMult, resProd, resProdSwitch, resProd2, resSumQuad;

  //gen plaintext
  ptxt1Poly.rep.SetLength(phim);
  ptxt2Poly.rep.SetLength(phim);
  for (long i=0; i < phim; i++) {
    ptxt1Poly.rep[i] = RandomBnd(p);
    ptxt2Poly.rep[i] = RandomBnd(p);
  }

  //printpoly(ptxt1Poly, phim);

  ptxt1Poly.normalize();
  ptxt2Poly.normalize();
  
  #ifdef DEBUG
  cout<<"phim:"<<phim<<endl;
  cout<<"p1:"<<endl;
  printpoly(ptxt1Poly, phim);
  cout<<"p2:"<<endl;
  printpoly(ptxt2Poly, phim);
  #endif

  //plaintext operation
  sum = ptxt1Poly + ptxt2Poly;
  sumMult = ptxt2Poly * 7;
  prod = ptxt1Poly * ptxt2Poly;
  prod2 = prod * prod;
  sumQuad = prod2 * prod2 * 9; //\sum_((xy)^4)

  #ifdef DEBUG  
  cout<<"sum:"<<endl;
  printpoly(sum, phim);
  cout<<"prod2:"<<endl;
  printpoly(prod2, phim);
  #endif

  rem(prod, prod, to_ZZ_pX(context.zMstar.PhimX()));
  rem(prod2, prod2, to_ZZ_pX(context.zMstar.PhimX()));
  rem(sumQuad, sumQuad, to_ZZ_pX(context.zMstar.PhimX()));
  
  //encryption
  start = std::clock();
  Plaintext ptxt1(context, ptxt1Poly), ptxt2(context, ptxt2Poly);
  duration = ( std::clock() - start ) / (double) CLOCKS_PER_SEC;
  duration = duration/2;
  cout<<"Encryption:"<< duration <<'\n';

  Ciphertext ctxt1(publicKey), ctxt2(publicKey);
  publicKey.Encrypt(ctxt1, ptxt1);
  publicKey.Encrypt(ctxt2, ptxt2);

  Ciphertext cSum = ctxt1;
  cSum += ctxt2;

  Ciphertext cSumMult = ctxt2;
  start = std::clock();
  for (int i = 1; i < 7; i++) {
    cSumMult += ctxt2;
  }
  duration = ( std::clock() - start ) / (double) (CLOCKS_PER_SEC);
  duration = duration/6;
  cout<<"Addition:"<< duration <<'\n';

  Ciphertext cProd = ctxt1;
  cProd *= ctxt2;
 
  secretKey.Decrypt(resSum, cSum);
  secretKey.Decrypt(resSumMult, cSumMult);
  
  KeySwitchSI keySwitch(secretKey);
  keySwitch.ApplyKeySwitch(cProd);
  secretKey.Decrypt(resProd, cProd);

  cProd *= cProd;
  Ciphertext tmp = cProd;
  Ciphertext cSumQuad = cProd;
  
  keySwitch.ApplyKeySwitch(cProd);
  secretKey.Decrypt(resProd2, cProd);

  for (int i = 0; i < 8; i++) {
    cSumQuad += tmp;
  }
  keySwitch.ApplyKeySwitch(cSumQuad);  //apply key switch after summing all prod

  start = std::clock();
  cSumQuad *= cProd;
  duration = ( std::clock() - start ) / (double) (CLOCKS_PER_SEC);
  cout<<"HMult without key switch:"<< duration <<'\n';

  keySwitch.ApplyKeySwitch(cSumQuad);

  duration = ( std::clock() - start ) / (double) (CLOCKS_PER_SEC);
  cout<<"HMult with key switch:"<< duration <<'\n';

  start = std::clock();
  secretKey.Decrypt(resSumQuad, cSumQuad);
  duration = ( std::clock() - start ) / (double) (CLOCKS_PER_SEC);
  cout<<"Decryption:"<< duration <<'\n';
  
  //comparison
  bool success = ((resSum.message == sum) && (resSumMult.message == sumMult) &&
                  (resProd.message == prod) && (resProd2.message == prod2) &&
                  (resSumQuad == sumQuad));
  
  if (disp || !success) {
    cout << "Seed: " << seed << endl << endl;
    
    if (resSum.message != sum) {
      cout << "Add failed." << endl;
    }
    if (resSumMult.message != sumMult) {
      cout << "Adding multiple times failed." << endl;
    }
    if (resProd.message != prod) {
      cout << "Multiply failed." << endl;
    }
    if (resProd2.message != prod2) {
      cout << "Squaring failed." << endl;
    }
    if (resSumQuad.message != sumQuad) {
      cout << "Sum and quad failed." << endl;
    }
  }
  
  if (disp || !success) {
    cout << "Test " << (success ? "SUCCEEDED" : "FAILED") << endl;
  }

  return success;
}