int main(void) { int i; mp_size_t j; flint_rand_t state; printf("split/combine_bits...."); fflush(stdout); flint_randinit(state); _flint_rand_init_gmp(state); for (i = 0; i < 10000; i++) { mp_size_t total_limbs = n_randint(state, 1000) + 1; mp_limb_t * in = flint_malloc(total_limbs*sizeof(mp_limb_t)); mp_limb_t * out = flint_calloc(total_limbs, sizeof(mp_limb_t)); mp_bitcnt_t bits = n_randint(state, 200) + 1; mp_size_t limbs = (2*bits - 1)/FLINT_BITS + 1; long length = (total_limbs*FLINT_BITS - 1)/bits + 1; mp_limb_t ** poly; poly = flint_malloc(length*sizeof(mp_limb_t *)); for (j = 0; j < length; j++) poly[j] = flint_malloc((limbs + 1)*sizeof(mp_limb_t)); mpn_urandomb(in, state->gmp_state, total_limbs*FLINT_BITS); fft_split_bits(poly, in, total_limbs, bits, limbs); fft_combine_bits(out, poly, length, bits, limbs, total_limbs); for (j = 0; j < total_limbs; j++) { if (in[j] != out[j]) { printf("FAIL:\n"); printf("Error in limb %ld, %lu != %lu\n", j, in[j], out[j]); abort(); } } flint_free(in); flint_free(out); for (j = 0; j < length; j++) flint_free(poly[j]); flint_free(poly); } flint_randclear(state); printf("PASS\n"); return 0; }
void mul_mfa_truncate_sqrt2(mp_ptr r1, mp_srcptr i1, mp_size_t n1, mp_srcptr i2, mp_size_t n2, mp_bitcnt_t depth, mp_bitcnt_t w) { mp_size_t n = (UWORD(1)<<depth); mp_bitcnt_t bits1 = (n*w - (depth+1))/2; mp_size_t sqrt = (UWORD(1)<<(depth/2)); mp_size_t r_limbs = n1 + n2; mp_size_t limbs = (n*w)/FLINT_BITS; mp_size_t size = limbs + 1; mp_size_t j1 = (n1*FLINT_BITS - 1)/bits1 + 1; mp_size_t j2 = (n2*FLINT_BITS - 1)/bits1 + 1; mp_size_t i, j, trunc; mp_limb_t ** ii, ** jj, * t1, * t2, * s1, * ptr; mp_limb_t * tt; ii = flint_malloc((4*(n + n*size) + 5*size)*sizeof(mp_limb_t)); for (i = 0, ptr = (mp_limb_t *) ii + 4*n; i < 4*n; i++, ptr += size) { ii[i] = ptr; } t1 = ptr; t2 = t1 + size; s1 = t2 + size; tt = s1 + size; if (i1 != i2) { jj = flint_malloc(4*(n + n*size)*sizeof(mp_limb_t)); for (i = 0, ptr = (mp_limb_t *) jj + 4*n; i < 4*n; i++, ptr += size) { jj[i] = ptr; } } else jj = ii; trunc = j1 + j2 - 1; if (trunc <= 2*n) trunc = 2*n + 1; trunc = 2*sqrt*((trunc + 2*sqrt - 1)/(2*sqrt)); /* trunc must be divisible by 2*sqrt */ j1 = fft_split_bits(ii, i1, n1, bits1, limbs); for (j = j1 ; j < 4*n; j++) flint_mpn_zero(ii[j], limbs + 1); fft_mfa_truncate_sqrt2_outer(ii, n, w, &t1, &t2, &s1, sqrt, trunc); if (i1 != i2) { j2 = fft_split_bits(jj, i2, n2, bits1, limbs); for (j = j2 ; j < 4*n; j++) flint_mpn_zero(jj[j], limbs + 1); fft_mfa_truncate_sqrt2_outer(jj, n, w, &t1, &t2, &s1, sqrt, trunc); } else j2 = j1; fft_mfa_truncate_sqrt2_inner(ii, jj, n, w, &t1, &t2, &s1, sqrt, trunc, tt); ifft_mfa_truncate_sqrt2_outer(ii, n, w, &t1, &t2, &s1, sqrt, trunc); flint_mpn_zero(r1, r_limbs); fft_combine_bits(r1, ii, j1 + j2 - 1, bits1, limbs, r_limbs); flint_free(ii); if (i1 != i2) flint_free(jj); }