transpose_mpi_plan transpose_mpi_create_plan(int nx, int ny, MPI_Comm transpose_comm) { transpose_mpi_plan p; int my_pe, n_pes, pe; int x_block_size, local_nx, local_x_start; int y_block_size, local_ny, local_y_start; transpose_mpi_exchange *exchange = 0; int step, send_block_size = 0, recv_block_size = 0, num_steps = 0; int **sched, sched_npes, sched_sortpe, sched_sort_ascending = 0; int *perm_block_dest = NULL; int num_perm_blocks = 0, perm_block_size = 0, perm_block; char *move = NULL; int move_size = 0; int *send_block_sizes = 0, *send_block_offsets = 0; int *recv_block_sizes = 0, *recv_block_offsets = 0; MPI_Comm comm; /* create a new "clone" communicator so that transpose communications do not interfere with caller communications. */ MPI_Comm_dup(transpose_comm, &comm); MPI_Comm_rank(comm,&my_pe); MPI_Comm_size(comm,&n_pes); /* work space for in-place transpose routine: */ move_size = (nx + ny) / 2; move = (char *) fftw_malloc(sizeof(char) * move_size); x_block_size = transpose_mpi_get_block_size(nx,n_pes); transpose_mpi_get_local_size(nx,my_pe,n_pes,&local_nx,&local_x_start); y_block_size = transpose_mpi_get_block_size(ny,n_pes); transpose_mpi_get_local_size(ny,my_pe,n_pes,&local_ny,&local_y_start); /* allocate pre-computed post-transpose permutation: */ perm_block_size = gcd(nx,x_block_size); num_perm_blocks = (nx / perm_block_size) * local_ny; perm_block_dest = (int *) fftw_malloc(sizeof(int) * num_perm_blocks); for (perm_block = 0; perm_block < num_perm_blocks; ++perm_block) perm_block_dest[perm_block] = num_perm_blocks; /* allocate block sizes/offsets arrays for out-of-place transpose: */ send_block_sizes = (int *) fftw_malloc(n_pes * sizeof(int)); send_block_offsets = (int *) fftw_malloc(n_pes * sizeof(int)); recv_block_sizes = (int *) fftw_malloc(n_pes * sizeof(int)); recv_block_offsets = (int *) fftw_malloc(n_pes * sizeof(int)); for (step = 0; step < n_pes; ++step) send_block_sizes[step] = send_block_offsets[step] = recv_block_sizes[step] = recv_block_offsets[step] = 0; if (local_nx > 0 || local_ny > 0) { sched_npes = n_pes; sched_sortpe = -1; for (pe = 0; pe < n_pes; ++pe) { int pe_nx, pe_x_start, pe_ny, pe_y_start; transpose_mpi_get_local_size(nx,pe,n_pes, &pe_nx,&pe_x_start); transpose_mpi_get_local_size(ny,pe,n_pes, &pe_ny,&pe_y_start); if (pe_nx == 0 && pe_ny == 0) { sched_npes = pe; break; } else if (pe_nx * y_block_size != pe_ny * x_block_size && pe_nx != 0 && pe_ny != 0) { if (sched_sortpe != -1) fftw_mpi_die("BUG: More than one PE needs sorting!\n"); sched_sortpe = pe; sched_sort_ascending = pe_nx * y_block_size > pe_ny * x_block_size; } } sched = make_comm_schedule(sched_npes); if (!sched) { MPI_Comm_free(&comm); return 0; } if (sched_sortpe != -1) { sort_comm_schedule(sched,sched_npes,sched_sortpe); if (!sched_sort_ascending) invert_comm_schedule(sched,sched_npes); } send_block_size = local_nx * y_block_size; recv_block_size = local_ny * x_block_size; num_steps = sched_npes; exchange = (transpose_mpi_exchange *) fftw_malloc(num_steps * sizeof(transpose_mpi_exchange)); if (!exchange) { free_comm_schedule(sched,sched_npes); MPI_Comm_free(&comm); return 0; } for (step = 0; step < sched_npes; ++step) { int dest_pe; int dest_nx, dest_x_start; int dest_ny, dest_y_start; int num_perm_blocks_received, num_perm_rows_received; exchange[step].dest_pe = dest_pe = exchange[step].block_num = sched[my_pe][step]; if (exchange[step].block_num == -1) fftw_mpi_die("BUG: schedule ended too early.\n"); transpose_mpi_get_local_size(nx,dest_pe,n_pes, &dest_nx,&dest_x_start); transpose_mpi_get_local_size(ny,dest_pe,n_pes, &dest_ny,&dest_y_start); exchange[step].send_size = local_nx * dest_ny; exchange[step].recv_size = dest_nx * local_ny; send_block_sizes[dest_pe] = exchange[step].send_size; send_block_offsets[dest_pe] = dest_pe * send_block_size; recv_block_sizes[dest_pe] = exchange[step].recv_size; recv_block_offsets[dest_pe] = dest_pe * recv_block_size; /* Precompute the post-transpose permutation (ugh): */ if (exchange[step].recv_size > 0) { num_perm_blocks_received = exchange[step].recv_size / perm_block_size; num_perm_rows_received = num_perm_blocks_received / local_ny; for (perm_block = 0; perm_block < num_perm_blocks_received; ++perm_block) { int old_block, new_block; old_block = perm_block + exchange[step].block_num * (recv_block_size / perm_block_size); new_block = perm_block % num_perm_rows_received + dest_x_start / perm_block_size + (perm_block / num_perm_rows_received) * (nx / perm_block_size); if (old_block >= num_perm_blocks || new_block >= num_perm_blocks) fftw_mpi_die("bad block index in permutation!"); perm_block_dest[old_block] = new_block; } } } free_comm_schedule(sched,sched_npes); } /* if (local_nx > 0 || local_ny > 0) */ p = (transpose_mpi_plan) fftw_malloc(sizeof(transpose_mpi_plan_struct)); if (!p) { fftw_free(exchange); MPI_Comm_free(&comm); return 0; } p->comm = comm; p->nx = nx; p->ny = ny; p->local_nx = local_nx; p->local_ny = local_ny; p->my_pe = my_pe; p->n_pes = n_pes; p->exchange = exchange; p->send_block_size = send_block_size; p->recv_block_size = recv_block_size; p->num_steps = num_steps; p->perm_block_dest = perm_block_dest; p->num_perm_blocks = num_perm_blocks; p->perm_block_size = perm_block_size; p->move = move; p->move_size = move_size; p->send_block_sizes = send_block_sizes; p->send_block_offsets = send_block_offsets; p->recv_block_sizes = recv_block_sizes; p->recv_block_offsets = recv_block_offsets; p->all_blocks_equal = send_block_size * n_pes * n_pes == nx * ny && recv_block_size * n_pes * n_pes == nx * ny; if (p->all_blocks_equal) for (step = 0; step < n_pes; ++step) if (send_block_sizes[step] != send_block_size || recv_block_sizes[step] != recv_block_size) { p->all_blocks_equal = 0; break; } if (nx % n_pes == 0 && ny % n_pes == 0 && !p->all_blocks_equal) fftw_mpi_die("n_pes divided dimensions but blocks are unequal!"); /* Set the type constant for passing to the MPI routines; here, we assume that TRANSPOSE_EL_TYPE is one of the floating-point types. */ if (sizeof(TRANSPOSE_EL_TYPE) == sizeof(double)) p->el_type = MPI_DOUBLE; else if (sizeof(TRANSPOSE_EL_TYPE) == sizeof(float)) p->el_type = MPI_FLOAT; else fftw_mpi_die("Unknown TRANSPOSE_EL_TYPE!\n"); return p; }
int main(int argc, char **argv) { int **sched; int npes = -1, sortpe = -1, steps, i; if (argc >= 2) { npes = atoi(argv[1]); if (npes <= 0) { fprintf(stderr,"npes must be positive!"); return 1; } } if (argc >= 3) { sortpe = atoi(argv[2]); if (sortpe < 0 || sortpe >= npes) { fprintf(stderr,"sortpe must be between 0 and npes-1.\n"); return 1; } } if (npes != -1) { printf("Computing schedule for npes = %d:\n",npes); sched = make_comm_schedule(npes); if (!sched) { fprintf(stderr,"Out of memory!"); return 6; } if (steps = check_comm_schedule(sched,npes)) printf("schedule OK (takes %d steps to complete).\n", steps); else printf("schedule not OK.\n"); print_comm_schedule(sched, npes); if (sortpe != -1) { printf("\nRe-creating schedule for pe = %d...\n", sortpe); int *sched1 = (int*) malloc(sizeof(int) * npes); for (i = 0; i < npes; ++i) sched1[i] = -1; fill1_comm_sched(sched1, sortpe, npes); printf(" ="); for (i = 0; i < npes; ++i) printf(" %*d", npes < 10 ? 1 : (npes < 100 ? 2 : 3), sched1[i]); printf("\n"); printf("\nSorting schedule for sortpe = %d...\n", sortpe); sort_comm_schedule(sched,npes,sortpe); if (steps = check_comm_schedule(sched,npes)) printf("schedule OK (takes %d steps to complete).\n", steps); else printf("schedule not OK.\n"); print_comm_schedule(sched, npes); printf("\nInverting schedule...\n"); invert_comm_schedule(sched,npes); if (steps = check_comm_schedule(sched,npes)) printf("schedule OK (takes %d steps to complete).\n", steps); else printf("schedule not OK.\n"); print_comm_schedule(sched, npes); free_comm_schedule(sched,npes); free(sched1); } } else { printf("Doing infinite tests...\n"); for (npes = 1; ; ++npes) { int *sched1 = (int*) malloc(sizeof(int) * npes); printf("npes = %d...",npes); sched = make_comm_schedule(npes); if (!sched) { fprintf(stderr,"Out of memory!\n"); return 5; } for (sortpe = 0; sortpe < npes; ++sortpe) { empty_comm_schedule(sched,npes); fill_comm_schedule(sched,npes); if (!check_comm_schedule(sched,npes)) { fprintf(stderr, "\n -- fill error for sortpe = %d!\n",sortpe); return 2; } for (i = 0; i < npes; ++i) sched1[i] = -1; fill1_comm_sched(sched1, sortpe, npes); for (i = 0; i < npes; ++i) if (sched1[i] != sched[sortpe][i]) fprintf(stderr, "\n -- fill1 error for pe = %d!\n", sortpe); sort_comm_schedule(sched,npes,sortpe); if (!check_comm_schedule(sched,npes)) { fprintf(stderr, "\n -- sort error for sortpe = %d!\n",sortpe); return 3; } invert_comm_schedule(sched,npes); if (!check_comm_schedule(sched,npes)) { fprintf(stderr, "\n -- invert error for sortpe = %d!\n", sortpe); return 4; } } free_comm_schedule(sched,npes); printf("OK\n"); if (npes % 50 == 0) printf("(...Hit Ctrl-C to stop...)\n"); free(sched1); } } return 0; }