void fft_mfa_truncate_sqrt2_inner(mp_limb_t ** ii, mp_limb_t ** jj, mp_size_t n, mp_bitcnt_t w, mp_limb_t ** t1, mp_limb_t ** t2, mp_limb_t ** temp, mp_size_t n1, mp_size_t trunc, mp_limb_t * tt) { mp_size_t i, j, s; mp_size_t n2 = (2*n)/n1; mp_size_t trunc2 = (trunc - 2*n)/n1; mp_size_t limbs = (n*w)/FLINT_BITS; mp_bitcnt_t depth = 0; mp_bitcnt_t depth2 = 0; while ((UWORD(1)<<depth) < n2) depth++; while ((UWORD(1)<<depth2) < n1) depth2++; ii += 2*n; jj += 2*n; /* convolutions on relevant rows */ for (s = 0; s < trunc2; s++) { i = n_revbin(s, depth); fft_radix2(ii + i*n1, n1/2, w*n2, t1, t2); if (ii != jj) fft_radix2(jj + i*n1, n1/2, w*n2, t1, t2); for (j = 0; j < n1; j++) { mp_size_t t = i*n1 + j; mpn_normmod_2expp1(ii[t], limbs); if (ii != jj) mpn_normmod_2expp1(jj[t], limbs); fft_mulmod_2expp1(ii[t], ii[t], jj[t], n, w, tt); } ifft_radix2(ii + i*n1, n1/2, w*n2, t1, t2); } ii -= 2*n; jj -= 2*n; /* convolutions on rows */ for (i = 0; i < n2; i++) { fft_radix2(ii + i*n1, n1/2, w*n2, t1, t2); if (ii != jj) fft_radix2(jj + i*n1, n1/2, w*n2, t1, t2); for (j = 0; j < n1; j++) { mp_size_t t = i*n1 + j; mpn_normmod_2expp1(ii[t], limbs); if (ii != jj) mpn_normmod_2expp1(jj[t], limbs); fft_mulmod_2expp1(ii[t], ii[t], jj[t], n, w, tt); } ifft_radix2(ii + i*n1, n1/2, w*n2, t1, t2); } }
void fft_convolution(mp_limb_t ** ii, mp_limb_t ** jj, long depth, long limbs, long trunc, mp_limb_t ** t1, mp_limb_t ** t2, mp_limb_t ** s1, mp_limb_t * tt) { long n = (1L<<depth), j; long w = (limbs*FLINT_BITS)/n; long sqrt = (1L<<(depth/2)); if (depth <= 6) { trunc = 2*((trunc + 1)/2); fft_truncate_sqrt2(ii, n, w, t1, t2, s1, trunc); if (ii != jj) fft_truncate_sqrt2(jj, n, w, t1, t2, s1, trunc); for (j = 0; j < trunc; j++) { mpn_normmod_2expp1(ii[j], limbs); if (ii != jj) mpn_normmod_2expp1(jj[j], limbs); fft_mulmod_2expp1(ii[j], ii[j], jj[j], n, w, tt); } ifft_truncate_sqrt2(ii, n, w, t1, t2, s1, trunc); for (j = 0; j < trunc; j++) { mpn_div_2expmod_2expp1(ii[j], ii[j], limbs, depth + 2); mpn_normmod_2expp1(ii[j], limbs); } } else { trunc = 2*sqrt*((trunc + 2*sqrt - 1)/(2*sqrt)); fft_mfa_truncate_sqrt2_outer(ii, n, w, t1, t2, s1, sqrt, trunc); if (ii != jj) fft_mfa_truncate_sqrt2_outer(jj, n, w, t1, t2, s1, sqrt, trunc); fft_mfa_truncate_sqrt2_inner(ii, jj, n, w, t1, t2, s1, sqrt, trunc, tt); ifft_mfa_truncate_sqrt2_outer(ii, n, w, t1, t2, s1, sqrt, trunc); } }
int main(void) { mp_bitcnt_t depth, w; int iters; flint_rand_t state; printf("mulmod_2expp1...."); fflush(stdout); flint_randinit(state); _flint_rand_init_gmp(state); for (iters = 0; iters < 100; iters++) { for (depth = 6; depth <= 18; depth++) { for (w = 1; w <= 2; w++) { mp_size_t n = (1UL<<depth); mp_bitcnt_t bits = n*w; mp_size_t int_limbs = bits/FLINT_BITS; mp_size_t j; mp_limb_t c, * i1, * i2, * r1, * r2, * tt; i1 = flint_malloc(6*(int_limbs+1)*sizeof(mp_limb_t)); i2 = i1 + int_limbs + 1; r1 = i2 + int_limbs + 1; r2 = r1 + int_limbs + 1; tt = r2 + int_limbs + 1; random_fermat(i1, state, int_limbs); random_fermat(i2, state, int_limbs); mpn_normmod_2expp1(i1, int_limbs); mpn_normmod_2expp1(i2, int_limbs); fft_mulmod_2expp1(r2, i1, i2, n, w, tt); c = i1[int_limbs] + 2*i2[int_limbs]; c = mpn_mulmod_2expp1(r1, i1, i2, c, int_limbs*FLINT_BITS, tt); for (j = 0; j < int_limbs; j++) { if (r1[j] != r2[j]) { printf("error in limb %ld, %lx != %lx\n", j, r1[j], r2[j]); abort(); } } if (c != r2[int_limbs]) { printf("error in limb %ld, %lx != %lx\n", j, c, r2[j]); abort(); } flint_free(i1); } } } /* test squaring */ for (iters = 0; iters < 100; iters++) { for (depth = 6; depth <= 18; depth++) { for (w = 1; w <= 2; w++) { mp_size_t n = (1UL<<depth); mp_bitcnt_t bits = n*w; mp_size_t int_limbs = bits/FLINT_BITS; mp_size_t j; mp_limb_t c, * i1, * r1, * r2, * tt; i1 = flint_malloc(5*(int_limbs+1)*sizeof(mp_limb_t)); r1 = i1 + int_limbs + 1; r2 = r1 + int_limbs + 1; tt = r2 + int_limbs + 1; random_fermat(i1, state, int_limbs); mpn_normmod_2expp1(i1, int_limbs); fft_mulmod_2expp1(r2, i1, i1, n, w, tt); c = i1[int_limbs] + 2*i1[int_limbs]; c = mpn_mulmod_2expp1(r1, i1, i1, c, int_limbs*FLINT_BITS, tt); for (j = 0; j < int_limbs; j++) { if (r1[j] != r2[j]) { printf("error in limb %ld, %lx != %lx\n", j, r1[j], r2[j]); abort(); } } if (c != r2[int_limbs]) { printf("error in limb %ld, %lx != %lx\n", j, c, r2[j]); abort(); } flint_free(i1); } } } flint_randclear(state); printf("PASS\n"); return 0; }