// Adds several primes to the chain. If byNumber=true then totalSize specifies
// the number of primes to add. If byNumber=false then totalSize specifies the
// target total bitsize of all the added primes.
// The function returns the total bitsize of all the added primes.
double AddManyPrimes(FHEcontext& context, double totalSize,
                     bool byNumber, bool special)
    double nBits = 0.0;     // How many bits added so far
    double sizeSoFar = 0.0;
    if (!context.zMStar.getM() || context.zMStar.getM()>(1<<20))// sanity checks
        Error("AddManyPrimes: m undefined or larger than 2^20");

    if (ALT_CRT) {
        while (sizeSoFar < totalSize) {
            long p = context.AddFFTPrime(special);
            nBits += log((double)p);
            sizeSoFar = byNumber? (sizeSoFar+1.0) : nBits;
    else {
        long sizeBits = context.bitsPerLevel;
        long sizeBits = 2*context.bitsPerLevel;
        if (special) {
            long numPrimes = ceil(totalSize/NTL_SP_NBITS);// how many special primes
            sizeBits = ceil(totalSize/numPrimes);         // what's the size of each
        long twoM = 2 * context.zMStar.getM();

        if (sizeBits>NTL_SP_NBITS) sizeBits = NTL_SP_NBITS;
        long sizeBound = 1L << sizeBits;
        if (sizeBound < twoM*log2(twoM)*8) {
            sizeBits = ceil(log2(twoM*log2(twoM)))+3;
            sizeBound = 1L << sizeBits;

        // make p-1 divisible by m*2^k for as large k as possible
        while (twoM < sizeBound/(sizeBits*2)) twoM *= 2;

        long bigP = sizeBound - (sizeBound%twoM) +1; // 1 mod 2m
        long p = bigP+twoM; // The twoM is subtracted in the AddPrime function

        // FIXME: The last prime could be smaller
        while (sizeSoFar < totalSize) {
            if ((p = context.AddPrime(p,-twoM,special))) { // found a prime
                nBits += log((double)p);
                sizeSoFar = byNumber? (sizeSoFar+1.0) : nBits;
            else { // we ran out of primes, try a lower power of two
                twoM /= 2;
                assert(twoM > (long)context.zMStar.getM()); // can we go lower?
                p = bigP;
    return nBits;
// Adds several primes to the chain. If byNumber=true then totalSize specifies
// the number of primes to add. If byNumber=false then totalSize specifies the
// target natural log all the added primes.
// Returns natural log of the product of all added primes.
double AddManyPrimes(FHEcontext& context, double totalSize, 
		     bool byNumber, bool special)
  if (!context.zMStar.getM() || context.zMStar.getM()>(1<<20))// sanity checks
    Error("AddManyPrimes: m undefined or larger than 2^20");
  // NOTE: Below we are ensured that 16m*log(m) << NTL_SP_BOUND

  double sizeLogSoFar = 0.0; // log of added primes so far
  double addedSoFar = 0.0;   // Either size or number, depending on 'byNumber'

  long sizeBits = context.bitsPerLevel;
  long sizeBits = 2*context.bitsPerLevel;
  if (special) { // try to use similar size for all the special primes
    double totalBits = totalSize/log(2.0);
    long numPrimes = ceil(totalBits/NTL_SP_NBITS);// how many special primes
    sizeBits = 1+ceil(totalBits/numPrimes);       // what's the size of each
    // Added one so we don't undershoot our target
  if (sizeBits>NTL_SP_NBITS) sizeBits = NTL_SP_NBITS;
  long sizeBound = 1L << sizeBits;

  // Make sure that you have enough primes such that p-1 is divisible by 2m
  long twoM = 2 * context.zMStar.getM();
  if (sizeBound < twoM*log2(twoM)*8) { // bound too small to have such primes
    sizeBits = ceil(log2(twoM*log2(twoM)))+3; // increase prime size-bound
    sizeBound = 1L << sizeBits;

  // make p-1 divisible by m*2^k for as large k as possible
  // (not needed when m itself a power of two)

  if (context.zMStar.getM() & 1) // m is odd, so not power of two
    while (twoM < sizeBound/(sizeBits*2)) twoM *= 2;

  long bigP = sizeBound - (sizeBound%twoM) +1; // 1 mod 2m
  long p = bigP+twoM; // twoM is subtracted in the AddPrime function

  // FIXME: The last prime could sometimes be slightly smaller
  while (addedSoFar < totalSize) {
    if ((p = context.AddPrime(p,-twoM,special))) { // found a prime
      sizeLogSoFar += log((double)p);
      addedSoFar = byNumber? (addedSoFar+1.0) : sizeLogSoFar;
    else { // we ran out of primes, try a lower power of two
      twoM /= 2;
      assert(twoM > (long)context.zMStar.getM()); // can we go lower?
      p = bigP;
  return sizeLogSoFar;
void buildModChain(FHEcontext &context, long nLevels, long nDgts)
    long nPrimes = nLevels;
    long nPrimes = (nLevels+1)/2;
    // The first prime should be of half the size. The code below tries to find
    // a prime q0 of this size where q0-1 is divisible by 2^k * m for some k>1.
    // Then if the plaintext space is a power of two it tries to choose the
    // second prime q1 so that q0*q1 = 1 mod ptxtSpace. All the other primes are
    // chosen so that qi-1 is divisible by 2^k * m for as large k as possible.
    long twoM;
    if (ALT_CRT)
        twoM = 2;
        twoM = 2 * context.zMStar.getM();

    long bound = (1L << (context.bitsPerLevel-1));
    while (twoM < bound/(2*context.bitsPerLevel))
        twoM *= 2; // divisible by 2^k * m  for a larger k

    bound = bound - (bound % twoM) +1; // = 1 mod 2m
    long q0 = context.AddPrime(bound, twoM, false, !ALT_CRT);
    // add next prime to chain

    assert(q0 != 0);

    // Choose the next primes as large as possible
    if (nPrimes>0) AddPrimesByNumber(context, nPrimes);

    // calculate the size of the digits

    if (nDgts > nPrimes) nDgts = nPrimes; // sanity checks
    if (nDgts <= 0) nDgts = 1;
    context.digits.resize(nDgts); // allocate space

    IndexSet s1;
    double sizeSoFar = 0.0;
    double maxDigitSize = 0.0;
    if (nDgts>1) { // we break ciphetext into a few digits when key-switching
        double dsize = context.logOfProduct(context.ctxtPrimes)/nDgts; // estimate
        double target = dsize-(context.bitsPerLevel/3.0);
        long idx = context.ctxtPrimes.first();
        for (long i=0; i<nDgts-1; i++) { // compute next digit
            IndexSet s;
            while (idx <= context.ctxtPrimes.last() && (empty(s)||sizeSoFar<target)) {
                sizeSoFar += log((double)context.ithPrime(idx));
                idx = context.ctxtPrimes.next(idx);
            assert (!empty(s));
            context.digits[i] = s;
            double thisDigitSize = context.logOfProduct(s);
            if (maxDigitSize < thisDigitSize) maxDigitSize = thisDigitSize;
            target += dsize;
        IndexSet s = context.ctxtPrimes / s1; // all the remaining primes
        if (!empty(s)) {
            context.digits[nDgts-1] = s;
            double thisDigitSize = context.logOfProduct(s);
            if (maxDigitSize < thisDigitSize) maxDigitSize = thisDigitSize;
        else { // If last digit is empty, remove it
    else {
        maxDigitSize = context.logOfProduct(context.ctxtPrimes);
        context.digits[0] = context.ctxtPrimes;

    // Add primes to the chain for the P factor of key-switching
    long p2r = (context.rcData.alMod)? context.rcData.alMod->getPPowR()
               : context.alMod.getPPowR();
    double sizeOfSpecialPrimes
        = maxDigitSize + log(nDgts/32.0)/2 + log(context.stdev *2)
          + log((double)p2r);

    AddPrimesBySize(context, sizeOfSpecialPrimes, true);
void buildModChain(FHEcontext &context, long nLevels, long nDgts,long extraBits)
  long nPrimes = nLevels;
  long nPrimes = (nLevels+1)/2;
  // The first prime should be of half the size. The code below tries to find
  // a prime q0 of this size where q0-1 is divisible by 2^k * m for some k>1.

  long twoM = 2 * context.zMStar.getM();
  long bound = (1L << (context.bitsPerLevel-1));
  while (twoM < bound/(2*context.bitsPerLevel))
    twoM *= 2; // divisible by 2^k * m  for a larger k

  bound = bound - (bound % twoM) +1; // = 1 mod 2m
  long q0 = context.AddPrime(bound, twoM, false); 
  // add next prime to chain
  assert(q0 != 0);

  // Choose the next primes as large as possible
  if (nPrimes>0) AddPrimesByNumber(context, nPrimes);

  // calculate the size of the digits

  if (nDgts > nPrimes) nDgts = nPrimes; // sanity checks
  if (nDgts <= 0) nDgts = 1;
  context.digits.resize(nDgts); // allocate space

  IndexSet s1;
  double sizeSoFar = 0.0;
  double maxDigitSize = 0.0;
  if (nDgts>1) { // we break ciphetext into a few digits when key-switching
    double dsize = context.logOfProduct(context.ctxtPrimes)/nDgts; // estimate

    // A hack: we break the current digit after the total size of all digits
    // so far "almost reaches" the next multiple of dsize, upto 1/3 of a level
    double target = dsize-(context.bitsPerLevel/3.0);
    long idx = context.ctxtPrimes.first();
    for (long i=0; i<nDgts-1; i++) { // set all digits but the last
      IndexSet s;
      while (idx <= context.ctxtPrimes.last() && (empty(s)||sizeSoFar<target)) {
	sizeSoFar += log((double)context.ithPrime(idx));
	idx = context.ctxtPrimes.next(idx);
      assert (!empty(s));
      context.digits[i] = s;
      double thisDigitSize = context.logOfProduct(s);
      if (maxDigitSize < thisDigitSize) maxDigitSize = thisDigitSize;
      target += dsize;
    // The ctxt primes that are left (if any) form the last digit
    IndexSet s = context.ctxtPrimes / s1;
    if (!empty(s)) {
      context.digits[nDgts-1] = s;
      double thisDigitSize = context.logOfProduct(s);
      if (maxDigitSize < thisDigitSize) maxDigitSize = thisDigitSize;
    else { // If last digit is empty, remove it
  else { // only one digit
    maxDigitSize = context.logOfProduct(context.ctxtPrimes);
    context.digits[0] = context.ctxtPrimes;

  // Add special primes to the chain for the P factor of key-switching
  long p2r = context.alMod.getPPowR();
  double sizeOfSpecialPrimes
    = maxDigitSize + log(nDgts) + log(context.stdev *2)
      + log((double)p2r) + (extraBits*log(2.0));

  AddPrimesBySize(context, sizeOfSpecialPrimes, true);