Пример #1
0
/* reduces a modulo n where n is of the form 2**p - d 
   This differs from reduce_2k since "d" can be larger
   than a single digit.
*/
int mp_reduce_2k_l(mp_int *a, mp_int *n, mp_int *d)
{
   mp_int q;
   int    p, res;
   
   if ((res = mp_init(&q)) != MP_OKAY) {
      return res;
   }
   
   p = mp_count_bits(n);    
top:
   /* q = a/2**p, a = a mod 2**p */
   if ((res = mp_div_2d(a, p, &q, a)) != MP_OKAY) {
      goto LBL_ERR;
   }
   
   /* q = q * d */
   if ((res = mp_mul(&q, d, &q)) != MP_OKAY) { 
      goto LBL_ERR;
   }
   
   /* a = a + q */
   if ((res = s_mp_add(a, &q, a)) != MP_OKAY) {
      goto LBL_ERR;
   }
   
   if (mp_cmp_mag(a, n) != MP_LT) {
      s_mp_sub(a, n, a);
      goto top;
   }
   
LBL_ERR:
   mp_clear(&q);
   return res;
}
Пример #2
0
/* high level subtraction (handles signs) */
int
mp_sub (mp_int * a, mp_int * b, mp_int * c)
{
  int     sa, sb, res;

  sa = a->sign;
  sb = b->sign;

  if (sa != sb) {
    /* subtract a negative from a positive, OR */
    /* subtract a positive from a negative. */
    /* In either case, ADD their magnitudes, */
    /* and use the sign of the first number. */
    c->sign = sa;
    res = s_mp_add (a, b, c);
  } else {
    /* subtract a positive from a positive, OR */
    /* subtract a negative from a negative. */
    /* First, take the difference between their */
    /* magnitudes, then... */
    if (mp_cmp_mag (a, b) != MP_LT) {
      /* Copy the sign from the first */
      c->sign = sa;
      /* The first has a larger or equal magnitude */
      res = s_mp_sub (a, b, c);
    } else {
      /* The result has the *opposite* sign from */
      /* the first number. */
      c->sign = (sa == MP_ZPOS) ? MP_NEG : MP_ZPOS;
      /* The second has a larger magnitude */
      res = s_mp_sub (b, a, c);
    }
  }
  return res;
}
Пример #3
0
/* high level addition (handles signs) */
int mp_add(const mp_int *a, const mp_int *b, mp_int *c)
{
   int     sa, sb, res;

   /* get sign of both inputs */
   sa = a->sign;
   sb = b->sign;

   /* handle two cases, not four */
   if (sa == sb) {
      /* both positive or both negative */
      /* add their magnitudes, copy the sign */
      c->sign = sa;
      res = s_mp_add(a, b, c);
   } else {
      /* one positive, the other negative */
      /* subtract the one with the greater magnitude from */
      /* the one of the lesser magnitude.  The result gets */
      /* the sign of the one with the greater magnitude. */
      if (mp_cmp_mag(a, b) == MP_LT) {
         c->sign = sb;
         res = s_mp_sub(b, a, c);
      } else {
         c->sign = sa;
         res = s_mp_sub(a, b, c);
      }
   }
   return res;
}
Пример #4
0
/* reduces a modulo n where n is of the form 2**p - k */
int
mp_reduce_2k(mp_int *a, mp_int *n, mp_digit k)
{
   mp_int q;
   int    p, res;
   
   if ((res = mp_init(&q)) != MP_OKAY) {
      return res;
   }
   
   p = mp_count_bits(n);    
top:
   /* q = a/2**p, a = a mod 2**p */
   if ((res = mp_div_2d(a, p, &q, a)) != MP_OKAY) {
      goto ERR;
   }
   
   if (k != 1) {
      /* q = q * k */
      if ((res = mp_mul_d(&q, k, &q)) != MP_OKAY) { 
         goto ERR;
      }
   }
   
   /* a = a + q */
   if ((res = s_mp_add(a, &q, a)) != MP_OKAY) {
      goto ERR;
   }
   
   if (mp_cmp_mag(a, n) != MP_LT) {
      s_mp_sub(a, n, a);
      goto top;
   }
   
ERR:
   mp_clear(&q);
   return res;
}
Пример #5
0
/* Karatsuba squaring, computes b = a*a using three
 * half size squarings
 *
 * See comments of karatsuba_mul for details.  It
 * is essentially the same algorithm but merely
 * tuned to perform recursive squarings.
 */
int mp_karatsuba_sqr(const mp_int *a, mp_int *b)
{
   mp_int  x0, x1, t1, t2, x0x0, x1x1;
   int     B, err;

   err = MP_MEM;

   /* min # of digits */
   B = a->used;

   /* now divide in two */
   B = B >> 1;

   /* init copy all the temps */
   if (mp_init_size(&x0, B) != MP_OKAY)
      goto LBL_ERR;
   if (mp_init_size(&x1, a->used - B) != MP_OKAY)
      goto X0;

   /* init temps */
   if (mp_init_size(&t1, a->used * 2) != MP_OKAY)
      goto X1;
   if (mp_init_size(&t2, a->used * 2) != MP_OKAY)
      goto T1;
   if (mp_init_size(&x0x0, B * 2) != MP_OKAY)
      goto T2;
   if (mp_init_size(&x1x1, (a->used - B) * 2) != MP_OKAY)
      goto X0X0;

   {
      int x;
      mp_digit *dst, *src;

      src = a->dp;

      /* now shift the digits */
      dst = x0.dp;
      for (x = 0; x < B; x++) {
         *dst++ = *src++;
      }

      dst = x1.dp;
      for (x = B; x < a->used; x++) {
         *dst++ = *src++;
      }
   }

   x0.used = B;
   x1.used = a->used - B;

   mp_clamp(&x0);

   /* now calc the products x0*x0 and x1*x1 */
   if (mp_sqr(&x0, &x0x0) != MP_OKAY)
      goto X1X1;           /* x0x0 = x0*x0 */
   if (mp_sqr(&x1, &x1x1) != MP_OKAY)
      goto X1X1;           /* x1x1 = x1*x1 */

   /* now calc (x1+x0)**2 */
   if (s_mp_add(&x1, &x0, &t1) != MP_OKAY)
      goto X1X1;           /* t1 = x1 - x0 */
   if (mp_sqr(&t1, &t1) != MP_OKAY)
      goto X1X1;           /* t1 = (x1 - x0) * (x1 - x0) */

   /* add x0y0 */
   if (s_mp_add(&x0x0, &x1x1, &t2) != MP_OKAY)
      goto X1X1;           /* t2 = x0x0 + x1x1 */
   if (s_mp_sub(&t1, &t2, &t1) != MP_OKAY)
      goto X1X1;           /* t1 = (x1+x0)**2 - (x0x0 + x1x1) */

   /* shift by B */
   if (mp_lshd(&t1, B) != MP_OKAY)
      goto X1X1;           /* t1 = (x0x0 + x1x1 - (x1-x0)*(x1-x0))<<B */
   if (mp_lshd(&x1x1, B * 2) != MP_OKAY)
      goto X1X1;           /* x1x1 = x1x1 << 2*B */

   if (mp_add(&x0x0, &t1, &t1) != MP_OKAY)
      goto X1X1;           /* t1 = x0x0 + t1 */
   if (mp_add(&t1, &x1x1, b) != MP_OKAY)
      goto X1X1;           /* t1 = x0x0 + t1 + x1x1 */

   err = MP_OKAY;

X1X1:
   mp_clear(&x1x1);
X0X0:
   mp_clear(&x0x0);
T2:
   mp_clear(&t2);
T1:
   mp_clear(&t1);
X1:
   mp_clear(&x1);
X0:
   mp_clear(&x0);
LBL_ERR:
   return err;
}
Пример #6
0
/* c = |a| * |b| using Karatsuba Multiplication using
 * three half size multiplications
 *
 * Let B represent the radix [e.g. 2**DIGIT_BIT] and
 * let n represent half of the number of digits in
 * the min(a,b)
 *
 * a = a1 * B**n + a0
 * b = b1 * B**n + b0
 *
 * Then, a * b =>
   a1b1 * B**2n + ((a1 + a0)(b1 + b0) - (a0b0 + a1b1)) * B + a0b0
 *
 * Note that a1b1 and a0b0 are used twice and only need to be
 * computed once.  So in total three half size (half # of
 * digit) multiplications are performed, a0b0, a1b1 and
 * (a1+b1)(a0+b0)
 *
 * Note that a multiplication of half the digits requires
 * 1/4th the number of single precision multiplications so in
 * total after one call 25% of the single precision multiplications
 * are saved.  Note also that the call to mp_mul can end up back
 * in this function if the a0, a1, b0, or b1 are above the threshold.
 * This is known as divide-and-conquer and leads to the famous
 * O(N**lg(3)) or O(N**1.584) work which is asymptopically lower than
 * the standard O(N**2) that the baseline/comba methods use.
 * Generally though the overhead of this method doesn't pay off
 * until a certain size (N ~ 80) is reached.
 */
int mp_karatsuba_mul (mp_int * a, mp_int * b, mp_int * c)
{
    mp_int  x0, x1, y0, y1, t1, x0y0, x1y1;
    int     B, err;

    /* default the return code to an error */
    err = MP_MEM;

    /* min # of digits */
    B = MIN (a->used, b->used);

    /* now divide in two */
    B = B >> 1;

    /* init copy all the temps */
    if (mp_init_size (&x0, B) != MP_OKAY)
        goto ERR;
    if (mp_init_size (&x1, a->used - B) != MP_OKAY)
        goto X0;
    if (mp_init_size (&y0, B) != MP_OKAY)
        goto X1;
    if (mp_init_size (&y1, b->used - B) != MP_OKAY)
        goto Y0;

    /* init temps */
    if (mp_init_size (&t1, B * 2) != MP_OKAY)
        goto Y1;
    if (mp_init_size (&x0y0, B * 2) != MP_OKAY)
        goto T1;
    if (mp_init_size (&x1y1, B * 2) != MP_OKAY)
        goto X0Y0;

    /* now shift the digits */
    x0.used = y0.used = B;
    x1.used = a->used - B;
    y1.used = b->used - B;

    {
        register int x;
        register mp_digit *tmpa, *tmpb, *tmpx, *tmpy;

        /* we copy the digits directly instead of using higher level functions
         * since we also need to shift the digits
         */
        tmpa = a->dp;
        tmpb = b->dp;

        tmpx = x0.dp;
        tmpy = y0.dp;
        for (x = 0; x < B; x++) {
            *tmpx++ = *tmpa++;
            *tmpy++ = *tmpb++;
        }

        tmpx = x1.dp;
        for (x = B; x < a->used; x++) {
            *tmpx++ = *tmpa++;
        }

        tmpy = y1.dp;
        for (x = B; x < b->used; x++) {
            *tmpy++ = *tmpb++;
        }
    }

    /* only need to clamp the lower words since by definition the
     * upper words x1/y1 must have a known number of digits
     */
    mp_clamp (&x0);
    mp_clamp (&y0);

    /* now calc the products x0y0 and x1y1 */
    /* after this x0 is no longer required, free temp [x0==t2]! */
    if (mp_mul (&x0, &y0, &x0y0) != MP_OKAY)
        goto X1Y1;          /* x0y0 = x0*y0 */
    if (mp_mul (&x1, &y1, &x1y1) != MP_OKAY)
        goto X1Y1;          /* x1y1 = x1*y1 */

    /* now calc x1+x0 and y1+y0 */
    if (s_mp_add (&x1, &x0, &t1) != MP_OKAY)
        goto X1Y1;          /* t1 = x1 - x0 */
    if (s_mp_add (&y1, &y0, &x0) != MP_OKAY)
        goto X1Y1;          /* t2 = y1 - y0 */
    if (mp_mul (&t1, &x0, &t1) != MP_OKAY)
        goto X1Y1;          /* t1 = (x1 + x0) * (y1 + y0) */

    /* add x0y0 */
    if (mp_add (&x0y0, &x1y1, &x0) != MP_OKAY)
        goto X1Y1;          /* t2 = x0y0 + x1y1 */
    if (s_mp_sub (&t1, &x0, &t1) != MP_OKAY)
        goto X1Y1;          /* t1 = (x1+x0)*(y1+y0) - (x1y1 + x0y0) */

    /* shift by B */
    if (mp_lshd (&t1, B) != MP_OKAY)
        goto X1Y1;          /* t1 = (x0y0 + x1y1 - (x1-x0)*(y1-y0))<<B */
    if (mp_lshd (&x1y1, B * 2) != MP_OKAY)
        goto X1Y1;          /* x1y1 = x1y1 << 2*B */

    if (mp_add (&x0y0, &t1, &t1) != MP_OKAY)
        goto X1Y1;          /* t1 = x0y0 + t1 */
    if (mp_add (&t1, &x1y1, c) != MP_OKAY)
        goto X1Y1;          /* t1 = x0y0 + t1 + x1y1 */

    /* Algorithm succeeded set the return code to MP_OKAY */
    err = MP_OKAY;

X1Y1:
    mp_clear (&x1y1);
X0Y0:
    mp_clear (&x0y0);
T1:
    mp_clear (&t1);
Y1:
    mp_clear (&y1);
Y0:
    mp_clear (&y0);
X1:
    mp_clear (&x1);
X0:
    mp_clear (&x0);
ERR:
    return err;
}
Пример #7
0
/* Karatsuba squaring, computes b = a*a using three 
 * half size squarings
 *
 * See comments of karatsuba_mul for details.  It 
 * is essentially the same algorithm but merely 
 * tuned to perform recursive squarings.
 */
int mp_karatsuba_sqr (mp_int * a, mp_int * b)
{
  mp_int  x0, x1, t1, t2, x0x0, x1x1;
  int     B, err;

  err = MP_MEM;

  /* min # of digits */
  B = USED(a);

  /* now divide in two */
  B = B >> 1;

  /* init copy all the temps */
  if (mp_init_size (&x0, B) != MP_OKAY)
    goto ERR;
  if (mp_init_size (&x1, USED(a) - B) != MP_OKAY)
    goto X0;

  /* init temps */
  if (mp_init_size (&t1, USED(a) * 2) != MP_OKAY)
    goto X1;
  if (mp_init_size (&t2, USED(a) * 2) != MP_OKAY)
    goto T1;
  if (mp_init_size (&x0x0, B * 2) != MP_OKAY)
    goto T2;
  if (mp_init_size (&x1x1, (USED(a) - B) * 2) != MP_OKAY)
    goto X0X0;

  {
    register int x;
    register mp_digit *dst, *src;

    src = DIGITS(a);

    /* now shift the digits */
    dst = DIGITS(&x0);
    for (x = 0; x < B; x++) {
      *dst++ = *src++;
    }

    dst = DIGITS(&x1);
    for (x = B; x < USED(a); x++) {
      *dst++ = *src++;
    }
  }

  SET_USED(&x0,B);
  SET_USED(&x1,USED(a) - B);

  mp_clamp (&x0);

  /* now calc the products x0*x0 and x1*x1 */
  if (mp_sqr (&x0, &x0x0) != MP_OKAY)
    goto X1X1;           /* x0x0 = x0*x0 */
  if (mp_sqr (&x1, &x1x1) != MP_OKAY)
    goto X1X1;           /* x1x1 = x1*x1 */

  /* now calc (x1+x0)**2 */
  if (s_mp_add (&x1, &x0, &t1) != MP_OKAY)
    goto X1X1;           /* t1 = x1 - x0 */
  if (mp_sqr (&t1, &t1) != MP_OKAY)
    goto X1X1;           /* t1 = (x1 - x0) * (x1 - x0) */

  /* add x0y0 */
  if (s_mp_add (&x0x0, &x1x1, &t2) != MP_OKAY)
    goto X1X1;           /* t2 = x0x0 + x1x1 */
  if (s_mp_sub (&t1, &t2, &t1) != MP_OKAY)
    goto X1X1;           /* t1 = (x1+x0)**2 - (x0x0 + x1x1) */

  /* shift by B */
  if (mp_lshd (&t1, B) != MP_OKAY)
    goto X1X1;           /* t1 = (x0x0 + x1x1 - (x1-x0)*(x1-x0))<<B */
  if (mp_lshd (&x1x1, B * 2) != MP_OKAY)
    goto X1X1;           /* x1x1 = x1x1 << 2*B */

  if (mp_add (&x0x0, &t1, &t1) != MP_OKAY)
    goto X1X1;           /* t1 = x0x0 + t1 */
  if (mp_add (&t1, &x1x1, b) != MP_OKAY)
    goto X1X1;           /* t1 = x0x0 + t1 + x1x1 */

  err = MP_OKAY;

X1X1:mp_clear (&x1x1);
X0X0:mp_clear (&x0x0);
T2:mp_clear (&t2);
T1:mp_clear (&t1);
X1:mp_clear (&x1);
X0:mp_clear (&x0);
ERR:
  return err;
}