Ejemplo n.º 1
0
/* w.b. hart */
void zz_sub(zz_ptr r, zz_srcptr a, zz_srcptr b)
{
   long asize = ABS(a->size);
   long bsize = ABS(b->size);
   long rsize;
   int sign = (asize < bsize);

   ZZ_ORDER(a, asize, b, bsize);
   
   zz_fit(r, asize + 1);

   if ((a->size ^ b->size) >= 0) {
      word_t bi = nn_sub(r->n, a->n, asize, b->n, bsize, 0);
      rsize = a->size;
      
      if (bi) {
         nn_neg(r->n, r->n, asize, 0);
         rsize = -rsize;
      }
   } else {
      r->n[asize] = nn_add(r->n, a->n, asize, b->n, bsize, 0);
      rsize = asize + 1;
      if (a->size < 0) rsize = -rsize;
   }

   r->size = rsize;
   zz_normalise(r);

   if (sign) r->size = -r->size;
}
int16 rsa_private_block(uint8 *input, uint16 input_len, uint8 *output, uint16 *output_len, rsa_private_key *private_key)
{
  uint32 c[MAX_NN_DIGITS] = {0}, cP[MAX_NN_DIGITS] = {0}, cQ[MAX_NN_DIGITS] = {0},
  dP[MAX_NN_DIGITS] = {0}, dQ[MAX_NN_DIGITS] = {0}, mP[MAX_NN_DIGITS] = {0},
  mQ[MAX_NN_DIGITS] = {0}, n[MAX_NN_DIGITS] = {0}, p[MAX_NN_DIGITS] = {0}, q[MAX_NN_DIGITS] = {0},
  qInv[MAX_NN_DIGITS] = {0}, t[MAX_NN_DIGITS] = {0};
  uint16 cDigits = 0, nDigits = 0, pDigits = 0;

  memset (c, 0, sizeof (c));
  memset (cP, 0, sizeof (cP));
  memset (cQ, 0, sizeof (cQ));
  memset (dP, 0, sizeof (dP));
  memset (dQ, 0, sizeof (dQ));
  memset (mP, 0, sizeof (mP));
  memset (mQ, 0, sizeof (mQ));
  memset (p, 0, sizeof (p));
  memset (q, 0, sizeof (q));
  memset (qInv, 0, sizeof (qInv));
  memset (t, 0, sizeof (t));

  nn_decode (c, MAX_NN_DIGITS, input, input_len);
  nn_decode (n, MAX_NN_DIGITS, private_key->modulus, MAX_RSA_MODULUS_LEN);
  nn_decode (p, MAX_NN_DIGITS, private_key->prime[0], MAX_RSA_PRIME_LEN);
  nn_decode (q, MAX_NN_DIGITS, private_key->prime[1], MAX_RSA_PRIME_LEN);
  nn_decode (dP, MAX_NN_DIGITS, private_key->prime_exponent[0], MAX_RSA_PRIME_LEN);
  nn_decode (dQ, MAX_NN_DIGITS, private_key->prime_exponent[1], MAX_RSA_PRIME_LEN);
  nn_decode (qInv, MAX_NN_DIGITS, private_key->coefficient, MAX_RSA_PRIME_LEN);


  cDigits = nn_digits (c, MAX_NN_DIGITS);
  nDigits = nn_digits (n, MAX_NN_DIGITS);
  pDigits = nn_digits (p, MAX_NN_DIGITS);

  if (nn_cmp (c, n, nDigits) >= 0)
  return -1;

  /* Compute mP = cP^dP mod p  and  mQ = cQ^dQ mod q. (Assumes q has length at most pDigits, i.e., p > q.) */
  nn_mod (cP, c, cDigits, p, pDigits);
  nn_mod (cQ, c, cDigits, q, pDigits);
  nn_mod_exp (mP, cP, dP, pDigits, p, pDigits);

  nn_assign_zero (mQ, nDigits);
  nn_mod_exp (mQ, cQ, dQ, pDigits, q, pDigits);

  /* Chinese Remainder Theorem:  m = ((((mP - mQ) mod p) * qInv) mod p) * q + mQ.   */
  if (nn_cmp (mP, mQ, pDigits) >= 0)
  {
    nn_sub (t, mP, mQ, pDigits);
  }
  else
  {
    nn_sub (t, mQ, mP, pDigits);
    nn_sub (t, p, t, pDigits);
  }
  nn_mod_mult (t, t, qInv, p, pDigits);
  nn_mult (t, t, q, pDigits);
  nn_add (t, t, mQ, nDigits);

	*output_len = (uint16)(private_key->bits + 7) / 8;
	nn_encode (output, *output_len, t, nDigits);

	return (0);
}