static batch_type haddp(const simd_batch<batch_type>* row) { // The following folds over the vector once: // tmp1 = [a0..8, b0..8] // tmp2 = [a8..f, b8..f] #define XSIMD_AVX512_HADDP_STEP1(I, a, b) \ batch<float, 16> res ## I; \ { \ auto tmp1 = _mm512_shuffle_f32x4(a, b, _MM_SHUFFLE(1, 0, 1, 0)); \ auto tmp2 = _mm512_shuffle_f32x4(a, b, _MM_SHUFFLE(3, 2, 3, 2)); \ res ## I = tmp1 + tmp2; \ } \ XSIMD_AVX512_HADDP_STEP1(1, row[0](), row[4]()); XSIMD_AVX512_HADDP_STEP1(2, row[2](), row[6]()); XSIMD_AVX512_HADDP_STEP1(3, row[1](), row[5]()); XSIMD_AVX512_HADDP_STEP1(4, row[3](), row[7]()); XSIMD_AVX512_HADDP_STEP1(5, row[8](), row[12]()); XSIMD_AVX512_HADDP_STEP1(6, row[10](), row[14]()); XSIMD_AVX512_HADDP_STEP1(7, row[9](), row[13]()); XSIMD_AVX512_HADDP_STEP1(8, row[11](), row[15]()); #undef XSIMD_AVX512_HADDP_STEP1 // The following flds the code and shuffles so that hadd_ps produces the correct result // tmp1 = [a0..4, a8..12, b0..4, b8..12] (same for tmp3) // tmp2 = [a5..8, a12..16, b5..8, b12..16] (same for tmp4) // tmp5 = [r1[0], r1[2], r2[0], r2[2], r1[4], r1[6] ... #define XSIMD_AVX512_HADDP_STEP2(I, a, b, c, d) \ batch<float, 8> halfx ## I; \ { \ batch<float, 16> tmp1 = _mm512_shuffle_f32x4(a, b, _MM_SHUFFLE(2, 0, 2, 0)); \ batch<float, 16> tmp2 = _mm512_shuffle_f32x4(a, b, _MM_SHUFFLE(3, 1, 3, 1)); \ \ batch<float, 16> resx1 = tmp1 + tmp2; \ \ batch<float, 16> tmp3 = _mm512_shuffle_f32x4(c, d, _MM_SHUFFLE(2, 0, 2, 0)); \ batch<float, 16> tmp4 = _mm512_shuffle_f32x4(c, d, _MM_SHUFFLE(3, 1, 3, 1)); \ \ batch<float, 16> resx2 = tmp3 + tmp4; \ \ batch<float, 16> tmp5 = _mm512_shuffle_ps(resx1, resx2, 0b00000000); \ batch<float, 16> tmp6 = _mm512_shuffle_ps(resx1, resx2, 0b11111111); \ \ batch<float, 16> resx3 = tmp5 + tmp6; \ \ halfx ## I = _mm256_hadd_ps(_mm512_extractf32x8_ps(resx3, 0), \ _mm512_extractf32x8_ps(resx3, 1)); \ } \ XSIMD_AVX512_HADDP_STEP2(0, res1, res2, res3, res4); XSIMD_AVX512_HADDP_STEP2(1, res5, res6, res7, res8); #undef XSIMD_AVX512_HADDP_STEP2 auto concat = _mm512_castps256_ps512(halfx0); concat = _mm512_insertf32x8(concat, halfx1, 1); return concat; }
__m512 test_mm512_insertf32x8(__m512 __A, __m256 __B) { // CHECK-LABEL: @test_mm512_insertf32x8 // CHECK: @llvm.x86.avx512.mask.insertf32x8 return _mm512_insertf32x8(__A, __B, 1); }