LIBXSMM_API_INTERN void libxsmm_generator_spgemm_csc_bsparse_soa_avx256_512( libxsmm_generated_code* io_generated_code, const libxsmm_gemm_descriptor* i_xgemm_desc, const char* i_arch, const unsigned int* i_row_idx, const unsigned int* i_column_idx, const void* i_values ) { unsigned int l_n = 0; unsigned int l_k = 0; unsigned int l_soa_width = 0; unsigned int l_max_cols = 0; unsigned int l_n_processed = 0; unsigned int l_n_limit = 0; unsigned int l_n_chunks = 0; unsigned int l_n_chunksize = 0; unsigned int l_found_mul = 0; unsigned int l_max_reg_block = 0; libxsmm_micro_kernel_config l_micro_kernel_config; libxsmm_loop_label_tracker l_loop_label_tracker; libxsmm_gp_reg_mapping l_gp_reg_mapping; LIBXSMM_UNUSED(i_values); /* select soa width */ if ( LIBXSMM_GEMM_PRECISION_F64 == LIBXSMM_GETENUM_INP( i_xgemm_desc->datatype ) ) { if ( strcmp(i_arch, "knl") == 0 || strcmp(i_arch, "knm") == 0 || strcmp(i_arch, "skx") == 0 || strcmp(i_arch, "clx") == 0 || strcmp(i_arch, "cpx") == 0 ) { l_soa_width = 8; l_max_reg_block = 28; } else { l_soa_width = 4; l_max_reg_block = 14; } } else { if ( strcmp(i_arch, "knl") == 0 || strcmp(i_arch, "knm") == 0 || strcmp(i_arch, "skx") == 0 || strcmp(i_arch, "clx") == 0 || strcmp(i_arch, "cpx") == 0 ) { l_soa_width = 16; l_max_reg_block = 28; } else { l_soa_width = 8; l_max_reg_block = 14; } } /* define gp register mapping */ libxsmm_reset_x86_gp_reg_mapping( &l_gp_reg_mapping ); /* matching calling convention on Linux */ #if defined(_WIN32) || defined(__CYGWIN__) l_gp_reg_mapping.gp_reg_a = LIBXSMM_X86_GP_REG_RCX; l_gp_reg_mapping.gp_reg_b = LIBXSMM_X86_GP_REG_RDX; l_gp_reg_mapping.gp_reg_c = LIBXSMM_X86_GP_REG_R8; /* TODO: full support for Windows calling convention */ l_gp_reg_mapping.gp_reg_a_prefetch = LIBXSMM_X86_GP_REG_RDI; l_gp_reg_mapping.gp_reg_b_prefetch = LIBXSMM_X86_GP_REG_RSI; #else /* match calling convention on Linux */ l_gp_reg_mapping.gp_reg_a = LIBXSMM_X86_GP_REG_RDI; l_gp_reg_mapping.gp_reg_b = LIBXSMM_X86_GP_REG_RSI; l_gp_reg_mapping.gp_reg_c = LIBXSMM_X86_GP_REG_RDX; l_gp_reg_mapping.gp_reg_a_prefetch = LIBXSMM_X86_GP_REG_RCX; l_gp_reg_mapping.gp_reg_b_prefetch = LIBXSMM_X86_GP_REG_R8; #endif l_gp_reg_mapping.gp_reg_c_prefetch = LIBXSMM_X86_GP_REG_UNDEF; l_gp_reg_mapping.gp_reg_mloop = LIBXSMM_X86_GP_REG_R12; l_gp_reg_mapping.gp_reg_nloop = LIBXSMM_X86_GP_REG_R13; l_gp_reg_mapping.gp_reg_kloop = LIBXSMM_X86_GP_REG_R14; l_gp_reg_mapping.gp_reg_help_0 = LIBXSMM_X86_GP_REG_UNDEF; l_gp_reg_mapping.gp_reg_help_1 = LIBXSMM_X86_GP_REG_UNDEF; l_gp_reg_mapping.gp_reg_help_2 = LIBXSMM_X86_GP_REG_UNDEF; l_gp_reg_mapping.gp_reg_help_3 = LIBXSMM_X86_GP_REG_UNDEF; l_gp_reg_mapping.gp_reg_help_4 = LIBXSMM_X86_GP_REG_UNDEF; l_gp_reg_mapping.gp_reg_help_5 = LIBXSMM_X86_GP_REG_UNDEF; /* define loop_label_tracker */ libxsmm_reset_loop_label_tracker( &l_loop_label_tracker ); /* define the micro kernel code gen properties */ libxsmm_generator_gemm_init_micro_kernel_config_fullvector( &l_micro_kernel_config, i_xgemm_desc, i_arch, 0 ); /* get max column in C */ l_max_cols = i_xgemm_desc->n; for ( l_n = 0; l_n < i_xgemm_desc->n; l_n++ ) { if ( i_column_idx[l_n] == i_column_idx[i_xgemm_desc->n] ) { l_max_cols = l_n+1; } } /* calculate the chunk size of current columns to work on */ l_n_chunks = ( (l_max_cols % l_max_reg_block) == 0 ) ? (l_max_cols / l_max_reg_block) : (l_max_cols / l_max_reg_block) + 1; assert(0 != l_n_chunks); /* mute static analysis (division-by-zero); such invalid input must be caught upfront */ l_n_chunksize = ( (l_max_cols % l_n_chunks) == 0 ) ? (l_max_cols / l_n_chunks) : (l_max_cols / l_n_chunks) + 1; /* open asm */ libxsmm_x86_instruction_open_stream( io_generated_code, &l_gp_reg_mapping, i_arch, i_xgemm_desc->prefetch ); /* m loop */ libxsmm_x86_instruction_register_jump_back_label( io_generated_code, &l_loop_label_tracker ); libxsmm_x86_instruction_alu_imm( io_generated_code, l_micro_kernel_config.alu_add_instruction, l_gp_reg_mapping.gp_reg_mloop, 1 ); /* loop over n-blocks */ l_n_processed = 0; l_n_limit = l_n_chunksize; while ( l_n_processed < l_max_cols ) { #if 0 printf("l_max_cols: %i, l_n_processed: %i, l_n_limit: %i\n", l_max_cols, l_n_processed, l_n_limit); #endif /* load C accumulator */ for ( l_n = 0; l_n < l_n_limit - l_n_processed; l_n++ ) { if (0 != (LIBXSMM_GEMM_FLAG_BETA_0 & i_xgemm_desc->flags)) { /* Beta=0 */ libxsmm_x86_instruction_vec_compute_reg( io_generated_code, l_micro_kernel_config.instruction_set, l_micro_kernel_config.vxor_instruction, l_micro_kernel_config.vector_name, l_n, l_n, l_n ); } else { libxsmm_x86_instruction_vec_move( io_generated_code, l_micro_kernel_config.instruction_set, l_micro_kernel_config.c_vmove_instruction, l_gp_reg_mapping.gp_reg_c, LIBXSMM_X86_GP_REG_UNDEF, 0, (l_n_processed + l_n)*l_soa_width*l_micro_kernel_config.datatype_size, l_micro_kernel_config.vector_name, l_n, 0, 1, 0 ); } } /* do dense soa times sparse multiplication */ for ( l_k = 0; l_k < (unsigned int)i_xgemm_desc->k; l_k++ ) { unsigned int l_found_qmadd = 0; unsigned int l_col_k = 0; unsigned int l_column_active[28]; int l_nnz_idx[28][4]; /* reset helpers */ for ( l_n = 0; l_n < l_n_limit - l_n_processed; l_n++ ) { l_column_active[l_n] = 0; l_nnz_idx[l_n][0] = -1; l_nnz_idx[l_n][1] = -1; l_nnz_idx[l_n][2] = -1; l_nnz_idx[l_n][3] = -1; } l_found_mul = 0; /* let's figure out if we can apply qmadd when being sin F32 setting and on KNM */ if ( (l_k < ((unsigned int)i_xgemm_desc->k - 3)) && (l_micro_kernel_config.instruction_set == LIBXSMM_X86_AVX512_KNM) && (LIBXSMM_GEMM_PRECISION_F32 == LIBXSMM_GETENUM_INP( i_xgemm_desc->datatype ) ) ) { /* loop over the columns of B/C */ for ( l_n = 0; l_n < l_n_limit - l_n_processed; l_n++ ) { unsigned int l_found = 0; unsigned int l_acol_k = 0; unsigned int l_col_elements = i_column_idx[l_n_processed+l_n+1] - i_column_idx[l_n_processed+l_n]; unsigned int l_cur_column = i_column_idx[l_n_processed+l_n]; for ( l_col_k = 0; l_col_k < l_col_elements; l_col_k++ ) { for ( l_acol_k = l_found; l_acol_k < 4; l_acol_k++ ) { if ( (l_k + l_acol_k) == i_row_idx[l_cur_column + l_col_k] ) { l_nnz_idx[l_n][l_acol_k] = l_cur_column + l_col_k; l_found = l_acol_k+1; } if (l_found == 4) { l_col_k = l_col_elements; } } } /* let's check if we can apply qmadd in col l_n */ if ( (l_nnz_idx[l_n][0] != -1) && (l_nnz_idx[l_n][1] != -1) && (l_nnz_idx[l_n][2] != -1) && (l_nnz_idx[l_n][3] != -1) ) { l_column_active[l_n] = 2; l_found_qmadd = 1; l_found_mul = 1; } else { /* let's check if we have at least one entry in the column that matches one of the four entries */ if ( (l_nnz_idx[l_n][0] != -1) || (l_nnz_idx[l_n][1] != -1) || (l_nnz_idx[l_n][2] != -1) || (l_nnz_idx[l_n][3] != -1) ) { l_column_active[l_n] = 1; l_found_mul = 1; } else { l_column_active[l_n] = 0; } } } } if ( l_found_qmadd == 0 ) { /* loop over the columns of B/C */ for ( l_n = 0; l_n < l_n_limit - l_n_processed; l_n++ ) { unsigned int l_col_elements = i_column_idx[l_n_processed+l_n+1] - i_column_idx[l_n_processed+l_n]; unsigned int l_cur_column = i_column_idx[l_n_processed+l_n]; /* search for entries matching that k */ for ( l_col_k = 0; l_col_k < l_col_elements; l_col_k++ ) { if ( l_k == i_row_idx[l_cur_column + l_col_k] ) { l_nnz_idx[l_n][0] = l_cur_column + l_col_k; l_col_k = l_col_elements; } } /* let's check if we have an entry in the column that matches the k from A */ if ( (l_nnz_idx[l_n][0] != -1) ) { l_column_active[l_n] = 1; l_found_mul = 1; } else { l_column_active[l_n] = 0; } } } /* First case: we can use qmadd */ if ( l_found_qmadd != 0 ) { unsigned int l_lcl_k = 0; for ( l_lcl_k = 0; l_lcl_k < 4; l_lcl_k++ ) { libxsmm_x86_instruction_vec_move( io_generated_code, l_micro_kernel_config.instruction_set, l_micro_kernel_config.a_vmove_instruction, l_gp_reg_mapping.gp_reg_a, LIBXSMM_X86_GP_REG_UNDEF, 0, (l_k+l_lcl_k)*l_soa_width*l_micro_kernel_config.datatype_size, l_micro_kernel_config.vector_name, l_max_reg_block+l_lcl_k, 0, 1, 0 ); } /* loop over the columns of B/C */ for ( l_n = 0; l_n < l_n_limit - l_n_processed; l_n++ ) { /* issue a qmadd */ if ( l_column_active[l_n] == 2 ) { libxsmm_x86_instruction_vec_compute_qfma( io_generated_code, l_micro_kernel_config.instruction_set, LIBXSMM_X86_INSTR_V4FMADDPS, l_gp_reg_mapping.gp_reg_b, LIBXSMM_X86_GP_REG_UNDEF, 0, l_nnz_idx[l_n][0] * l_micro_kernel_config.datatype_size, l_micro_kernel_config.vector_name, l_max_reg_block, l_n ); } else if ( l_column_active[l_n] == 1 ) { for ( l_lcl_k = 0; l_lcl_k < 4; l_lcl_k++ ) { if ( l_nnz_idx[l_n][l_lcl_k] != -1 ) { libxsmm_x86_instruction_vec_compute_mem( io_generated_code, l_micro_kernel_config.instruction_set, l_micro_kernel_config.vmul_instruction, 1, l_gp_reg_mapping.gp_reg_b, LIBXSMM_X86_GP_REG_UNDEF, 0, l_nnz_idx[l_n][l_lcl_k] * l_micro_kernel_config.datatype_size, l_micro_kernel_config.vector_name, l_max_reg_block+l_lcl_k, l_n ); } } } } /* increment by additional 3 columns */ l_k += 3; } else if ( l_found_mul != 0 ) { libxsmm_x86_instruction_vec_move( io_generated_code, l_micro_kernel_config.instruction_set, l_micro_kernel_config.a_vmove_instruction, l_gp_reg_mapping.gp_reg_a, LIBXSMM_X86_GP_REG_UNDEF, 0, l_k*l_soa_width*l_micro_kernel_config.datatype_size, l_micro_kernel_config.vector_name, l_max_reg_block, 0, 1, 0 ); /* loop over the columns of B/C */ for ( l_n = 0; l_n < l_n_limit - l_n_processed; l_n++ ) { if ( l_nnz_idx[l_n][0] != -1 ) { if ( strcmp(i_arch, "knl") == 0 || strcmp(i_arch, "knm") == 0 || strcmp(i_arch, "skx") == 0 || strcmp(i_arch, "clx") == 0 || strcmp(i_arch, "cpx") == 0 ) { libxsmm_x86_instruction_vec_compute_mem( io_generated_code, l_micro_kernel_config.instruction_set, l_micro_kernel_config.vmul_instruction, 1, l_gp_reg_mapping.gp_reg_b, LIBXSMM_X86_GP_REG_UNDEF, 0, l_nnz_idx[l_n][0] * l_micro_kernel_config.datatype_size, l_micro_kernel_config.vector_name, l_max_reg_block, l_n ); } else if ( strcmp(i_arch, "hsw") == 0 ) { libxsmm_x86_instruction_vec_move( io_generated_code, l_micro_kernel_config.instruction_set, l_micro_kernel_config.b_vmove_instruction, l_gp_reg_mapping.gp_reg_b, LIBXSMM_X86_GP_REG_UNDEF, 0, l_nnz_idx[l_n][0] * l_micro_kernel_config.datatype_size, l_micro_kernel_config.vector_name, 15, 0, 1, 0 ); libxsmm_x86_instruction_vec_compute_reg( io_generated_code, l_micro_kernel_config.instruction_set, l_micro_kernel_config.vmul_instruction, l_micro_kernel_config.vector_name, l_max_reg_block, 15, l_n ); } else { libxsmm_x86_instruction_vec_move( io_generated_code, l_micro_kernel_config.instruction_set, l_micro_kernel_config.b_vmove_instruction, l_gp_reg_mapping.gp_reg_b, LIBXSMM_X86_GP_REG_UNDEF, 0, l_nnz_idx[l_n][0] * l_micro_kernel_config.datatype_size, l_micro_kernel_config.vector_name, 15, 0, 1, 0 ); libxsmm_x86_instruction_vec_compute_reg( io_generated_code, l_micro_kernel_config.instruction_set, l_micro_kernel_config.vmul_instruction, l_micro_kernel_config.vector_name, l_max_reg_block, 15, 15 ); libxsmm_x86_instruction_vec_compute_reg( io_generated_code, l_micro_kernel_config.instruction_set, l_micro_kernel_config.vadd_instruction, l_micro_kernel_config.vector_name, 15, l_n, l_n ); } } } } else { /* shouldn't happen */ } } /* store C accumulator */ for ( l_n = 0; l_n < l_n_limit - l_n_processed; l_n++ ) { libxsmm_x86_instruction_vec_move( io_generated_code, l_micro_kernel_config.instruction_set, l_micro_kernel_config.c_vmove_instruction, l_gp_reg_mapping.gp_reg_c, LIBXSMM_X86_GP_REG_UNDEF, 0, (l_n_processed + l_n)*l_soa_width*l_micro_kernel_config.datatype_size, l_micro_kernel_config.vector_name, l_n, 0, 0, 1 ); } /* adjust n progression */ l_n_processed += l_n_chunksize; l_n_limit = LIBXSMM_MIN(l_n_processed + l_n_chunksize, l_max_cols); } /* advance C pointer */ libxsmm_x86_instruction_alu_imm( io_generated_code, l_micro_kernel_config.alu_add_instruction, l_gp_reg_mapping.gp_reg_c, l_micro_kernel_config.datatype_size*l_soa_width*i_xgemm_desc->ldc); /* advance A pointer */ libxsmm_x86_instruction_alu_imm( io_generated_code, l_micro_kernel_config.alu_add_instruction, l_gp_reg_mapping.gp_reg_a, l_micro_kernel_config.datatype_size*l_soa_width*i_xgemm_desc->lda); /* close m loop */ libxsmm_x86_instruction_alu_imm( io_generated_code, l_micro_kernel_config.alu_cmp_instruction, l_gp_reg_mapping.gp_reg_mloop, i_xgemm_desc->m ); libxsmm_x86_instruction_jump_back_to_label( io_generated_code, l_micro_kernel_config.alu_jmp_instruction, &l_loop_label_tracker ); /* close asm */ libxsmm_x86_instruction_close_stream( io_generated_code, &l_gp_reg_mapping, i_arch, i_xgemm_desc->prefetch ); }
LIBXSMM_INTERNAL_API_DEFINITION void libxsmm_generator_gemm_init_micro_kernel_config_scalar( libxsmm_micro_kernel_config* io_micro_kernel_config, const libxsmm_gemm_descriptor* i_xgemm_desc, const char* i_arch, const unsigned int i_use_masking_a_c ) { if( strcmp( i_arch, "wsm" ) == 0 ) { io_micro_kernel_config->instruction_set = LIBXSMM_X86_SSE3; io_micro_kernel_config->vector_reg_count = 16; io_micro_kernel_config->use_masking_a_c = i_use_masking_a_c; io_micro_kernel_config->vector_name = 'x'; if ( (LIBXSMM_GEMM_FLAG_F32PREC & i_xgemm_desc->flags) == 0 ) { io_micro_kernel_config->vector_length = 1; io_micro_kernel_config->datatype_size = 8; io_micro_kernel_config->a_vmove_instruction = LIBXSMM_X86_INSTR_MOVSD; io_micro_kernel_config->b_vmove_instruction = LIBXSMM_X86_INSTR_MOVSD; io_micro_kernel_config->b_shuff_instruction = LIBXSMM_X86_INSTR_UNDEF; io_micro_kernel_config->c_vmove_instruction = LIBXSMM_X86_INSTR_MOVSD; io_micro_kernel_config->vxor_instruction = LIBXSMM_X86_INSTR_XORPD; io_micro_kernel_config->vmul_instruction = LIBXSMM_X86_INSTR_MULSD; io_micro_kernel_config->vadd_instruction = LIBXSMM_X86_INSTR_ADDSD; } else { io_micro_kernel_config->vector_length = 1; io_micro_kernel_config->datatype_size = 4; io_micro_kernel_config->a_vmove_instruction = LIBXSMM_X86_INSTR_MOVSS; io_micro_kernel_config->b_vmove_instruction = LIBXSMM_X86_INSTR_MOVSS; io_micro_kernel_config->b_shuff_instruction = LIBXSMM_X86_INSTR_UNDEF; io_micro_kernel_config->c_vmove_instruction = LIBXSMM_X86_INSTR_MOVSS; io_micro_kernel_config->vxor_instruction = LIBXSMM_X86_INSTR_XORPS; io_micro_kernel_config->vmul_instruction = LIBXSMM_X86_INSTR_MULSS; io_micro_kernel_config->vadd_instruction = LIBXSMM_X86_INSTR_ADDSS; } } else if( strcmp( i_arch, "snb" ) == 0 ) { io_micro_kernel_config->instruction_set = LIBXSMM_X86_AVX; io_micro_kernel_config->vector_reg_count = 16; io_micro_kernel_config->use_masking_a_c = i_use_masking_a_c; io_micro_kernel_config->vector_name = 'x'; if ( (LIBXSMM_GEMM_FLAG_F32PREC & i_xgemm_desc->flags) == 0 ) { io_micro_kernel_config->vector_length = 1; io_micro_kernel_config->datatype_size = 8; io_micro_kernel_config->a_vmove_instruction = LIBXSMM_X86_INSTR_VMOVSD; io_micro_kernel_config->b_vmove_instruction = LIBXSMM_X86_INSTR_VMOVSD; io_micro_kernel_config->b_shuff_instruction = LIBXSMM_X86_INSTR_UNDEF; io_micro_kernel_config->c_vmove_instruction = LIBXSMM_X86_INSTR_VMOVSD; io_micro_kernel_config->vxor_instruction = LIBXSMM_X86_INSTR_VXORPD; io_micro_kernel_config->vmul_instruction = LIBXSMM_X86_INSTR_VMULSD; io_micro_kernel_config->vadd_instruction = LIBXSMM_X86_INSTR_VADDSD; } else { io_micro_kernel_config->vector_length = 1; io_micro_kernel_config->datatype_size = 4; io_micro_kernel_config->a_vmove_instruction = LIBXSMM_X86_INSTR_VMOVSS; io_micro_kernel_config->b_vmove_instruction = LIBXSMM_X86_INSTR_VMOVSS; io_micro_kernel_config->b_shuff_instruction = LIBXSMM_X86_INSTR_UNDEF; io_micro_kernel_config->c_vmove_instruction = LIBXSMM_X86_INSTR_VMOVSS; io_micro_kernel_config->vxor_instruction = LIBXSMM_X86_INSTR_VXORPS; io_micro_kernel_config->vmul_instruction = LIBXSMM_X86_INSTR_VMULSS; io_micro_kernel_config->vadd_instruction = LIBXSMM_X86_INSTR_VADDSS; } } else if ( strcmp( i_arch, "hsw" ) == 0 ) { io_micro_kernel_config->instruction_set = LIBXSMM_X86_AVX2; io_micro_kernel_config->vector_reg_count = 16; io_micro_kernel_config->use_masking_a_c = i_use_masking_a_c; io_micro_kernel_config->vector_name = 'x'; if ( (LIBXSMM_GEMM_FLAG_F32PREC & i_xgemm_desc->flags) == 0 ) { io_micro_kernel_config->vector_length = 1; io_micro_kernel_config->datatype_size = 8; io_micro_kernel_config->a_vmove_instruction = LIBXSMM_X86_INSTR_VMOVSD; io_micro_kernel_config->b_vmove_instruction = LIBXSMM_X86_INSTR_VMOVSD; io_micro_kernel_config->b_shuff_instruction = LIBXSMM_X86_INSTR_UNDEF; io_micro_kernel_config->c_vmove_instruction = LIBXSMM_X86_INSTR_VMOVSD; io_micro_kernel_config->vxor_instruction = LIBXSMM_X86_INSTR_VXORPD; io_micro_kernel_config->vmul_instruction = LIBXSMM_X86_INSTR_VFMADD231SD; io_micro_kernel_config->vadd_instruction = LIBXSMM_X86_INSTR_UNDEF; } else { io_micro_kernel_config->vector_length = 1; io_micro_kernel_config->datatype_size = 4; io_micro_kernel_config->a_vmove_instruction = LIBXSMM_X86_INSTR_VMOVSS; io_micro_kernel_config->b_vmove_instruction = LIBXSMM_X86_INSTR_VMOVSS; io_micro_kernel_config->b_shuff_instruction = LIBXSMM_X86_INSTR_UNDEF; io_micro_kernel_config->c_vmove_instruction = LIBXSMM_X86_INSTR_VMOVSS; io_micro_kernel_config->vxor_instruction = LIBXSMM_X86_INSTR_VXORPS; io_micro_kernel_config->vmul_instruction = LIBXSMM_X86_INSTR_VFMADD231SS; io_micro_kernel_config->vadd_instruction = LIBXSMM_X86_INSTR_UNDEF; } } else if ( (strcmp( i_arch, "knc" ) == 0) || (strcmp( i_arch, "knl" ) == 0) || (strcmp( i_arch, "skx" ) == 0) ) { if ((strcmp( i_arch, "knc" ) == 0)) { io_micro_kernel_config->instruction_set = LIBXSMM_X86_IMCI; #if !defined(NDEBUG) fprintf(stderr, "LIBXSMM WARNING, libxsmm_generator_gemm_init_micro_kernel_config_scalar, IMCI redirecting to fullvector, please fix the generation code!!!\n"); #endif libxsmm_generator_gemm_init_micro_kernel_config_fullvector( io_micro_kernel_config, i_xgemm_desc, i_arch, i_use_masking_a_c ); } else if ((strcmp( i_arch, "knl" ) == 0)) { io_micro_kernel_config->instruction_set = LIBXSMM_X86_AVX512_MIC; } else { io_micro_kernel_config->instruction_set = LIBXSMM_X86_AVX512_CORE; } io_micro_kernel_config->vector_reg_count = 16; io_micro_kernel_config->use_masking_a_c = i_use_masking_a_c; io_micro_kernel_config->vector_name = 'x'; if ( (LIBXSMM_GEMM_FLAG_F32PREC & i_xgemm_desc->flags) == 0 ) { io_micro_kernel_config->vector_length = 1; io_micro_kernel_config->datatype_size = 8; io_micro_kernel_config->a_vmove_instruction = LIBXSMM_X86_INSTR_VMOVSD; io_micro_kernel_config->b_vmove_instruction = LIBXSMM_X86_INSTR_VMOVSD; io_micro_kernel_config->b_shuff_instruction = LIBXSMM_X86_INSTR_UNDEF; io_micro_kernel_config->c_vmove_instruction = LIBXSMM_X86_INSTR_VMOVSD; io_micro_kernel_config->vxor_instruction = LIBXSMM_X86_INSTR_VPXORD; io_micro_kernel_config->vmul_instruction = LIBXSMM_X86_INSTR_VFMADD231SD; io_micro_kernel_config->vadd_instruction = LIBXSMM_X86_INSTR_VADDSD; } else { io_micro_kernel_config->vector_length = 1; io_micro_kernel_config->datatype_size = 4; io_micro_kernel_config->a_vmove_instruction = LIBXSMM_X86_INSTR_VMOVSS; io_micro_kernel_config->b_vmove_instruction = LIBXSMM_X86_INSTR_VMOVSS; io_micro_kernel_config->b_shuff_instruction = LIBXSMM_X86_INSTR_UNDEF; io_micro_kernel_config->c_vmove_instruction = LIBXSMM_X86_INSTR_VMOVSS; io_micro_kernel_config->vxor_instruction = LIBXSMM_X86_INSTR_VPXORD; io_micro_kernel_config->vmul_instruction = LIBXSMM_X86_INSTR_VFMADD231SS; io_micro_kernel_config->vadd_instruction = LIBXSMM_X86_INSTR_VADDSS; } } else { } io_micro_kernel_config->prefetch_instruction = LIBXSMM_X86_INSTR_PREFETCHT1; io_micro_kernel_config->alu_add_instruction = LIBXSMM_X86_INSTR_ADDQ; io_micro_kernel_config->alu_sub_instruction = LIBXSMM_X86_INSTR_SUBQ; io_micro_kernel_config->alu_cmp_instruction = LIBXSMM_X86_INSTR_CMPQ; io_micro_kernel_config->alu_jmp_instruction = LIBXSMM_X86_INSTR_JL; io_micro_kernel_config->alu_mov_instruction = LIBXSMM_X86_INSTR_MOVQ; }
void libxsmm_generator_gemm_imci_avx512_kernel( libxsmm_generated_code* io_generated_code, const libxsmm_gemm_descriptor* i_xgemm_desc, const char* i_arch ) { libxsmm_micro_kernel_config l_micro_kernel_config; libxsmm_loop_label_tracker l_loop_label_tracker; libxsmm_gp_reg_mapping l_gp_reg_mapping; unsigned int l_number_of_chunks = 1+((i_xgemm_desc->n-1)/30); unsigned int l_modulo = i_xgemm_desc->n%l_number_of_chunks; unsigned int l_n2 = i_xgemm_desc->n/l_number_of_chunks; unsigned int l_n1 = l_n2 + 1; unsigned int l_N2 = 0; unsigned int l_N1 = 0; unsigned int l_chunk = 0; /* define gp register mapping */ libxsmm_reset_x86_gp_reg_mapping( &l_gp_reg_mapping ); /* machting calling convention on Linux */ l_gp_reg_mapping.gp_reg_a = LIBXSMM_X86_GP_REG_RDI; l_gp_reg_mapping.gp_reg_b = LIBXSMM_X86_GP_REG_RSI; l_gp_reg_mapping.gp_reg_c = LIBXSMM_X86_GP_REG_RDX; l_gp_reg_mapping.gp_reg_a_prefetch = LIBXSMM_X86_GP_REG_RCX; l_gp_reg_mapping.gp_reg_b_prefetch = LIBXSMM_X86_GP_REG_R8; l_gp_reg_mapping.gp_reg_mloop = LIBXSMM_X86_GP_REG_R12; l_gp_reg_mapping.gp_reg_nloop = LIBXSMM_X86_GP_REG_R13; l_gp_reg_mapping.gp_reg_kloop = LIBXSMM_X86_GP_REG_R14; l_gp_reg_mapping.gp_reg_help_0 = LIBXSMM_X86_GP_REG_R15; /* masking */ l_gp_reg_mapping.gp_reg_help_1 = LIBXSMM_X86_GP_REG_RAX; /* B stride helper */ l_gp_reg_mapping.gp_reg_help_2 = LIBXSMM_X86_GP_REG_RBX; /* B stride helper */ l_gp_reg_mapping.gp_reg_help_3 = LIBXSMM_X86_GP_REG_R9; /* B stride helper */ l_gp_reg_mapping.gp_reg_help_4 = LIBXSMM_X86_GP_REG_R10; /* B stride helper */ l_gp_reg_mapping.gp_reg_help_5 = LIBXSMM_X86_GP_REG_R11; /* B stride helper */ /* define loop_label_tracker */ libxsmm_reset_loop_label_tracker( &l_loop_label_tracker ); /* define the micro kernel code gen properties */ libxsmm_generator_gemm_init_micro_kernel_config_fullvector( &l_micro_kernel_config, i_xgemm_desc, i_arch, 0 ); if (l_n1 > 30) l_n1 = 30; /* this just the case if i_xgemm_desc->n/l_number_of_chunks has no remainder */ for (l_chunk = 0; l_chunk < l_number_of_chunks; l_chunk++) { if (l_chunk < l_modulo) { l_N1 += l_n1; } else { l_N2 += l_n2; } } /* printf("N splitting of DP AVX512 Kernel: %i %i %i %i\n", l_N1, l_N2, l_n1, l_n2); */ /* open asm */ libxsmm_x86_instruction_open_stream( io_generated_code, &l_gp_reg_mapping, i_arch, i_xgemm_desc->prefetch ); if (l_number_of_chunks == 1) { libxsmm_generator_gemm_imci_avx512_kernel_mloop( io_generated_code, &l_loop_label_tracker, &l_gp_reg_mapping, &l_micro_kernel_config, i_xgemm_desc, i_arch, i_xgemm_desc->n); } else { if ((l_N2 > 0) && (l_N1 > 0)) { libxsmm_generator_gemm_header_nloop( io_generated_code, &l_loop_label_tracker, &l_gp_reg_mapping, &l_micro_kernel_config, l_n1 ); libxsmm_generator_gemm_imci_avx512_kernel_mloop( io_generated_code, &l_loop_label_tracker, &l_gp_reg_mapping, &l_micro_kernel_config, i_xgemm_desc, i_arch, l_n1); libxsmm_generator_gemm_footer_nloop( io_generated_code, &l_loop_label_tracker, &l_gp_reg_mapping, &l_micro_kernel_config, i_xgemm_desc, l_n1, l_N1 ); libxsmm_generator_gemm_header_nloop( io_generated_code, &l_loop_label_tracker, &l_gp_reg_mapping, &l_micro_kernel_config, l_n2 ); libxsmm_generator_gemm_imci_avx512_kernel_mloop( io_generated_code, &l_loop_label_tracker, &l_gp_reg_mapping, &l_micro_kernel_config, i_xgemm_desc, i_arch, l_n2); libxsmm_generator_gemm_footer_nloop( io_generated_code, &l_loop_label_tracker, &l_gp_reg_mapping, &l_micro_kernel_config, i_xgemm_desc, l_n2, i_xgemm_desc->n ); } else if ((l_N2 > 0) && (l_N1 == 0)) { libxsmm_generator_gemm_header_nloop( io_generated_code, &l_loop_label_tracker, &l_gp_reg_mapping, &l_micro_kernel_config, l_n2 ); libxsmm_generator_gemm_imci_avx512_kernel_mloop( io_generated_code, &l_loop_label_tracker, &l_gp_reg_mapping, &l_micro_kernel_config, i_xgemm_desc, i_arch, l_n2); libxsmm_generator_gemm_footer_nloop( io_generated_code, &l_loop_label_tracker, &l_gp_reg_mapping, &l_micro_kernel_config, i_xgemm_desc, l_n2, i_xgemm_desc->n ); } else {} } /* close asm */ libxsmm_x86_instruction_close_stream( io_generated_code, &l_gp_reg_mapping, i_arch, i_xgemm_desc->prefetch ); }
LIBXSMM_INLINE void libxsmm_generator_gemm_imci_avx512_kernel_mloop( libxsmm_generated_code* io_generated_code, libxsmm_loop_label_tracker* io_loop_label_tracker, const libxsmm_gp_reg_mapping* i_gp_reg_mapping, const libxsmm_micro_kernel_config* i_micro_kernel_config, const libxsmm_gemm_descriptor* i_xgemm_desc, const char* i_arch, unsigned int i_n_blocking ) { /* set function pointers for AVX512 and IMCI */ unsigned int (*l_generator_microkernel_kloop)( libxsmm_generated_code*, libxsmm_loop_label_tracker*, const libxsmm_gp_reg_mapping*, const libxsmm_micro_kernel_config*, const libxsmm_gemm_descriptor*, const char*, unsigned int ); void (*l_generator_load)( libxsmm_generated_code*, const libxsmm_gp_reg_mapping*, const libxsmm_micro_kernel_config*, const libxsmm_gemm_descriptor*, const unsigned int, const unsigned int ); void (*l_generator_store)( libxsmm_generated_code*, const libxsmm_gp_reg_mapping*, const libxsmm_micro_kernel_config*, const libxsmm_gemm_descriptor*, const unsigned int, const unsigned int ); unsigned int l_k_unrolled; unsigned int l_m_done; if ( (strcmp(i_arch, "knl") == 0) ) { l_generator_microkernel_kloop = libxsmm_generator_gemm_avx512_kernel_kloop; l_generator_load = libxsmm_generator_gemm_load_C; l_generator_store = libxsmm_generator_gemm_store_C; } else if ( (strcmp(i_arch, "skx") == 0) ) { l_generator_microkernel_kloop = libxsmm_generator_gemm_avx512_kernel_kloop; l_generator_load = libxsmm_generator_gemm_load_C; l_generator_store = libxsmm_generator_gemm_store_C; } else if ( (strcmp(i_arch, "knc") == 0) ) { l_generator_microkernel_kloop = libxsmm_generator_gemm_imci_kernel_kloop; l_generator_load = libxsmm_generator_gemm_load_C_imci; l_generator_store = libxsmm_generator_gemm_store_C_imci; } else { fprintf(stderr, "LIBXSMM ERROR libxsmm_generator_gemm_imci_avx512_kernel_mloop, cannot select microkernel\n"); exit(-1); } /* we proceed as much as we can in vector length steps, remainder is handled uisng masking */ l_m_done = (i_xgemm_desc->m / i_micro_kernel_config->vector_length) * i_micro_kernel_config->vector_length; /* multiples of vector_length in M */ if (l_m_done > 0) { libxsmm_generator_gemm_header_mloop( io_generated_code, io_loop_label_tracker, i_gp_reg_mapping, i_micro_kernel_config, i_micro_kernel_config->vector_length ); l_generator_load( io_generated_code, i_gp_reg_mapping, i_micro_kernel_config, i_xgemm_desc, i_micro_kernel_config->vector_length, i_n_blocking ); /* if we are generating for KNL && i_n_blocking is greater 26 && we prefetch via C -> push prefetch gpr */ if ( (i_n_blocking > 26) && (strcmp(i_arch, "knc") != 0) && (i_xgemm_desc->prefetch == LIBXSMM_PREFETCH_BL2_VIA_C || i_xgemm_desc->prefetch == LIBXSMM_PREFETCH_AL2BL2_VIA_C || i_xgemm_desc->prefetch == LIBXSMM_PREFETCH_AL2BL2_VIA_C_AHEAD || i_xgemm_desc->prefetch == LIBXSMM_PREFETCH_AL2BL2_VIA_C_JPST ) ) { libxsmm_x86_instruction_push_reg( io_generated_code, i_gp_reg_mapping->gp_reg_b_prefetch ); } l_k_unrolled = l_generator_microkernel_kloop( io_generated_code, io_loop_label_tracker, i_gp_reg_mapping, i_micro_kernel_config, i_xgemm_desc, i_arch, i_n_blocking ); /* if we are generating for KNL && i_n_blocking is greater 26 && we prefetch via C -> push prefetch gpr */ if ( (i_n_blocking > 26) && (strcmp(i_arch, "knc") != 0) && (i_xgemm_desc->prefetch == LIBXSMM_PREFETCH_BL2_VIA_C || i_xgemm_desc->prefetch == LIBXSMM_PREFETCH_AL2BL2_VIA_C || i_xgemm_desc->prefetch == LIBXSMM_PREFETCH_AL2BL2_VIA_C_AHEAD || i_xgemm_desc->prefetch == LIBXSMM_PREFETCH_AL2BL2_VIA_C_JPST ) ) { libxsmm_x86_instruction_pop_reg( io_generated_code, i_gp_reg_mapping->gp_reg_b_prefetch ); } l_generator_store( io_generated_code, i_gp_reg_mapping, i_micro_kernel_config, i_xgemm_desc, i_micro_kernel_config->vector_length, i_n_blocking ); libxsmm_generator_gemm_footer_mloop( io_generated_code, io_loop_label_tracker, i_gp_reg_mapping, i_micro_kernel_config, i_xgemm_desc, i_micro_kernel_config->vector_length, l_m_done, l_k_unrolled ); } /* Remainder Handling using Masking, we are using M loop counter register as GP register for the mask */ if ( l_m_done != i_xgemm_desc->m ) { /* request masking support, @TODO performance penality here, as a new object is created */ libxsmm_micro_kernel_config l_micro_kernel_config_mask; libxsmm_generator_gemm_init_micro_kernel_config_fullvector( &l_micro_kernel_config_mask, i_xgemm_desc, i_arch, 1 ); /* initialize k1 register */ libxsmm_generator_gemm_imci_avx512_kernel_initialize_mask( io_generated_code, i_gp_reg_mapping, &l_micro_kernel_config_mask, i_xgemm_desc, l_m_done ); /* run masked micro kernel */ l_generator_load( io_generated_code, i_gp_reg_mapping, &l_micro_kernel_config_mask, i_xgemm_desc, l_micro_kernel_config_mask.vector_length, i_n_blocking ); /* if we are generating for KNL && i_n_blocking is greater 26 && we prefetch via C -> push prefetch gpr */ if ( (i_n_blocking > 26) && (strcmp(i_arch, "knc") != 0) && (i_xgemm_desc->prefetch == LIBXSMM_PREFETCH_BL2_VIA_C || i_xgemm_desc->prefetch == LIBXSMM_PREFETCH_AL2BL2_VIA_C || i_xgemm_desc->prefetch == LIBXSMM_PREFETCH_AL2BL2_VIA_C_AHEAD || i_xgemm_desc->prefetch == LIBXSMM_PREFETCH_AL2BL2_VIA_C_JPST ) ) { libxsmm_x86_instruction_push_reg( io_generated_code, i_gp_reg_mapping->gp_reg_b_prefetch ); } l_k_unrolled = l_generator_microkernel_kloop( io_generated_code, io_loop_label_tracker, i_gp_reg_mapping, &l_micro_kernel_config_mask, i_xgemm_desc, i_arch, i_n_blocking ); /* if we are generating for KNL && i_n_blocking is greater 26 && we prefetch via C -> push prefetch gpr */ if ( (i_n_blocking > 26) && (strcmp(i_arch, "knc") != 0) && (i_xgemm_desc->prefetch == LIBXSMM_PREFETCH_BL2_VIA_C || i_xgemm_desc->prefetch == LIBXSMM_PREFETCH_AL2BL2_VIA_C || i_xgemm_desc->prefetch == LIBXSMM_PREFETCH_AL2BL2_VIA_C_AHEAD || i_xgemm_desc->prefetch == LIBXSMM_PREFETCH_AL2BL2_VIA_C_JPST ) ) { libxsmm_x86_instruction_pop_reg( io_generated_code, i_gp_reg_mapping->gp_reg_b_prefetch ); } l_generator_store( io_generated_code, i_gp_reg_mapping, &l_micro_kernel_config_mask, i_xgemm_desc, l_micro_kernel_config_mask.vector_length, i_n_blocking ); /* adjust pointers as we don't have a m-loop body */ /* C */ libxsmm_x86_instruction_alu_imm( io_generated_code, l_micro_kernel_config_mask.alu_add_instruction, i_gp_reg_mapping->gp_reg_c, (i_xgemm_desc->m - l_m_done) * l_micro_kernel_config_mask.datatype_size ); /* A */ if (l_k_unrolled == 0) { libxsmm_x86_instruction_alu_imm( io_generated_code, l_micro_kernel_config_mask.alu_sub_instruction, i_gp_reg_mapping->gp_reg_a, (i_xgemm_desc->k * l_micro_kernel_config_mask.datatype_size * i_xgemm_desc->lda) - ((i_xgemm_desc->m - l_m_done) * l_micro_kernel_config_mask.datatype_size) ); } else { libxsmm_x86_instruction_alu_imm( io_generated_code, l_micro_kernel_config_mask.alu_add_instruction, i_gp_reg_mapping->gp_reg_a, ((i_xgemm_desc->m - l_m_done) * l_micro_kernel_config_mask.datatype_size) ); } } }
LIBXSMM_API_INTERN void libxsmm_generator_gemm_rm_ac_soa_avx256_512( libxsmm_generated_code* io_generated_code, const libxsmm_gemm_descriptor* i_xgemm_desc, const char* i_arch ) { unsigned int l_soa_width = 0; unsigned int l_max_reg_block = 0; unsigned int l_n1_range = 0; unsigned int l_n2_range = 0; unsigned int l_n1_block = 0; unsigned int l_n2_block = 0; libxsmm_micro_kernel_config l_micro_kernel_config; libxsmm_loop_label_tracker l_loop_label_tracker; libxsmm_gp_reg_mapping l_gp_reg_mapping; /* select soa width */ if ( LIBXSMM_GEMM_PRECISION_F64 == LIBXSMM_GETENUM_INP( i_xgemm_desc->datatype ) ) { if ( strcmp(i_arch, "knl") == 0 || strcmp(i_arch, "knm") == 0 || strcmp(i_arch, "skx") == 0 || strcmp(i_arch, "clx") == 0 || strcmp(i_arch, "cpx") == 0 ) { l_soa_width = 8; l_max_reg_block = 28; } else { l_soa_width = 4; l_max_reg_block = 14; } } else { if ( strcmp(i_arch, "knl") == 0 || strcmp(i_arch, "knm") == 0 || strcmp(i_arch, "skx") == 0 || strcmp(i_arch, "clx") == 0 || strcmp(i_arch, "cpx") == 0 ) { l_soa_width = 16; l_max_reg_block = 28; } else { l_soa_width = 8; l_max_reg_block = 14; } } /* define gp register mapping */ libxsmm_reset_x86_gp_reg_mapping( &l_gp_reg_mapping ); /* matching calling convention on Linux */ #if defined(_WIN32) || defined(__CYGWIN__) l_gp_reg_mapping.gp_reg_a = LIBXSMM_X86_GP_REG_RCX; l_gp_reg_mapping.gp_reg_b = LIBXSMM_X86_GP_REG_RDX; l_gp_reg_mapping.gp_reg_c = LIBXSMM_X86_GP_REG_R8; /* TODO: full support for Windows calling convention */ l_gp_reg_mapping.gp_reg_a_prefetch = LIBXSMM_X86_GP_REG_RDI; l_gp_reg_mapping.gp_reg_b_prefetch = LIBXSMM_X86_GP_REG_RSI; #else /* match calling convention on Linux */ l_gp_reg_mapping.gp_reg_a = LIBXSMM_X86_GP_REG_RDI; l_gp_reg_mapping.gp_reg_b = LIBXSMM_X86_GP_REG_RSI; l_gp_reg_mapping.gp_reg_c = LIBXSMM_X86_GP_REG_RDX; l_gp_reg_mapping.gp_reg_a_prefetch = LIBXSMM_X86_GP_REG_RCX; l_gp_reg_mapping.gp_reg_b_prefetch = LIBXSMM_X86_GP_REG_R8; #endif l_gp_reg_mapping.gp_reg_c_prefetch = LIBXSMM_X86_GP_REG_UNDEF; l_gp_reg_mapping.gp_reg_mloop = LIBXSMM_X86_GP_REG_R12; l_gp_reg_mapping.gp_reg_nloop = LIBXSMM_X86_GP_REG_R13; l_gp_reg_mapping.gp_reg_kloop = LIBXSMM_X86_GP_REG_R14; l_gp_reg_mapping.gp_reg_help_0 = LIBXSMM_X86_GP_REG_UNDEF; l_gp_reg_mapping.gp_reg_help_1 = LIBXSMM_X86_GP_REG_UNDEF; l_gp_reg_mapping.gp_reg_help_2 = LIBXSMM_X86_GP_REG_UNDEF; l_gp_reg_mapping.gp_reg_help_3 = LIBXSMM_X86_GP_REG_UNDEF; l_gp_reg_mapping.gp_reg_help_4 = LIBXSMM_X86_GP_REG_UNDEF; l_gp_reg_mapping.gp_reg_help_5 = LIBXSMM_X86_GP_REG_UNDEF; /* define loop_label_tracker */ libxsmm_reset_loop_label_tracker( &l_loop_label_tracker ); /* define the micro kernel code gen properties */ libxsmm_generator_gemm_init_micro_kernel_config_fullvector( &l_micro_kernel_config, i_xgemm_desc, i_arch, 0 ); /* calculate the chunk size of current columns to work on */ if ( libxsmm_compute_equalized_blocking( i_xgemm_desc->n, l_max_reg_block, &l_n1_range, &l_n1_block, &l_n2_range, &l_n2_block ) ) { LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_N_BLOCK ); return; } /* open asm */ libxsmm_x86_instruction_open_stream( io_generated_code, &l_gp_reg_mapping, i_arch, i_xgemm_desc->prefetch ); /* m loop */ libxsmm_x86_instruction_register_jump_back_label( io_generated_code, &l_loop_label_tracker ); libxsmm_x86_instruction_alu_imm( io_generated_code, l_micro_kernel_config.alu_add_instruction, l_gp_reg_mapping.gp_reg_mloop, 1 ); /* loop over n-blocks */ if ( l_n1_block == i_xgemm_desc->n ) { /* no N loop at all */ libxsmm_generator_gemm_rm_ac_soa_avx256_512_kloop( io_generated_code, &l_loop_label_tracker, &l_gp_reg_mapping, &l_micro_kernel_config, i_xgemm_desc, i_arch, l_soa_width, i_xgemm_desc->n ); } else if ( (l_n1_range > 0) && (l_n2_range > 0) ) { /* reset n loop */ libxsmm_x86_instruction_alu_imm( io_generated_code, l_micro_kernel_config.alu_mov_instruction, l_gp_reg_mapping.gp_reg_nloop, 0 ); /* we have two ranges */ /* first range */ libxsmm_x86_instruction_register_jump_back_label( io_generated_code, &l_loop_label_tracker ); libxsmm_x86_instruction_alu_imm( io_generated_code, l_micro_kernel_config.alu_add_instruction, l_gp_reg_mapping.gp_reg_nloop, l_n1_block ); libxsmm_generator_gemm_rm_ac_soa_avx256_512_kloop( io_generated_code, &l_loop_label_tracker, &l_gp_reg_mapping, &l_micro_kernel_config, i_xgemm_desc, i_arch, l_soa_width, l_n1_block ); libxsmm_x86_instruction_alu_imm( io_generated_code, l_micro_kernel_config.alu_cmp_instruction, l_gp_reg_mapping.gp_reg_nloop, l_n1_range ); libxsmm_x86_instruction_jump_back_to_label( io_generated_code, l_micro_kernel_config.alu_jmp_instruction, &l_loop_label_tracker ); /* second range */ libxsmm_x86_instruction_register_jump_back_label( io_generated_code, &l_loop_label_tracker ); libxsmm_x86_instruction_alu_imm( io_generated_code, l_micro_kernel_config.alu_add_instruction, l_gp_reg_mapping.gp_reg_nloop, l_n2_block ); libxsmm_generator_gemm_rm_ac_soa_avx256_512_kloop( io_generated_code, &l_loop_label_tracker, &l_gp_reg_mapping, &l_micro_kernel_config, i_xgemm_desc, i_arch, l_soa_width, l_n2_block ); libxsmm_x86_instruction_alu_imm( io_generated_code, l_micro_kernel_config.alu_cmp_instruction, l_gp_reg_mapping.gp_reg_nloop, i_xgemm_desc->n ); libxsmm_x86_instruction_jump_back_to_label( io_generated_code, l_micro_kernel_config.alu_jmp_instruction, &l_loop_label_tracker ); /* reset B pointer */ libxsmm_x86_instruction_alu_imm( io_generated_code, l_micro_kernel_config.alu_sub_instruction, l_gp_reg_mapping.gp_reg_b, i_xgemm_desc->n * l_micro_kernel_config.datatype_size ); /* reset C pointer */ libxsmm_x86_instruction_alu_imm( io_generated_code, l_micro_kernel_config.alu_sub_instruction, l_gp_reg_mapping.gp_reg_c, i_xgemm_desc->n * l_soa_width * l_micro_kernel_config.datatype_size ); } else if ( (l_n1_range > 0) && (l_n2_range == 0) ) { /* reset n loop */ libxsmm_x86_instruction_alu_imm( io_generated_code, l_micro_kernel_config.alu_mov_instruction, l_gp_reg_mapping.gp_reg_nloop, 0 ); /* we have one range */ libxsmm_x86_instruction_register_jump_back_label( io_generated_code, &l_loop_label_tracker ); libxsmm_x86_instruction_alu_imm( io_generated_code, l_micro_kernel_config.alu_add_instruction, l_gp_reg_mapping.gp_reg_nloop, l_n1_block ); libxsmm_generator_gemm_rm_ac_soa_avx256_512_kloop( io_generated_code, &l_loop_label_tracker, &l_gp_reg_mapping, &l_micro_kernel_config, i_xgemm_desc, i_arch, l_soa_width, l_n1_block ); libxsmm_x86_instruction_alu_imm( io_generated_code, l_micro_kernel_config.alu_cmp_instruction, l_gp_reg_mapping.gp_reg_nloop, i_xgemm_desc->n ); libxsmm_x86_instruction_jump_back_to_label( io_generated_code, l_micro_kernel_config.alu_jmp_instruction, &l_loop_label_tracker ); /* reset B pointer */ libxsmm_x86_instruction_alu_imm( io_generated_code, l_micro_kernel_config.alu_sub_instruction, l_gp_reg_mapping.gp_reg_b, i_xgemm_desc->n * l_micro_kernel_config.datatype_size ); /* reset C pointer */ libxsmm_x86_instruction_alu_imm( io_generated_code, l_micro_kernel_config.alu_sub_instruction, l_gp_reg_mapping.gp_reg_c, i_xgemm_desc->n * l_soa_width * l_micro_kernel_config.datatype_size ); } else { LIBXSMM_HANDLE_ERROR( io_generated_code, LIBXSMM_ERR_N_BLOCK ); return; } /* advance A pointer */ libxsmm_x86_instruction_alu_imm( io_generated_code, l_micro_kernel_config.alu_add_instruction, l_gp_reg_mapping.gp_reg_a, l_micro_kernel_config.datatype_size*l_soa_width*i_xgemm_desc->lda); /* advance C pointer */ libxsmm_x86_instruction_alu_imm( io_generated_code, l_micro_kernel_config.alu_add_instruction, l_gp_reg_mapping.gp_reg_c, l_micro_kernel_config.datatype_size*l_soa_width*i_xgemm_desc->ldc); /* close m loop */ libxsmm_x86_instruction_alu_imm( io_generated_code, l_micro_kernel_config.alu_cmp_instruction, l_gp_reg_mapping.gp_reg_mloop, i_xgemm_desc->m ); libxsmm_x86_instruction_jump_back_to_label( io_generated_code, l_micro_kernel_config.alu_jmp_instruction, &l_loop_label_tracker ); /* close asm */ libxsmm_x86_instruction_close_stream( io_generated_code, &l_gp_reg_mapping, i_arch, i_xgemm_desc->prefetch ); }