int main(int argc, char* argv []) { /* check argument count for a valid range */ if ( argc != 14 ) { print_help(); return -1; } char* l_arch = NULL; char* l_precision = NULL; int l_m = 0; int l_n = 0; int l_k = 0; int l_lda = 0; int l_ldb = 0; int l_ldc = 0; int l_aligned_a = 0; int l_aligned_c = 0; int l_alpha = 0; int l_beta = 0; int l_single_precision = 0; int l_prefetch = 0; /* xgemm sizes */ l_m = atoi(argv[1]); l_n = atoi(argv[2]); l_k = atoi(argv[3]); l_lda = atoi(argv[4]); l_ldb = atoi(argv[5]); l_ldc = atoi(argv[6]); /* some sugar */ l_alpha = atoi(argv[7]); l_beta = atoi(argv[8]); l_aligned_a = atoi(argv[9]); l_aligned_c = atoi(argv[10]); /* arch specific stuff */ l_arch = argv[11]; l_precision = argv[13]; /* set value of prefetch flag */ if (strcmp("nopf", argv[12]) == 0) { l_prefetch = LIBXSMM_PREFETCH_NONE; } else if (strcmp("pfsigonly", argv[12]) == 0) { l_prefetch = LIBXSMM_PREFETCH_SIGNATURE; } else if (strcmp("BL2viaC", argv[12]) == 0) { l_prefetch = LIBXSMM_PREFETCH_BL2_VIA_C; } else if (strcmp("curAL2", argv[12]) == 0) { l_prefetch = LIBXSMM_PREFETCH_AL2_AHEAD; } else if (strcmp("curAL2_BL2viaC", argv[12]) == 0) { l_prefetch = LIBXSMM_PREFETCH_AL2BL2_VIA_C_AHEAD; } else if (strcmp("AL2", argv[12]) == 0) { l_prefetch = LIBXSMM_PREFETCH_AL2; } else if (strcmp("AL2_BL2viaC", argv[12]) == 0) { l_prefetch = LIBXSMM_PREFETCH_AL2BL2_VIA_C; } else if (strcmp("AL2jpst", argv[12]) == 0) { l_prefetch = LIBXSMM_PREFETCH_AL2_JPST; } else if (strcmp("AL2jpst_BL2viaC", argv[12]) == 0) { l_prefetch = LIBXSMM_PREFETCH_AL2BL2_VIA_C_JPST; } else { print_help(); return -1; } /* check value of arch flag */ if ( (strcmp(l_arch, "snb") != 0) && (strcmp(l_arch, "hsw") != 0) && (strcmp(l_arch, "knl") != 0) && (strcmp(l_arch, "skx") != 0) ) { print_help(); return -1; } /* check and evaluate precison flag */ if ( strcmp(l_precision, "SP") == 0 ) { l_single_precision = 1; } else if ( strcmp(l_precision, "DP") == 0 ) { l_single_precision = 0; } else { print_help(); return -1; } /* check alpha */ if ((l_alpha != -1) && (l_alpha != 1)) { print_help(); return -1; } /* check beta */ if ((l_beta != 0) && (l_beta != 1)) { print_help(); return -1; } libxsmm_xgemm_descriptor l_xgemm_desc; if ( l_m < 0 ) { l_xgemm_desc.m = 0; } else { l_xgemm_desc.m = l_m; } if ( l_n < 0 ) { l_xgemm_desc.n = 0; } else { l_xgemm_desc.n = l_n; } if ( l_k < 0 ) { l_xgemm_desc.k = 0; } else { l_xgemm_desc.k = l_k; } if ( l_lda < 0 ) { l_xgemm_desc.lda = 0; } else { l_xgemm_desc.lda = l_lda; } if ( l_ldb < 0 ) { l_xgemm_desc.ldb = 0; } else { l_xgemm_desc.ldb = l_ldb; } if ( l_ldc < 0 ) { l_xgemm_desc.ldc = 0; } else { l_xgemm_desc.ldc = l_ldc; } l_xgemm_desc.alpha = l_alpha; l_xgemm_desc.beta = l_beta; l_xgemm_desc.trans_a = 'n'; l_xgemm_desc.trans_b = 'n'; if (l_aligned_a == 0) { l_xgemm_desc.aligned_a = 0; } else { l_xgemm_desc.aligned_a = 1; } if (l_aligned_c == 0) { l_xgemm_desc.aligned_c = 0; } else { l_xgemm_desc.aligned_c = 1; } l_xgemm_desc.single_precision = l_single_precision; l_xgemm_desc.prefetch = l_prefetch; /* init data structures */ double* l_a_d; double* l_b_d; double* l_c_d; double* l_c_gold_d; float* l_a_f; float* l_b_f; float* l_c_f; float* l_c_gold_f; if ( l_xgemm_desc.single_precision == 0 ) { l_a_d = (double*)_mm_malloc(l_xgemm_desc.lda * l_xgemm_desc.k * sizeof(double), 64); l_b_d = (double*)_mm_malloc(l_xgemm_desc.ldb * l_xgemm_desc.n * sizeof(double), 64); l_c_d = (double*)_mm_malloc(l_xgemm_desc.ldc * l_xgemm_desc.n * sizeof(double), 64); l_c_gold_d = (double*)_mm_malloc(l_xgemm_desc.ldc * l_xgemm_desc.n * sizeof(double), 64); init_double(l_a_d, l_b_d, l_c_d, l_c_gold_d, &l_xgemm_desc); } else { l_a_f = (float*)_mm_malloc(l_xgemm_desc.lda * l_xgemm_desc.k * sizeof(float), 64); l_b_f = (float*)_mm_malloc(l_xgemm_desc.ldb * l_xgemm_desc.n * sizeof(float), 64); l_c_f = (float*)_mm_malloc(l_xgemm_desc.ldc * l_xgemm_desc.n * sizeof(float), 64); l_c_gold_f = (float*)_mm_malloc(l_xgemm_desc.ldc * l_xgemm_desc.n * sizeof(float), 64); init_float(l_a_f, l_b_f, l_c_f, l_c_gold_f, &l_xgemm_desc); } /* print some output... */ printf("------------------------------------------------\n"); printf("RUNNING (%ix%i) X (%ix%i) = (%ix%i)", l_xgemm_desc.m, l_xgemm_desc.k, l_xgemm_desc.k, l_xgemm_desc.n, l_xgemm_desc.m, l_xgemm_desc.n); if ( l_xgemm_desc.single_precision == 0 ) { printf(", DP\n"); } else { printf(", SP\n"); } printf("------------------------------------------------\n"); /* run C */ if ( l_xgemm_desc.single_precision == 0 ) { run_gold_double( l_a_d, l_b_d, l_c_gold_d, &l_xgemm_desc ); } else { run_gold_float( l_a_f, l_b_f, l_c_gold_f, &l_xgemm_desc ); } /* run jit */ if ( l_xgemm_desc.single_precision == 0 ) { run_jit_double( l_a_d, l_b_d, l_c_d, &l_xgemm_desc, l_arch ); } else { run_jit_float( l_a_f, l_b_f, l_c_f, &l_xgemm_desc, l_arch ); } /* test result */ if ( l_xgemm_desc.single_precision == 0 ) { max_error_double( l_c_d, l_c_gold_d, &l_xgemm_desc ); } else { max_error_float( l_c_f, l_c_gold_f, &l_xgemm_desc ); } /* free */ if ( l_xgemm_desc.single_precision == 0 ) { _mm_free(l_a_d); _mm_free(l_b_d); _mm_free(l_c_d); _mm_free(l_c_gold_d); } else { _mm_free(l_a_f); _mm_free(l_b_f); _mm_free(l_c_f); _mm_free(l_c_gold_f); } printf("------------------------------------------------\n"); return 0; }
int main(int argc, char* argv []) { char* l_arch = NULL; char* l_precision = NULL; int l_m = 0; int l_n = 0; int l_k = 0; int l_lda = 0; int l_ldb = 0; int l_ldc = 0; int l_aligned_a = 0; int l_aligned_c = 0; int l_alpha = 0; int l_beta = 0; int l_single_precision = 0; libxsmm_prefetch_type l_prefetch = 0; libxsmm_gemm_descriptor l_xgemm_desc; /* init data structures */ double* l_a_d; double* l_b_d; double* l_c_d; double* l_c_gold_d; float* l_a_f; float* l_b_f; float* l_c_f; float* l_c_gold_f; /* check argument count for a valid range */ if ( argc != 15 ) { print_help(); return -1; } /* xgemm sizes */ l_m = atoi(argv[1]); l_n = atoi(argv[2]); l_k = atoi(argv[3]); l_lda = atoi(argv[4]); l_ldb = atoi(argv[5]); l_ldc = atoi(argv[6]); /* some sugar */ l_alpha = atoi(argv[7]); l_beta = atoi(argv[8]); l_aligned_a = atoi(argv[9]); l_aligned_c = atoi(argv[10]); /* arch specific stuff */ l_arch = argv[11]; l_precision = argv[13]; g_jit_code_reps = atoi(argv[14]); /* set value of prefetch flag */ if (strcmp("nopf", argv[12]) == 0) { l_prefetch = LIBXSMM_PREFETCH_NONE; } else if (strcmp("pfsigonly", argv[12]) == 0) { l_prefetch = LIBXSMM_PREFETCH_SIGONLY; } else if (strcmp("BL2viaC", argv[12]) == 0) { l_prefetch = LIBXSMM_PREFETCH_BL2_VIA_C; } else if (strcmp("curAL2", argv[12]) == 0) { l_prefetch = LIBXSMM_PREFETCH_AL2_AHEAD; } else if (strcmp("curAL2_BL2viaC", argv[12]) == 0) { l_prefetch = LIBXSMM_PREFETCH_AL2BL2_VIA_C_AHEAD; } else if (strcmp("AL2", argv[12]) == 0) { l_prefetch = LIBXSMM_PREFETCH_AL2; } else if (strcmp("AL2_BL2viaC", argv[12]) == 0) { l_prefetch = LIBXSMM_PREFETCH_AL2BL2_VIA_C; } else if (strcmp("AL2jpst", argv[12]) == 0) { l_prefetch = LIBXSMM_PREFETCH_AL2_JPST; } else if (strcmp("AL2jpst_BL2viaC", argv[12]) == 0) { l_prefetch = LIBXSMM_PREFETCH_AL2BL2_VIA_C_JPST; } else { print_help(); return -1; } /* check value of arch flag */ if ( (strcmp(l_arch, "snb") != 0) && (strcmp(l_arch, "hsw") != 0) && (strcmp(l_arch, "knl") != 0) && (strcmp(l_arch, "skx") != 0) ) { print_help(); return -1; } /* check and evaluate precison flag */ if ( strcmp(l_precision, "SP") == 0 ) { l_single_precision = 1; } else if ( strcmp(l_precision, "DP") == 0 ) { l_single_precision = 0; } else { print_help(); return -1; } /* check alpha */ if ((l_alpha != 1)) { print_help(); return -1; } /* check beta */ if ((l_beta != 0) && (l_beta != 1)) { print_help(); return -1; } LIBXSMM_GEMM_DESCRIPTOR(l_xgemm_desc, 1, (0 == l_single_precision ? 0 : LIBXSMM_GEMM_FLAG_F32PREC) | (0 != l_aligned_a ? LIBXSMM_GEMM_FLAG_ALIGN_A : 0) | (0 != l_aligned_c ? LIBXSMM_GEMM_FLAG_ALIGN_C : 0), l_m, l_n, l_k, l_lda, l_ldb, l_ldc, l_alpha, l_beta, l_prefetch); if ( l_single_precision == 0 ) { l_a_d = (double*)_mm_malloc(l_xgemm_desc.lda * l_xgemm_desc.k * sizeof(double), 64); l_b_d = (double*)_mm_malloc(l_xgemm_desc.ldb * l_xgemm_desc.n * sizeof(double), 64); l_c_d = (double*)_mm_malloc(l_xgemm_desc.ldc * l_xgemm_desc.n * sizeof(double), 64); l_c_gold_d = (double*)_mm_malloc(l_xgemm_desc.ldc * l_xgemm_desc.n * sizeof(double), 64); init_double(l_a_d, l_b_d, l_c_d, l_c_gold_d, &l_xgemm_desc); } else { l_a_f = (float*)_mm_malloc(l_xgemm_desc.lda * l_xgemm_desc.k * sizeof(float), 64); l_b_f = (float*)_mm_malloc(l_xgemm_desc.ldb * l_xgemm_desc.n * sizeof(float), 64); l_c_f = (float*)_mm_malloc(l_xgemm_desc.ldc * l_xgemm_desc.n * sizeof(float), 64); l_c_gold_f = (float*)_mm_malloc(l_xgemm_desc.ldc * l_xgemm_desc.n * sizeof(float), 64); init_float(l_a_f, l_b_f, l_c_f, l_c_gold_f, &l_xgemm_desc); } /* print some output... */ printf("------------------------------------------------\n"); printf("RUNNING (%ix%i) X (%ix%i) = (%ix%i)", l_xgemm_desc.m, l_xgemm_desc.k, l_xgemm_desc.k, l_xgemm_desc.n, l_xgemm_desc.m, l_xgemm_desc.n); if ( l_single_precision == 0 ) { printf(", DP\n"); } else { printf(", SP\n"); } printf("------------------------------------------------\n"); /* run C */ if ( l_single_precision == 0 ) { run_gold_double( l_a_d, l_b_d, l_c_gold_d, &l_xgemm_desc ); } else { run_gold_float( l_a_f, l_b_f, l_c_gold_f, &l_xgemm_desc ); } /* run jit */ if ( l_single_precision == 0 ) { run_jit_double( l_a_d, l_b_d, l_c_d, l_m, l_n, l_k, l_prefetch, l_arch ); } else { run_jit_float( l_a_f, l_b_f, l_c_f, l_m, l_n, l_k, l_prefetch, l_arch ); } /* test result */ if ( l_single_precision == 0 ) { max_error_double( l_c_d, l_c_gold_d, &l_xgemm_desc ); } else { max_error_float( l_c_f, l_c_gold_f, &l_xgemm_desc ); } /* free */ if ( l_single_precision == 0 ) { _mm_free(l_a_d); _mm_free(l_b_d); _mm_free(l_c_d); _mm_free(l_c_gold_d); } else { _mm_free(l_a_f); _mm_free(l_b_f); _mm_free(l_c_f); _mm_free(l_c_gold_f); } printf("------------------------------------------------\n"); return 0; }