int cfft2_init(int pad1 /* padding on the first axis */, int nx, int ny /* input data size */, int *nx2, int *ny2 /* padded data size */, int *n_local, int *o_local /* local size & start */, MPI_Comm comm) /*< initialize >*/ { if (threads_ok) threads_ok = fftwf_init_threads(); fftwf_mpi_init(); if (false) sf_warning("Using threaded FFTW3! \n"); if (threads_ok) fftwf_plan_with_nthreads(omp_get_max_threads()); nk = n1 = kiss_fft_next_fast_size(nx*pad1); n2 = kiss_fft_next_fast_size(ny); alloc_local = fftwf_mpi_local_size_2d(n2, n1, comm, &local_n0, &local_0_start); //cc = sf_complexalloc2(n1,n2); //dd = sf_complexalloc2(nk,n2); cc = sf_complexalloc(alloc_local); dd = sf_complexalloc(alloc_local); cfg = fftwf_mpi_plan_dft_2d(n2,n1, (fftwf_complex *) cc, (fftwf_complex *) dd, comm, FFTW_FORWARD, FFTW_MEASURE); icfg = fftwf_mpi_plan_dft_2d(n2,n1, (fftwf_complex *) dd, (fftwf_complex *) cc, comm, FFTW_BACKWARD, FFTW_MEASURE); if (NULL == cfg || NULL == icfg) sf_error("FFTW failure."); *nx2 = n1; *ny2 = n2; *n_local = (int) local_n0; *o_local = (int) local_0_start; wt = 1.0/(n1*n2); return (nk*n2); }
int main(int argc, char **argv) { fftwf_plan plan; fftwf_complex *data; ptrdiff_t alloc_local, local_n0, local_0_start, i, j; if (argc != 2) { printf("usage: ./fft_mpi MATRIX_SIZE\n"); exit(1); } const ptrdiff_t N0 = atoi(argv[1]); const ptrdiff_t N1 = N0; int id; double startTime, totalTime; totalTime = 0; MPI_Init(&argc, &argv); MPI_Comm_rank(MPI_COMM_WORLD, &id); fftwf_mpi_init(); /* get local data size and allocate */ alloc_local = fftwf_mpi_local_size_2d(N0, N1, MPI_COMM_WORLD, &local_n0, &local_0_start); data = fftwf_alloc_complex(alloc_local);//(fftwf_complex *) fftwf_malloc(sizeof(fftw_complex) * alloc_local); /* create plan for in-place forward DFT */ plan = fftwf_mpi_plan_dft_2d(N0, N1, data, data, MPI_COMM_WORLD, FFTW_FORWARD, FFTW_ESTIMATE); /* initialize data to some function my_function(x,y) */ for (i = 0; i < local_n0; ++i) for (j = 0; j < N1; ++j){ data[i*N1 + j][0] = local_0_start;;//my_function(local_0_start + i, j); data[i*N1 + j][1]=i; } /* compute transforms, in-place, as many times as desired */ MPI_Barrier(MPI_COMM_WORLD); if (id == 0) { startTime = getTime(); } fftwf_execute(plan); MPI_Barrier(MPI_COMM_WORLD); if (id == 0) { totalTime += getTime() - startTime; } fftwf_destroy_plan(plan); fftwf_mpi_cleanup(); if (id == 0) { printf("%.5f\n", totalTime); } MPI_Finalize(); return 0; }