void fft_mfa_truncate_sqrt2_inner(mp_limb_t ** ii, mp_limb_t ** jj, mp_size_t n, 
                   mp_bitcnt_t w, mp_limb_t ** t1, mp_limb_t ** t2, 
                  mp_limb_t ** temp, mp_size_t n1, mp_size_t trunc, mp_limb_t * tt)
{
   mp_size_t i, j, s;
   mp_size_t n2 = (2*n)/n1;
   mp_size_t trunc2 = (trunc - 2*n)/n1;
   mp_size_t limbs = (n*w)/FLINT_BITS;
   mp_bitcnt_t depth = 0;
   mp_bitcnt_t depth2 = 0;
   
   while ((UWORD(1)<<depth) < n2) depth++;
   while ((UWORD(1)<<depth2) < n1) depth2++;

   ii += 2*n;
   jj += 2*n;

   /* convolutions on relevant rows */
   for (s = 0; s < trunc2; s++)
   {
      i = n_revbin(s, depth);
      fft_radix2(ii + i*n1, n1/2, w*n2, t1, t2);
      if (ii != jj) fft_radix2(jj + i*n1, n1/2, w*n2, t1, t2);
      
      for (j = 0; j < n1; j++)
      {
         mp_size_t t = i*n1 + j;
         mpn_normmod_2expp1(ii[t], limbs);
         if (ii != jj) mpn_normmod_2expp1(jj[t], limbs);
         fft_mulmod_2expp1(ii[t], ii[t], jj[t], n, w, tt);
      }      
      
      ifft_radix2(ii + i*n1, n1/2, w*n2, t1, t2);
   }

   ii -= 2*n;
   jj -= 2*n;

   /* convolutions on rows */
   for (i = 0; i < n2; i++)
   {
      fft_radix2(ii + i*n1, n1/2, w*n2, t1, t2);
      if (ii != jj) fft_radix2(jj + i*n1, n1/2, w*n2, t1, t2);

      for (j = 0; j < n1; j++)
      {
         mp_size_t t = i*n1 + j;
         mpn_normmod_2expp1(ii[t], limbs);
         if (ii != jj) mpn_normmod_2expp1(jj[t], limbs);
         fft_mulmod_2expp1(ii[t], ii[t], jj[t], n, w, tt);
      }      
      
      ifft_radix2(ii + i*n1, n1/2, w*n2, t1, t2);
   }
}
예제 #2
0
파일: convolution.c 프로젝트: goens/flint2
void fft_convolution(mp_limb_t ** ii, mp_limb_t ** jj, long depth, 
                              long limbs, long trunc, mp_limb_t ** t1, 
                          mp_limb_t ** t2, mp_limb_t ** s1, mp_limb_t * tt)
{
   long n = (1L<<depth), j;
   long w = (limbs*FLINT_BITS)/n;
   long sqrt = (1L<<(depth/2));
   
   if (depth <= 6)
   {
      trunc = 2*((trunc + 1)/2);
      
      fft_truncate_sqrt2(ii, n, w, t1, t2, s1, trunc);
   
      if (ii != jj)
         fft_truncate_sqrt2(jj, n, w, t1, t2, s1, trunc);

      for (j = 0; j < trunc; j++)
      {
         mpn_normmod_2expp1(ii[j], limbs);
         if (ii != jj) mpn_normmod_2expp1(jj[j], limbs);
         
         fft_mulmod_2expp1(ii[j], ii[j], jj[j], n, w, tt);
      }

      ifft_truncate_sqrt2(ii, n, w, t1, t2, s1, trunc);

      for (j = 0; j < trunc; j++)
      {
         mpn_div_2expmod_2expp1(ii[j], ii[j], limbs, depth + 2);
         mpn_normmod_2expp1(ii[j], limbs);
      }
   } else
   {
      trunc = 2*sqrt*((trunc + 2*sqrt - 1)/(2*sqrt));
      
      fft_mfa_truncate_sqrt2_outer(ii, n, w, t1, t2, s1, sqrt, trunc);
      
      if (ii != jj)
         fft_mfa_truncate_sqrt2_outer(jj, n, w, t1, t2, s1, sqrt, trunc);
      
      fft_mfa_truncate_sqrt2_inner(ii, jj, n, w, t1, t2, s1, sqrt, trunc, tt);
      
      ifft_mfa_truncate_sqrt2_outer(ii, n, w, t1, t2, s1, sqrt, trunc);
   }
}
예제 #3
0
int
main(void)
{
    mp_bitcnt_t depth, w;
    int iters;

    flint_rand_t state;

    printf("mulmod_2expp1....");
    fflush(stdout);

    flint_randinit(state);
    _flint_rand_init_gmp(state);

    for (iters = 0; iters < 100; iters++)
    {
        for (depth = 6; depth <= 18; depth++)
        {
            for (w = 1; w <= 2; w++)
            {
                mp_size_t n = (1UL<<depth);
                mp_bitcnt_t bits = n*w;
                mp_size_t int_limbs = bits/FLINT_BITS;
                mp_size_t j;
                mp_limb_t c, * i1, * i2, * r1, * r2, * tt;
        
                i1 = flint_malloc(6*(int_limbs+1)*sizeof(mp_limb_t));
                i2 = i1 + int_limbs + 1;
                r1 = i2 + int_limbs + 1;
                r2 = r1 + int_limbs + 1;
                tt = r2 + int_limbs + 1;

                random_fermat(i1, state, int_limbs);
                random_fermat(i2, state, int_limbs);
                mpn_normmod_2expp1(i1, int_limbs);
                mpn_normmod_2expp1(i2, int_limbs);

                fft_mulmod_2expp1(r2, i1, i2, n, w, tt);
                c = i1[int_limbs] + 2*i2[int_limbs];
                c = mpn_mulmod_2expp1(r1, i1, i2, c, int_limbs*FLINT_BITS, tt);
            
                for (j = 0; j < int_limbs; j++)
                {
                    if (r1[j] != r2[j]) 
                    {
                        printf("error in limb %ld, %lx != %lx\n", j, r1[j], r2[j]);
                        abort();
                    }
                }

                if (c != r2[int_limbs])
                {
                    printf("error in limb %ld, %lx != %lx\n", j, c, r2[j]);
                    abort();
                }

                flint_free(i1);
            }
        }
    }
    
    /* test squaring */
    for (iters = 0; iters < 100; iters++)
    {
        for (depth = 6; depth <= 18; depth++)
        {
            for (w = 1; w <= 2; w++)
            {
                mp_size_t n = (1UL<<depth);
                mp_bitcnt_t bits = n*w;
                mp_size_t int_limbs = bits/FLINT_BITS;
                mp_size_t j;
                mp_limb_t c, * i1, * r1, * r2, * tt;
        
                i1 = flint_malloc(5*(int_limbs+1)*sizeof(mp_limb_t));
                r1 = i1 + int_limbs + 1;
                r2 = r1 + int_limbs + 1;
                tt = r2 + int_limbs + 1;

                random_fermat(i1, state, int_limbs);
                mpn_normmod_2expp1(i1, int_limbs);
                
                fft_mulmod_2expp1(r2, i1, i1, n, w, tt);
                c = i1[int_limbs] + 2*i1[int_limbs];
                c = mpn_mulmod_2expp1(r1, i1, i1, c, int_limbs*FLINT_BITS, tt);
            
                for (j = 0; j < int_limbs; j++)
                {
                    if (r1[j] != r2[j]) 
                    {
                        printf("error in limb %ld, %lx != %lx\n", j, r1[j], r2[j]);
                        abort();
                    }
                }

                if (c != r2[int_limbs])
                {
                    printf("error in limb %ld, %lx != %lx\n", j, c, r2[j]);
                    abort();
                }

                flint_free(i1);
            }
        }
    }
    
    flint_randclear(state);
    
    printf("PASS\n");
    return 0;
}