static void NOINLINE mulX8( const __m256 *v1, const __m256 *v2, __m256 *vout ) { static const int ALIGN32 p1[ 8 ] = { 0, 0, 0, 0, 1, 1, 1, 1 }; static const int ALIGN32 p2[ 8 ] = { 2, 2, 2, 2, 3, 3, 3, 3 }; static const int ALIGN32 p3[ 8 ] = { 4, 4, 4, 4, 5, 5, 5, 5 }; static const int ALIGN32 p4[ 8 ] = { 6, 6, 6, 6, 7, 7, 7, 7 }; const __m256i perm1 = _mm256_load_si256( reinterpret_cast< const __m256i* >( p1 ) ); const __m256i perm2 = _mm256_load_si256( reinterpret_cast< const __m256i* >( p2 ) ); const __m256i perm3 = _mm256_load_si256( reinterpret_cast< const __m256i* >( p3 ) ); const __m256i perm4 = _mm256_load_si256( reinterpret_cast< const __m256i* >( p4 ) ); for( int r = 0; r < 2; r++ ) { __m256 a0 = _mm256_permutevar8x32_ps( v1[ r ], perm1 ); __m256 a1 = _mm256_permutevar8x32_ps( v1[ r ], perm2 ); __m256 a2 = _mm256_permutevar8x32_ps( v1[ r ], perm3 ); __m256 a3 = _mm256_permutevar8x32_ps( v1[ r ], perm4 ); __m256 b0 = _mm256_mul_ps( a0, v2[ 0 ] ); __m256 b1 = _mm256_mul_ps( a1, v2[ 1 ] ); __m256 b2 = _mm256_mul_ps( a2, v2[ 0 ] ); __m256 b3 = _mm256_mul_ps( a3, v2[ 1 ] ); __m256 c0 = _mm256_add_ps( b0, b1 ); __m256 c1 = _mm256_add_ps( b2, b3 ); __m256 d0 = _mm256_permute2f128_ps( c0, c1, _MM_SHUFFLE( 0, 2, 0, 0 ) ); __m256 d1 = _mm256_permute2f128_ps( c0, c1, _MM_SHUFFLE( 0, 3, 0, 1 ) ); vout[ r ] = _mm256_add_ps( d0, d1 ); } }
static void NOINLINE transposeX8( const __m256 *v1, __m256 *vout ) { #if 0 // AVX1 __m256 a0 = _mm256_unpacklo_ps( v1[ 0 ], v1[ 1 ] ); __m256 a1 = _mm256_unpackhi_ps( v1[ 0 ], v1[ 1 ] ); __m256 b0 = _mm256_permute2f128_ps( a0, a1, _MM_SHUFFLE( 0, 2, 0, 0 ) ); __m256 b1 = _mm256_permute2f128_ps( a0, a1, _MM_SHUFFLE( 0, 3, 0, 1 ) ); __m256 c0 = _mm256_unpacklo_ps( b0, b1 ); __m256 c1 = _mm256_unpackhi_ps( b0, b1 ); vout[ 0 ] = _mm256_permute2f128_ps( c0, c1, _MM_SHUFFLE( 0, 2, 0, 0 ) ); vout[ 1 ] = _mm256_permute2f128_ps( c0, c1, _MM_SHUFFLE( 0, 3, 0, 1 ) ); #else // AVX2 static const int ALIGN32 p1[ 8 ] = { 0, 4, 2, 6, 1, 5, 3, 7 }; static const int ALIGN32 p2[ 8 ] = { 2, 6, 0, 4, 3, 7, 1, 5 }; const __m256i perm1 = _mm256_load_si256( reinterpret_cast< const __m256i* >( p1 ) ); const __m256i perm2 = _mm256_load_si256( reinterpret_cast< const __m256i* >( p2 ) ); __m256 a0 = _mm256_permutevar8x32_ps( v1[ 0 ], perm1 ); __m256 a1 = _mm256_permutevar8x32_ps( v1[ 1 ], perm2 ); vout[ 0 ] = _mm256_blend_ps( a0, a1, 0xCC ); vout[ 1 ] = _mm256_shuffle_ps( a0, a1, 0x4E ); #endif }
void run_softmax_int32_float_work_item_latency(nn_workload_item *const work_item) { nn_workload_data_t *input_view = work_item->input[0]->output; const auto &arguments = work_item->arguments.forward_softmax_fixedpoint; const auto input_width = input_view->parent->lengths.t[NN_DATA_COORD_z] * input_view->parent->lengths.t[NN_DATA_COORD_p]; const auto output_width = work_item->output->view_end.t[NN_DATA_COORD_x] - work_item->output->view_begin.t[NN_DATA_COORD_x] + 1; const auto num_full_blocks = output_width / C_data_stride; const auto partial_block_size = (output_width / C_simd_width) % C_max_acc; const auto subsimd_block_size = output_width % C_simd_width; const auto output_view_start = work_item->output->view_begin.t[NN_DATA_COORD_x]; const auto input_view_start = input_view->view_begin.t[NN_DATA_COORD_z] * input_view->parent->lengths.t[NN_DATA_COORD_p]; const auto out_fraction = arguments.input_fraction; float * input_f = (float*)_mm_malloc(input_width * sizeof(float), 64); auto input_buffer = &static_cast<int32_t*>(input_view->parent->data_buffer)[input_view_start]; auto shift = out_fraction; if (shift > 0) { for (uint32_t i = 0; i < input_width; i++) input_f[i] = (float)(input_buffer[i]) / (1 << shift); } else if (shift < 0) { for (uint32_t i = 0; i < input_width; i++) input_f[i] = (float)(input_buffer[i]) * (1 << -shift); } else { for (uint32_t i = 0; i < input_width; i++) input_f[i] = (float)(input_buffer[i]); } __m256 acc_sum = _mm256_setzero_ps(); float subsimd_sum = 0.0f; { auto input_buffer = input_f; auto output_buffer = &static_cast<float*>(work_item->output->parent->data_buffer)[output_view_start]; for (auto block = 0u; block < num_full_blocks; ++block) { // Run computation. softmax_compute_block<C_max_acc>(input_buffer, output_buffer, acc_sum); } switch (partial_block_size) { case 0: break; case 1: softmax_compute_block< 1>(input_buffer, output_buffer, acc_sum); break; case 2: softmax_compute_block< 2>(input_buffer, output_buffer, acc_sum); break; case 3: softmax_compute_block< 3>(input_buffer, output_buffer, acc_sum); break; case 4: softmax_compute_block< 4>(input_buffer, output_buffer, acc_sum); break; case 5: softmax_compute_block< 5>(input_buffer, output_buffer, acc_sum); break; case 6: softmax_compute_block< 6>(input_buffer, output_buffer, acc_sum); break; case 7: softmax_compute_block< 7>(input_buffer, output_buffer, acc_sum); break; case 8: softmax_compute_block< 8>(input_buffer, output_buffer, acc_sum); break; case 9: softmax_compute_block< 9>(input_buffer, output_buffer, acc_sum); break; case 10: softmax_compute_block<10>(input_buffer, output_buffer, acc_sum); break; case 11: softmax_compute_block<11>(input_buffer, output_buffer, acc_sum); break; case 12: softmax_compute_block<12>(input_buffer, output_buffer, acc_sum); break; case 13: softmax_compute_block<13>(input_buffer, output_buffer, acc_sum); break; case 14: softmax_compute_block<14>(input_buffer, output_buffer, acc_sum); break; default: NN_UNREACHABLE_CODE; } switch (subsimd_block_size) { case 0: break; case 1: softmax_compute_subsimd<1>(input_buffer, output_buffer, subsimd_sum); break; case 2: softmax_compute_subsimd<2>(input_buffer, output_buffer, subsimd_sum); break; case 3: softmax_compute_subsimd<3>(input_buffer, output_buffer, subsimd_sum); break; case 4: softmax_compute_subsimd<4>(input_buffer, output_buffer, subsimd_sum); break; case 5: softmax_compute_subsimd<5>(input_buffer, output_buffer, subsimd_sum); break; case 6: softmax_compute_subsimd<6>(input_buffer, output_buffer, subsimd_sum); break; case 7: softmax_compute_subsimd<7>(input_buffer, output_buffer, subsimd_sum); break; default: NN_UNREACHABLE_CODE; } } { __m256 intermediate_sum = _mm256_hadd_ps(acc_sum, acc_sum); intermediate_sum = _mm256_permutevar8x32_ps(intermediate_sum, _mm256_set_epi32(0, 1, 4, 5, 2, 3, 6, 7)); intermediate_sum = _mm256_hadd_ps(intermediate_sum, intermediate_sum); intermediate_sum = _mm256_hadd_ps(intermediate_sum, intermediate_sum); acc_sum = _mm256_add_ps(intermediate_sum, _mm256_set1_ps(subsimd_sum)); subsimd_sum = _mm_cvtss_f32(_mm256_extractf128_ps(acc_sum, 0)); acc_sum = _mm256_div_ps(_mm256_set1_ps(1.0f), acc_sum); subsimd_sum = 1.0f / subsimd_sum; } { auto output_buffer = &static_cast<float*>(work_item->output->parent->data_buffer)[output_view_start]; for (auto block = 0u; block < num_full_blocks; ++block) { // Run computation. softmax_finalize_block<C_max_acc>(output_buffer, acc_sum); } switch (partial_block_size) { case 0: break; case 1: softmax_finalize_block< 1>(output_buffer, acc_sum); break; case 2: softmax_finalize_block< 2>(output_buffer, acc_sum); break; case 3: softmax_finalize_block< 3>(output_buffer, acc_sum); break; case 4: softmax_finalize_block< 4>(output_buffer, acc_sum); break; case 5: softmax_finalize_block< 5>(output_buffer, acc_sum); break; case 6: softmax_finalize_block< 6>(output_buffer, acc_sum); break; case 7: softmax_finalize_block< 7>(output_buffer, acc_sum); break; case 8: softmax_finalize_block< 8>(output_buffer, acc_sum); break; case 9: softmax_finalize_block< 9>(output_buffer, acc_sum); break; case 10: softmax_finalize_block<10>(output_buffer, acc_sum); break; case 11: softmax_finalize_block<11>(output_buffer, acc_sum); break; case 12: softmax_finalize_block<12>(output_buffer, acc_sum); break; case 13: softmax_finalize_block<13>(output_buffer, acc_sum); break; case 14: softmax_finalize_block<14>(output_buffer, acc_sum); break; default: NN_UNREACHABLE_CODE; } switch (subsimd_block_size) { case 0: break; case 1: softmax_finalize_subsimd<1>(output_buffer, subsimd_sum); break; case 2: softmax_finalize_subsimd<2>(output_buffer, subsimd_sum); break; case 3: softmax_finalize_subsimd<3>(output_buffer, subsimd_sum); break; case 4: softmax_finalize_subsimd<4>(output_buffer, subsimd_sum); break; case 5: softmax_finalize_subsimd<5>(output_buffer, subsimd_sum); break; case 6: softmax_finalize_subsimd<6>(output_buffer, subsimd_sum); break; case 7: softmax_finalize_subsimd<7>(output_buffer, subsimd_sum); break; default: NN_UNREACHABLE_CODE; } } _mm_free(input_f); }
__m256 test_mm256_permutevar8x32_ps(__m256 a, __m256 b) { // CHECK: @llvm.x86.avx2.permps return _mm256_permutevar8x32_ps(a, b); }
__m256 test_mm256_permutevar8x32_ps(__m256 a, __m256i b) { // CHECK-LABEL: test_mm256_permutevar8x32_ps // CHECK: call <8 x float> @llvm.x86.avx2.permps(<8 x float> %{{.*}}, <8 x i32> %{{.*}}) return _mm256_permutevar8x32_ps(a, b); }