void mpmul(mpint *b1, mpint *b2, mpint *prod) { mpint *oprod; oprod = nil; if(prod == b1 || prod == b2){ oprod = prod; prod = mpnew(0); } prod->top = 0; mpbits(prod, (b1->top+b2->top+1)*Dbits); mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p); prod->top = b1->top+b2->top+1; prod->sign = b1->sign*b2->sign; mpnorm(prod); if(oprod != nil){ mpassign(prod, oprod); mpfree(prod); } }
mpint* mpfactorial(ulong n) { int i; ulong k; unsigned cnt; int max, mmax; mpdigit p, pp[2]; mpint *r, *s, *stk[31]; cnt = 0; max = mmax = -1; p = 1; r = mpnew(0); for(k=2; k<=n; k++) { pp[0] = 0; pp[1] = 0; mpvecdigmuladd(&p, 1, (mpdigit)k, pp); if(pp[1] == 0) /* !overflow */ p = pp[0]; else { cnt++; if((cnt & 1) == 0) { s = stk[max]; mpbits(r, Dbits*(s->top+1+1)); memset(r->p, 0, Dbytes*(s->top+1+1)); mpvecmul(s->p, s->top, &p, 1, r->p); r->sign = 1; r->top = s->top+1+1; /* XXX: norm */ mpassign(r, s); for(i=4; (cnt & (i-1)) == 0; i=i<<1) { mpmul(stk[max], stk[max-1], r); mpassign(r, stk[max-1]); max--; } } else { max++; if(max > mmax) { mmax++; if(max > nelem(stk)) abort(); stk[max] = mpnew(Dbits); } stk[max]->top = 1; stk[max]->p[0] = p; } p = (mpdigit)k; } } if(max < 0) { mpbits(r, Dbits); r->top = 1; r->sign = 1; r->p[0] = p; } else { s = stk[max--]; mpbits(r, Dbits*(s->top+1+1)); memset(r->p, 0, Dbytes*(s->top+1+1)); mpvecmul(s->p, s->top, &p, 1, r->p); r->sign = 1; r->top = s->top+1+1; /* XXX: norm */ } while(max >= 0) mpmul(r, stk[max--], r); for(max=mmax; max>=0; max--) mpfree(stk[max]); mpnorm(r); return r; }
// karatsuba like (see knuth pg 258) // prereq: p is already zeroed static void mpkaratsuba(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p) { mpdigit *t, *u0, *u1, *v0, *v1, *u0v0, *u1v1, *res, *diffprod; int u0len, u1len, v0len, v1len, reslen; int sign, n; // divide each piece in half n = alen/2; if(alen&1) n++; u0len = n; u1len = alen-n; if(blen > n){ v0len = n; v1len = blen-n; } else { v0len = blen; v1len = 0; } u0 = a; u1 = a + u0len; v0 = b; v1 = b + v0len; // room for the partial products t = mallocz(Dbytes*5*(2*n+1), 1); if(t == nil) sysfatal("mpkaratsuba: %r"); u0v0 = t; u1v1 = t + (2*n+1); diffprod = t + 2*(2*n+1); res = t + 3*(2*n+1); reslen = 4*n+1; // t[0] = (u1-u0) sign = 1; if(mpveccmp(u1, u1len, u0, u0len) < 0){ sign = -1; mpvecsub(u0, u0len, u1, u1len, u0v0); } else mpvecsub(u1, u1len, u0, u1len, u0v0); // t[1] = (v0-v1) if(mpveccmp(v0, v0len, v1, v1len) < 0){ sign *= -1; mpvecsub(v1, v1len, v0, v1len, u1v1); } else mpvecsub(v0, v0len, v1, v1len, u1v1); // t[4:5] = (u1-u0)*(v0-v1) mpvecmul(u0v0, u0len, u1v1, v0len, diffprod); // t[0:1] = u1*v1 memset(t, 0, 2*(2*n+1)*Dbytes); if(v1len > 0) mpvecmul(u1, u1len, v1, v1len, u1v1); // t[2:3] = u0v0 mpvecmul(u0, u0len, v0, v0len, u0v0); // res = u0*v0<<n + u0*v0 mpvecadd(res, reslen, u0v0, u0len+v0len, res); mpvecadd(res+n, reslen-n, u0v0, u0len+v0len, res+n); // res += u1*v1<<n + u1*v1<<2*n if(v1len > 0){ mpvecadd(res+n, reslen-n, u1v1, u1len+v1len, res+n); mpvecadd(res+2*n, reslen-2*n, u1v1, u1len+v1len, res+2*n); } // res += (u1-u0)*(v0-v1)<<n if(sign < 0) mpvecsub(res+n, reslen-n, diffprod, u0len+v0len, res+n); else mpvecadd(res+n, reslen-n, diffprod, u0len+v0len, res+n); memmove(p, res, (alen+blen)*Dbytes); free(t); }