예제 #1
0
/************** Each round consists of the following:
1. c1.multiplyBy(c0)
2. c0 += random constant
3. c2 *= random constant
4. tmp = c1
5. ea.rotate(tmp, random amount in [-nSlots/2, nSlots/2])
6. c2 += tmp
7. ea.rotate(c2, random amount in [1-nSlots, nSlots-1])
8. c1.negate()
9. c3.multiplyBy(c2) 
10. c0 -= c3
**************/
void testGeneralOps(const FHEPubKey& publicKey, const FHESecKey& secretKey,
                    const EncryptedArrayCx& ea, double epsilon,
                    long nRounds)
{
  long nslots = ea.size();
  char buffer[32];

  vector<cx_double> p0, p1, p2, p3;
  ea.random(p0);
  ea.random(p1);
  ea.random(p2);
  ea.random(p3);

  Ctxt c0(publicKey), c1(publicKey), c2(publicKey), c3(publicKey);
  ea.encrypt(c0, publicKey, p0, /*size=*/1.0);
  ea.encrypt(c1, publicKey, p1, /*size=*/1.0);
  ea.encrypt(c2, publicKey, p2, /*size=*/1.0);
  ea.encrypt(c3, publicKey, p3, /*size=*/1.0);

  resetAllTimers();
  FHE_NTIMER_START(Circuit);

  for (long i = 0; i < nRounds; i++) {

    if (verbose) std::cout << "*** round " << i << "..."<<endl;

     long shamt = RandomBnd(2*(nslots/2) + 1) - (nslots/2);
                  // random number in [-nslots/2..nslots/2]
     long rotamt = RandomBnd(2*nslots - 1) - (nslots - 1);
                  // random number in [-(nslots-1)..nslots-1]

     // two random constants
     vector<cx_double> const1, const2;
     ea.random(const1);
     ea.random(const2);

     ZZX const1_poly, const2_poly;
     ea.encode(const1_poly, const1, /*size=*/1.0);
     ea.encode(const2_poly, const2, /*size=*/1.0);

     mul(p1, p0);     // c1.multiplyBy(c0)
     c1.multiplyBy(c0);
     if (verbose) {
       CheckCtxt(c1, "c1*=c0");
       debugCompare(ea, secretKey, p1, c1, epsilon);
     }

     add(p0, const1); // c0 += random constant
     c0.addConstant(const1_poly);
     if (verbose) {
       CheckCtxt(c0, "c0+=k1");
       debugCompare(ea, secretKey, p0, c0, epsilon);
     }
     mul(p2, const2); // c2 *= random constant
     c2.multByConstant(const2_poly);
     if (verbose) {
       CheckCtxt(c2, "c2*=k2");
       debugCompare(ea, secretKey, p2, c2, epsilon);
     }
     vector<cx_double> tmp_p(p1); // tmp = c1
     Ctxt tmp(c1);
     sprintf(buffer, "tmp=c1>>=%d", (int)shamt);
     rotate(tmp_p, shamt); // ea.shift(tmp, random amount in [-nSlots/2,nSlots/2])
     ea.rotate(tmp, shamt);
     if (verbose) {
       CheckCtxt(tmp, buffer);
       debugCompare(ea, secretKey, tmp_p, tmp, epsilon);
     }
     add(p2, tmp_p);  // c2 += tmp
     c2 += tmp;
     if (verbose) {
       CheckCtxt(c2, "c2+=tmp");
       debugCompare(ea, secretKey, p2, c2, epsilon);
     }
     sprintf(buffer, "c2>>>=%d", (int)rotamt);
     rotate(p2, rotamt); // ea.rotate(c2, random amount in [1-nSlots, nSlots-1])
     ea.rotate(c2, rotamt);
     if (verbose) {
       CheckCtxt(c2, buffer);
       debugCompare(ea, secretKey, p2, c2, epsilon);
     }
     negateVec(p1); // c1.negate()
     c1.negate();
     if (verbose) {
       CheckCtxt(c1, "c1=-c1");
       debugCompare(ea, secretKey, p1, c1, epsilon);
     }
     mul(p3, p2); // c3.multiplyBy(c2) 
     c3.multiplyBy(c2);
     if (verbose) {
       CheckCtxt(c3, "c3*=c2");
       debugCompare(ea, secretKey, p3, c3, epsilon);
     }
     sub(p0, p3); // c0 -= c3
     c0 -= c3;
     if (verbose) {
       CheckCtxt(c0, "c0=-c3");
       debugCompare(ea, secretKey, p0, c0, epsilon);
     }
  }

  c0.cleanUp();
  c1.cleanUp();
  c2.cleanUp();
  c3.cleanUp();

  FHE_NTIMER_STOP(Circuit);

  vector<cx_double> pp0, pp1, pp2, pp3;
   
  ea.decrypt(c0, secretKey, pp0);
  ea.decrypt(c1, secretKey, pp1);
  ea.decrypt(c2, secretKey, pp2);
  ea.decrypt(c3, secretKey, pp3);

  std::cout << "Test "<<nRounds<<" rounds of mixed operations, ";
  if (cx_equals(pp0, p0,conv<double>(epsilon*c0.getPtxtMag()))
      && cx_equals(pp1, p1,conv<double>(epsilon*c1.getPtxtMag()))
      && cx_equals(pp2, p2,conv<double>(epsilon*c2.getPtxtMag()))
      && cx_equals(pp3, p3,conv<double>(epsilon*c3.getPtxtMag())))
    std::cout << "PASS\n\n";
  else {
    std::cout << "FAIL:\n";
    std::cout << "  max(p0)="<<largestCoeff(p0)
              << ", max(pp0)="<<largestCoeff(pp0)
              << ", maxDiff="<<calcMaxDiff(p0,pp0) << endl;
    std::cout << "  max(p1)="<<largestCoeff(p1)
              << ", max(pp1)="<<largestCoeff(pp1)
              << ", maxDiff="<<calcMaxDiff(p1,pp1) << endl;
    std::cout << "  max(p2)="<<largestCoeff(p2)
              << ", max(pp2)="<<largestCoeff(pp2)
              << ", maxDiff="<<calcMaxDiff(p2,pp2) << endl;
    std::cout << "  max(p3)="<<largestCoeff(p3)
              << ", max(pp3)="<<largestCoeff(pp3)
              << ", maxDiff="<<calcMaxDiff(p3,pp3) << endl<<endl;
  }

  if (verbose) {
    std::cout << endl;
    printAllTimers();
    std::cout << endl;
  }
  resetAllTimers();
   }
예제 #2
0
  void AngleHarmonicChain_2d::interact() const {
    // Call parent class.
    AngleHarmonicChain::interact();
    // Get simdata, check if the local ids need updating
    SimData *sd = Base::simData;
    RealType **x = sd->X();
    RealType **f = sd->F();
    if (sd->getNeedsRemake()) updateLocalIDs();
    // If there are not enough particles in the buffer
    if (local_ids.size()<3) return;
    // Variables and buffers
    RealType dx1[2], dx2[2], rsqr1, rsqr2, r1, r2, invr1, invr2, angle_cos, a11, a12, a22;
    RealType f1[2], f3[2];
    // Get the bounds and boundary conditions
    Bounds bounds = Base::gflow->getBounds(); // Simulation bounds
    BCFlag boundaryConditions[2];
    copyVec(Base::gflow->getBCs(), boundaryConditions, 2); // Keep a local copy of the bcs
    // Extract bounds related data
    RealType bnd_x = bounds.wd(0);
    RealType bnd_y = bounds.wd(1);

    // If there are fewer than three particles, this loop will not execute
    for (int i=0; i+2<local_ids.size(); ++i) {
      // Get the global, then local ids of the particles.
      int id1 = local_ids[i], id2 = local_ids[i+1], id3=local_ids[i+2];
      // id1     id2     id3
      //  A <---- B ----> C
      //     dx1    dx2
      //
      // First bond
      dx1[0] = x[id1][0] - x[id2][0];
      dx1[1] = x[id1][1] - x[id2][1];
      // Second bond
      dx2[0] = x[id3][0] - x[id2][0];
      dx2[1] = x[id3][1] - x[id2][1];
      // Minimum image under harmonic b.c.s
      gflow->minimumImage(dx1);
      gflow->minimumImage(dx2);

      /*
      // Harmonic corrections to distance.
      if (boundaryConditions[0]==BCFlag::WRAP) {
        RealType dX = bnd_x - fabs(dx1);
        if (dX<fabs(dx1)) dx1 = dx1>0 ? -dX : dX;
        dX = bnd_x - fabs(dx2);
        if (dX<fabs(dx2)) dx2 = dx2>0 ? -dX : dX;
      }  
      if (boundaryConditions[1]==BCFlag::WRAP) {
        RealType dY = bnd_y - fabs(dy1);
        if (dY<fabs(dy1)) dy1 = dy1>0 ? -dY : dY;
        dY = bnd_y - fabs(dy2);
        if (dY<fabs(dy2)) dy2 = dy1>0 ? -dY : dY;
      } 
      */

      
      // Get distance squared, distance
      rsqr1 = dx1[0]*dx1[0] + dx1[1]*dx1[1];
      r1 = sqrt(rsqr1);
      rsqr2 = dx2[0]*dx2[0] + dx2[1]*dx2[1];
      r2 = sqrt(rsqr2);
      // Find the cosine of the angle between the atoms
      angle_cos = dx1[0]*dx2[0] + dx1[1]*dx2[1];
      angle_cos /= (r1*r2);

      RealType theta = acos(angle_cos);
      //cout << theta / (PI) << endl;


      /*
      // Calculate the force
      a11 = angleConstant * angle_cos / rsqr1;
      a12 = -angleConstant / (r1*r2);
      a22 = angleConstant * angle_cos / rsqr2;
      // Force buffers
      f1[0] = a11*dx1 + a12*dx2;
      f1[1] = a11*dy1 + a12*dy2;
      f3[0] = a22*dx2 + a12*dx1;
      f3[1] = a22*dy2 + a12*dy1;
      */

      RealType pa[2], pb[2];
      tripleProduct2(dx1, dx1, dx2, pa);
      tripleProduct2(dx2, dx1, dx2, pb);
      negateVec(pb, 2);

      RealType str = -2*angleConstant*(theta - PI);
      RealType strength1 = str/r1;
      RealType strength2 = str/r2;
      f1[0] = strength1*pa[0];
      f1[1] = strength1*pa[1];

      f3[0] = strength2*pb[0];
      f3[1] = strength2*pb[1];


      // First atom
      f[id1][0] += f1[0];
      f[id1][1] += f1[1];
      // Second atom
      f[id2][0] -= (f1[0] + f3[0]);
      f[id2][1] -= (f1[1] + f3[1]);
      // Third atom
      f[id3][0] += f3[0];
      f[id3][1] += f3[1];
    }
    // For two atoms, there is no angle. Just compute the bond force.
  }