Exemple #1
0
	Code()
	{
		Xbyak::Label label;
		cmpss(xmm0, ptr[rip + label], 0);
		test(dword[rip + label], 33);
		bt(dword[rip + label ], 3);
		vblendpd(xmm0, dword[rip + label], 3);
		vpalignr(xmm0, qword[rip + label], 4);
		vextractf128(dword[rip + label], ymm3, 12);
		vperm2i128(ymm0, ymm1, qword[rip + label], 13);
		vcvtps2ph(ptr[rip + label], xmm2, 44);
		mov(dword[rip + label], 0x1234);
		shl(dword[rip + label], 3);
		shr(dword[rip + label], 1);
		shld(qword[rip + label], rax, 3);
		imul(rax, qword[rip + label], 21);
		rorx(rax, qword[rip + label], 21);
		test(dword[rip + label], 5);
		pextrq(ptr[rip + label], xmm0, 3);
		pinsrq(xmm2, ptr[rip + label], 5);
		pextrw(ptr[rip + label], xmm1, 4);
		adc(dword[rip + label], 0x12345);
		bt(byte[rip + label], 0x34);
		btc(word[rip + label], 0x34);
		btr(dword[rip + label], 0x34);
		rcl(dword[rip + label], 4);
		shld(qword[rip + label], rax, 4);
		palignr(mm0, ptr[rip + label], 4);
		aeskeygenassist(xmm3, ptr[rip + label], 4);
		vpcmpestrm(xmm2, ptr[rip + label], 7);
		ret();
	L(label);
		dq(0x123456789abcdef0ull);
	};
    void forward_avx2() {
        xor_(reg_soff, reg_soff);
        Label mb_sp_loop;
        L(mb_sp_loop); {

            channel_loop([=](size_t unroll) {
                        // Load 32 channels (two C16_blocks) in ymm, then
                        // split the work in half, each half splits in two
                        // regs with 8 channels per. When down converting,
                        // put the result in a temp register for the 1st
                        // iteration, combine the result at 2nd iteration
                        // and store ymm with 32 channels.
                        // If 16 channels, do just one half and store the
                        // result with mask.
                        Vmm v0 = Vmm(0);
                        Vmm v1 = Vmm(1);
                        Vmm vscale0 = Vmm(2);
                        Vmm vshift0 = Vmm(3);
                        Vmm vmean0 = Vmm(4);
                        Vmm vsqrtvar0 = Vmm(5);
                        Vmm vscale1 = Vmm(6);
                        Vmm vshift1 = Vmm(7);
                        Vmm vmean1 = Vmm(8);
                        Vmm vsqrtvar1 = Vmm(9);
                        Vmm tmp = Vmm(10);

                        for (size_t i = 0; i < unroll; i++) {
                            compute_vscaleshift(vscale0, vshift0, vmean0,
                                    vsqrtvar0, i * c_in_xmm_ * sizeof(float));
                            compute_vscaleshift(vscale1, vshift1, vmean1,
                                    vsqrtvar1, i * c_in_xmm_ * sizeof(float)
                                    + simd_w_ * sizeof(float));

                            vpmovsxbd(v0, src_ptr(i*c_in_xmm_));
                            vpmovsxbd(v1, src_ptr(i*c_in_xmm_ + simd_w_));
                            vcvtdq2ps(v0, v0);
                            vcvtdq2ps(v1, v1);

                            uni_vfmadd213ps(v0, vscale0, vshift0);
                            uni_vfmadd213ps(v1, vscale1, vshift1);
                            if (with_relu_) {
                                uni_vmaxps(v0, v0, vzero);
                                uni_vmaxps(v1, v1, vzero);
                            }

                            vcvtps2dq(v0, v0); // BA
                            vcvtps2dq(v1, v1); // DC
                            vpackssdw(v0, v0, v1); // BA + DC -> DBCA
                            vpermq(v0, v0, 0xD8); // DBCA -> DCBA
                            vperm2i128(v1, v0, v0, 0x1); // DCBA -> BADC
                            vpacksswb(v0, v0, v1); // DCBA + BADC -> badcDCBA
                            if (i == 0 && unroll != 1)
                                uni_vmovups(tmp, v0);
                            else if (i == 1) {
                                // badcDCBA + fehgHGFE -> HGFEDCBA
                                vperm2i128(v0, v0, tmp, 0x2);
                            }
                        }

                        if (unroll == 1)
                            vmaskmovps(dst_ptr(), vbody_mask, v0);
                        else
                            uni_vmovups(dst_ptr(), v0);
                    },
                    [=]() {
                        // handle first 8 channels. If tail is bigger,
                        // handle second part separately. There is no way
                        // to get performance as one has to work with bytes
                        // via xmm. vzeroupper kills all the perf.
                        Xmm x0 = Xmm(0);
                        Vmm v0 = Vmm(0);
                        Vmm vscale0 = Vmm(1);
                        Vmm vshift0 = Vmm(2);
                        Vmm vmean0 = Vmm(3);
                        Vmm vsqrtvar0 = Vmm(4);

                        size_t tail = nstl::min(c_tail_, simd_w_);
                        size_t num_iters = c_tail_ > simd_w_ ? 2 : 1;

                        for (size_t i = 0; i < num_iters; i++) {
                            if (i > 0)
                                tail = c_tail_ - simd_w_;

                            for (size_t tl = 0; tl < tail; tl++)
                                vpinsrb(x0, x0, src_ptr(8*i + tl), tl);

                            if (tail == simd_w_)
                                compute_vscaleshift(vscale0, vshift0, vmean0,
                                        vsqrtvar0, 32*i);
                            else
                                compute_vscaleshift(vscale0, vshift0, vmean0,
                                        vsqrtvar0, 32*i, true);

                            vpmovsxbd(v0, x0);
                            vcvtdq2ps(v0, v0);
                            uni_vfmadd213ps(v0, vscale0, vshift0);
                            if (with_relu_)
                                uni_vmaxps(v0, v0, vzero);
                            vcvtps2dq(v0, v0);
                            vpackssdw(v0, v0, vzero);
                            vpermq(v0, v0, 0xD8);
                            vpacksswb(v0, v0, vzero);

                            for (size_t tl = 0; tl < tail; tl++)
                                vpextrb(dst_ptr(8*i + tl), x0, tl);
                        }
                    });

            add(reg_soff, reg_coff_max);
            cmp(reg_soff, reg_soff_max);
            jl(mb_sp_loop);
        }
    }