/* sum = abs(b1) + abs(b2), i.e., add the magnitudes */ void mpmagadd(mpint *b1, mpint *b2, mpint *sum) { int m, n; mpint *t; /* get the sizes right */ if(b2->top > b1->top){ t = b1; b1 = b2; b2 = t; } n = b1->top; m = b2->top; if(n == 0){ mpassign(mpzero, sum); return; } if(m == 0){ mpassign(b1, sum); return; } mpbits(sum, (n+1)*Dbits); sum->top = n+1; mpvecadd(b1->p, n, b2->p, m, sum->p); sum->sign = 1; mpnorm(sum); }
ECpoint* strtoec(ECdomain *dom, char *s, char **rptr, ECpoint *ret) { int allocd, o; mpint *r; allocd = 0; if(ret == nil){ allocd = 1; ret = mallocz(sizeof(*ret), 1); if(ret == nil) return nil; ret->x = mpnew(0); ret->y = mpnew(0); } o = 0; switch(octet(&s)){ case 0: ret->inf = 1; return ret; case 3: o = 1; case 2: if(halfpt(dom, s, &s, ret->x) == nil) goto err; r = mpnew(0); mpmul(ret->x, ret->x, r); mpadd(r, dom->a, r); mpmul(r, ret->x, r); mpadd(r, dom->b, r); if(!mpsqrt(r, dom->p, r)){ mpfree(r); goto err; } if((r->p[0] & 1) != o) mpsub(dom->p, r, r); mpassign(r, ret->y); mpfree(r); if(!ecverify(dom, ret)) goto err; return ret; case 4: if(halfpt(dom, s, &s, ret->x) == nil) goto err; if(halfpt(dom, s, &s, ret->y) == nil) goto err; if(!ecverify(dom, ret)) goto err; return ret; } err: if(rptr) *rptr = s; if(allocd){ mpfree(ret->x); mpfree(ret->y); free(ret); } return nil; }
/* * this code assumes that a vlong is an integral number of * mpdigits long. */ mpint* vtomp(vlong v, mpint *b) { int s; uvlong uv; if(b == nil) b = mpnew(VLDIGITS*sizeof(mpdigit)); else mpbits(b, VLDIGITS*sizeof(mpdigit)); mpassign(mpzero, b); if(v == 0) return b; if(v < 0){ b->sign = -1; uv = -v; } else uv = v; for(s = 0; s < VLDIGITS && uv != 0; s++){ b->p[s] = uv; uv >>= sizeof(mpdigit)*8; } b->top = s; return b; }
mpint* uitomp(uint i, mpint *b) { if(b == nil) b = mpnew(0); mpassign(mpzero, b); if(i != 0) b->top = 1; *b->p = i; return b; }
/* * this code assumes that a vlong is an integral number of * mpdigits long. */ mpint* uvtomp(uint64_t v, mpint *b) { int s; if(b == nil) b = mpnew(VLDIGITS*sizeof(mpdigit)); else mpbits(b, VLDIGITS*sizeof(mpdigit)); mpassign(mpzero, b); if(v == 0) return b; for(s = 0; s < VLDIGITS && v != 0; s++){ b->p[s] = v; v >>= sizeof(mpdigit)*8; } b->top = s; return b; }
/* garners algorithm for converting residue form to linear */ void crtout(CRTpre *crt, CRTres *res, mpint *x) { mpint *u; int i; u = mpnew(0); mpassign(res->r[0], x); for(i = 1; i < crt->n; i++){ mpsub(res->r[i], x, u); mpmul(u, crt->c[i], u); mpmod(u, crt->m[i], u); mpmul(u, crt->p[i-1], u); mpadd(x, u, x); } mpfree(u); }
void mpmul(mpint *b1, mpint *b2, mpint *prod) { mpint *oprod; oprod = nil; if(prod == b1 || prod == b2){ oprod = prod; prod = mpnew(0); } prod->top = 0; mpbits(prod, (b1->top+b2->top+1)*Dbits); mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p); prod->top = b1->top+b2->top+1; prod->sign = b1->sign*b2->sign; mpnorm(prod); if(oprod != nil){ mpassign(prod, oprod); mpfree(prod); } }
mpint* mpfactorial(ulong n) { int i; ulong k; unsigned cnt; int max, mmax; mpdigit p, pp[2]; mpint *r, *s, *stk[31]; cnt = 0; max = mmax = -1; p = 1; r = mpnew(0); for(k=2; k<=n; k++) { pp[0] = 0; pp[1] = 0; mpvecdigmuladd(&p, 1, (mpdigit)k, pp); if(pp[1] == 0) /* !overflow */ p = pp[0]; else { cnt++; if((cnt & 1) == 0) { s = stk[max]; mpbits(r, Dbits*(s->top+1+1)); memset(r->p, 0, Dbytes*(s->top+1+1)); mpvecmul(s->p, s->top, &p, 1, r->p); r->sign = 1; r->top = s->top+1+1; /* XXX: norm */ mpassign(r, s); for(i=4; (cnt & (i-1)) == 0; i=i<<1) { mpmul(stk[max], stk[max-1], r); mpassign(r, stk[max-1]); max--; } } else { max++; if(max > mmax) { mmax++; if(max > nelem(stk)) abort(); stk[max] = mpnew(Dbits); } stk[max]->top = 1; stk[max]->p[0] = p; } p = (mpdigit)k; } } if(max < 0) { mpbits(r, Dbits); r->top = 1; r->sign = 1; r->p[0] = p; } else { s = stk[max--]; mpbits(r, Dbits*(s->top+1+1)); memset(r->p, 0, Dbytes*(s->top+1+1)); mpvecmul(s->p, s->top, &p, 1, r->p); r->sign = 1; r->top = s->top+1+1; /* XXX: norm */ } while(max >= 0) mpmul(r, stk[max--], r); for(max=mmax; max>=0; max--) mpfree(stk[max]); mpnorm(r); return r; }
static int mpsqrt(mpint *n, mpint *p, mpint *r) { mpint *a, *t, *s, *xp, *xq, *yp, *yq, *zp, *zq, *N; if(mpleg(n, p) == -1) return 0; a = mpnew(0); t = mpnew(0); s = mpnew(0); N = mpnew(0); xp = mpnew(0); xq = mpnew(0); yp = mpnew(0); yq = mpnew(0); zp = mpnew(0); zq = mpnew(0); for(;;){ for(;;){ mprand(mpsignif(p), genrandom, a); if(mpcmp(a, mpzero) > 0 && mpcmp(a, p) < 0) break; } mpmul(a, a, t); mpsub(t, n, t); mpmod(t, p, t); if(mpleg(t, p) == -1) break; } mpadd(p, mpone, N); mpright(N, 1, N); mpmul(a, a, t); mpsub(t, n, t); mpassign(a, xp); uitomp(1, xq); uitomp(1, yp); uitomp(0, yq); while(mpcmp(N, mpzero) != 0){ if(N->p[0] & 1){ mpmul(xp, yp, zp); mpmul(xq, yq, zq); mpmul(zq, t, zq); mpadd(zp, zq, zp); mpmod(zp, p, zp); mpmul(xp, yq, zq); mpmul(xq, yp, s); mpadd(zq, s, zq); mpmod(zq, p, yq); mpassign(zp, yp); } mpmul(xp, xp, zp); mpmul(xq, xq, zq); mpmul(zq, t, zq); mpadd(zp, zq, zp); mpmod(zp, p, zp); mpmul(xp, xq, zq); mpadd(zq, zq, zq); mpmod(zq, p, xq); mpassign(zp, xp); mpright(N, 1, N); } if(mpcmp(yq, mpzero) != 0) abort(); mpassign(yp, r); mpfree(a); mpfree(t); mpfree(s); mpfree(N); mpfree(xp); mpfree(xq); mpfree(yp); mpfree(yq); mpfree(zp); mpfree(zq); return 1; }
void ecadd(ECdomain *dom, ECpoint *a, ECpoint *b, ECpoint *s) { mpint *l, *k, *sx, *sy; if(a->inf && b->inf){ s->inf = 1; return; } if(a->inf){ ecassign(dom, b, s); return; } if(b->inf){ ecassign(dom, a, s); return; } if(mpcmp(a->x, b->x) == 0 && (mpcmp(a->y, mpzero) == 0 || mpcmp(a->y, b->y) != 0)){ s->inf = 1; return; } l = mpnew(0); k = mpnew(0); sx = mpnew(0); sy = mpnew(0); if(mpcmp(a->x, b->x) == 0 && mpcmp(a->y, b->y) == 0){ mpadd(mpone, mptwo, k); mpmul(a->x, a->x, l); mpmul(l, k, l); mpadd(l, dom->a, l); mpleft(a->y, 1, k); mpmod(k, dom->p, k); mpinvert(k, dom->p, k); mpmul(k, l, l); mpmod(l, dom->p, l); mpleft(a->x, 1, k); mpmul(l, l, sx); mpsub(sx, k, sx); mpmod(sx, dom->p, sx); mpsub(a->x, sx, sy); mpmul(l, sy, sy); mpsub(sy, a->y, sy); mpmod(sy, dom->p, sy); mpassign(sx, s->x); mpassign(sy, s->y); mpfree(sx); mpfree(sy); mpfree(l); mpfree(k); return; } mpsub(b->y, a->y, l); mpmod(l, dom->p, l); mpsub(b->x, a->x, k); mpmod(k, dom->p, k); mpinvert(k, dom->p, k); mpmul(k, l, l); mpmod(l, dom->p, l); mpmul(l, l, sx); mpsub(sx, a->x, sx); mpsub(sx, b->x, sx); mpmod(sx, dom->p, sx); mpsub(a->x, sx, sy); mpmul(sy, l, sy); mpsub(sy, a->y, sy); mpmod(sy, dom->p, sy); mpassign(sx, s->x); mpassign(sy, s->y); mpfree(sx); mpfree(sy); mpfree(l); mpfree(k); }
// extended binary gcd // // For a anv b it solves, v = gcd(a,b) and finds x and y s.t. // ax + by = v // // Handbook of Applied Cryptography, Menezes et al, 1997, pg 608. void mpextendedgcd(mpint *a, mpint *b, mpint *v, mpint *x, mpint *y) { mpint *u, *A, *B, *C, *D; int g; if(a->top == 0){ mpassign(b, v); mpassign(mpone, y); mpassign(mpzero, x); return; } if(b->top == 0){ mpassign(a, v); mpassign(mpone, x); mpassign(mpzero, y); return; } g = 0; a = mpcopy(a); b = mpcopy(b); while(iseven(a) && iseven(b)){ mpright(a, 1, a); mpright(b, 1, b); g++; } u = mpcopy(a); mpassign(b, v); A = mpcopy(mpone); B = mpcopy(mpzero); C = mpcopy(mpzero); D = mpcopy(mpone); for(;;) { // print("%B %B %B %B %B %B\n", u, v, A, B, C, D); while(iseven(u)){ mpright(u, 1, u); if(!iseven(A) || !iseven(B)) { mpadd(A, b, A); mpsub(B, a, B); } mpright(A, 1, A); mpright(B, 1, B); } // print("%B %B %B %B %B %B\n", u, v, A, B, C, D); while(iseven(v)){ mpright(v, 1, v); if(!iseven(C) || !iseven(D)) { mpadd(C, b, C); mpsub(D, a, D); } mpright(C, 1, C); mpright(D, 1, D); } // print("%B %B %B %B %B %B\n", u, v, A, B, C, D); if(mpcmp(u, v) >= 0){ mpsub(u, v, u); mpsub(A, C, A); mpsub(B, D, B); } else { mpsub(v, u, v); mpsub(C, A, C); mpsub(D, B, D); } if(u->top == 0) break; } mpassign(C, x); mpassign(D, y); mpleft(v, g, v); mpfree(A); mpfree(B); mpfree(C); mpfree(D); mpfree(u); mpfree(a); mpfree(b); }
void mpexp(mpint *b, mpint *e, mpint *m, mpint *res) { mpint *t[2]; int tofree; mpdigit d, bit; int i, j; t[0] = mpcopy(b); t[1] = res; tofree = 0; if(res == b){ b = mpcopy(b); tofree |= Freeb; } if(res == e){ e = mpcopy(e); tofree |= Freee; } if(res == m){ m = mpcopy(m); tofree |= Freem; } // skip first bit i = e->top-1; d = e->p[i]; for(bit = mpdighi; (bit & d) == 0; bit >>= 1) ; bit >>= 1; j = 0; for(;;){ for(; bit != 0; bit >>= 1){ mpmul(t[j], t[j], t[j^1]); if(bit & d) mpmul(t[j^1], b, t[j]); else j ^= 1; if(m != nil && t[j]->top > m->top){ mpmod(t[j], m, t[j^1]); j ^= 1; } } if(--i < 0) break; bit = mpdighi; d = e->p[i]; } if(m != nil){ mpmod(t[j], m, t[j^1]); j ^= 1; } if(t[j] == res){ mpfree(t[j^1]); } else { mpassign(t[j], res); mpfree(t[j]); } if(tofree){ if(tofree & Freeb) mpfree(b); if(tofree & Freee) mpfree(e); if(tofree & Freem) mpfree(m); } }
void mpeuclid(mpint *a, mpint *b, mpint *d, mpint *x, mpint *y) { mpint *tmp, *x0, *x1, *x2, *y0, *y1, *y2, *q, *r; if(a->sign<0 || b->sign<0) sysfatal("mpeuclid: negative arg"); if(mpcmp(a, b) < 0){ tmp = a; a = b; b = tmp; tmp = x; x = y; y = tmp; } if(b->top == 0){ mpassign(a, d); mpassign(mpone, x); mpassign(mpzero, y); return; } a = mpcopy(a); b = mpcopy(b); x0 = mpnew(0); x1 = mpcopy(mpzero); x2 = mpcopy(mpone); y0 = mpnew(0); y1 = mpcopy(mpone); y2 = mpcopy(mpzero); q = mpnew(0); r = mpnew(0); while(b->top != 0 && b->sign > 0){ // q = a/b // r = a mod b mpdiv(a, b, q, r); // x0 = x2 - qx1 mpmul(q, x1, x0); mpsub(x2, x0, x0); // y0 = y2 - qy1 mpmul(q, y1, y0); mpsub(y2, y0, y0); // rotate values tmp = a; a = b; b = r; r = tmp; tmp = x2; x2 = x1; x1 = x0; x0 = tmp; tmp = y2; y2 = y1; y1 = y0; y0 = tmp; } mpassign(a, d); mpassign(x2, x); mpassign(y2, y); mpfree(x0); mpfree(x1); mpfree(x2); mpfree(y0); mpfree(y1); mpfree(y2); mpfree(q); mpfree(r); mpfree(a); mpfree(b); }
void mpdiv(mpint *dividend, mpint *divisor, mpint *quotient, mpint *remainder) { int j, s, vn, sign; mpdigit qd, *up, *vp, *qp; mpint *u, *v, *t; // divide bv zero if(divisor->top == 0) sysfatal("mpdiv: divide by zero"); // quick check if(mpmagcmp(dividend, divisor) < 0){ if(remainder != nil) mpassign(dividend, remainder); if(quotient != nil) mpassign(mpzero, quotient); return; } // D1: shift until divisor, v, has hi bit set (needed to make trial // divisor accurate) qd = divisor->p[divisor->top-1]; for(s = 0; (qd & mpdighi) == 0; s++) qd <<= 1; u = mpnew((dividend->top+2)*Dbits + s); if(s == 0 && divisor != quotient && divisor != remainder) { mpassign(dividend, u); v = divisor; } else { mpleft(dividend, s, u); v = mpnew(divisor->top*Dbits); mpleft(divisor, s, v); } up = u->p+u->top-1; vp = v->p+v->top-1; vn = v->top; // D1a: make sure high digit of dividend is less than high digit of divisor if(*up >= *vp){ *++up = 0; u->top++; } // storage for multiplies t = mpnew(4*Dbits); qp = nil; if(quotient != nil){ mpbits(quotient, (u->top - v->top)*Dbits); quotient->top = u->top - v->top; qp = quotient->p+quotient->top-1; } // D2, D7: loop on length of dividend for(j = u->top; j > vn; j--){ // D3: calculate trial divisor mpdigdiv(up-1, *vp, &qd); // D3a: rule out trial divisors 2 greater than real divisor if(vn > 1) for(;;){ memset(t->p, 0, 3*Dbytes); // mpvecdigmuladd adds to what's there mpvecdigmuladd(vp-1, 2, qd, t->p); if(mpveccmp(t->p, 3, up-2, 3) > 0) qd--; else break; } // D4: u -= v*qd << j*Dbits sign = mpvecdigmulsub(v->p, vn, qd, up-vn); if(sign < 0){ // D6: trial divisor was too high, add back borrowed // value and decrease divisor mpvecadd(up-vn, vn+1, v->p, vn, up-vn); qd--; } // D5: save quotient digit if(qp != nil) *qp-- = qd; // push top of u down one u->top--; *up-- = 0; } if(qp != nil){ mpnorm(quotient); if(dividend->sign != divisor->sign) quotient->sign = -1; } if(remainder != nil){ mpright(u, s, remainder); // u is the remainder shifted remainder->sign = dividend->sign; } mpfree(t); mpfree(u); if(v != divisor) mpfree(v); }