コード例 #1
0
    void forward_avx2() {
        xor_(reg_soff, reg_soff);
        Label mb_sp_loop;
        L(mb_sp_loop); {

            channel_loop([=](size_t unroll) {
                        // Load 32 channels (two C16_blocks) in ymm, then
                        // split the work in half, each half splits in two
                        // regs with 8 channels per. When down converting,
                        // put the result in a temp register for the 1st
                        // iteration, combine the result at 2nd iteration
                        // and store ymm with 32 channels.
                        // If 16 channels, do just one half and store the
                        // result with mask.
                        Vmm v0 = Vmm(0);
                        Vmm v1 = Vmm(1);
                        Vmm vscale0 = Vmm(2);
                        Vmm vshift0 = Vmm(3);
                        Vmm vmean0 = Vmm(4);
                        Vmm vsqrtvar0 = Vmm(5);
                        Vmm vscale1 = Vmm(6);
                        Vmm vshift1 = Vmm(7);
                        Vmm vmean1 = Vmm(8);
                        Vmm vsqrtvar1 = Vmm(9);
                        Vmm tmp = Vmm(10);

                        for (size_t i = 0; i < unroll; i++) {
                            compute_vscaleshift(vscale0, vshift0, vmean0,
                                    vsqrtvar0, i * c_in_xmm_ * sizeof(float));
                            compute_vscaleshift(vscale1, vshift1, vmean1,
                                    vsqrtvar1, i * c_in_xmm_ * sizeof(float)
                                    + simd_w_ * sizeof(float));

                            vpmovsxbd(v0, src_ptr(i*c_in_xmm_));
                            vpmovsxbd(v1, src_ptr(i*c_in_xmm_ + simd_w_));
                            vcvtdq2ps(v0, v0);
                            vcvtdq2ps(v1, v1);

                            uni_vfmadd213ps(v0, vscale0, vshift0);
                            uni_vfmadd213ps(v1, vscale1, vshift1);
                            if (with_relu_) {
                                uni_vmaxps(v0, v0, vzero);
                                uni_vmaxps(v1, v1, vzero);
                            }

                            vcvtps2dq(v0, v0); // BA
                            vcvtps2dq(v1, v1); // DC
                            vpackssdw(v0, v0, v1); // BA + DC -> DBCA
                            vpermq(v0, v0, 0xD8); // DBCA -> DCBA
                            vperm2i128(v1, v0, v0, 0x1); // DCBA -> BADC
                            vpacksswb(v0, v0, v1); // DCBA + BADC -> badcDCBA
                            if (i == 0 && unroll != 1)
                                uni_vmovups(tmp, v0);
                            else if (i == 1) {
                                // badcDCBA + fehgHGFE -> HGFEDCBA
                                vperm2i128(v0, v0, tmp, 0x2);
                            }
                        }

                        if (unroll == 1)
                            vmaskmovps(dst_ptr(), vbody_mask, v0);
                        else
                            uni_vmovups(dst_ptr(), v0);
                    },
                    [=]() {
                        // handle first 8 channels. If tail is bigger,
                        // handle second part separately. There is no way
                        // to get performance as one has to work with bytes
                        // via xmm. vzeroupper kills all the perf.
                        Xmm x0 = Xmm(0);
                        Vmm v0 = Vmm(0);
                        Vmm vscale0 = Vmm(1);
                        Vmm vshift0 = Vmm(2);
                        Vmm vmean0 = Vmm(3);
                        Vmm vsqrtvar0 = Vmm(4);

                        size_t tail = nstl::min(c_tail_, simd_w_);
                        size_t num_iters = c_tail_ > simd_w_ ? 2 : 1;

                        for (size_t i = 0; i < num_iters; i++) {
                            if (i > 0)
                                tail = c_tail_ - simd_w_;

                            for (size_t tl = 0; tl < tail; tl++)
                                vpinsrb(x0, x0, src_ptr(8*i + tl), tl);

                            if (tail == simd_w_)
                                compute_vscaleshift(vscale0, vshift0, vmean0,
                                        vsqrtvar0, 32*i);
                            else
                                compute_vscaleshift(vscale0, vshift0, vmean0,
                                        vsqrtvar0, 32*i, true);

                            vpmovsxbd(v0, x0);
                            vcvtdq2ps(v0, v0);
                            uni_vfmadd213ps(v0, vscale0, vshift0);
                            if (with_relu_)
                                uni_vmaxps(v0, v0, vzero);
                            vcvtps2dq(v0, v0);
                            vpackssdw(v0, v0, vzero);
                            vpermq(v0, v0, 0xD8);
                            vpacksswb(v0, v0, vzero);

                            for (size_t tl = 0; tl < tail; tl++)
                                vpextrb(dst_ptr(8*i + tl), x0, tl);
                        }
                    });

            add(reg_soff, reg_coff_max);
            cmp(reg_soff, reg_soff_max);
            jl(mb_sp_loop);
        }
    }
コード例 #2
0
void GSDrawScanlineCodeGenerator::clamp16(const Ymm& a, const Ymm& temp)
{
    vpackuswb(a, a);
    vpermq(a, a, _MM_SHUFFLE(3, 1, 2, 0)); // this sucks
    vpmovzxbw(a, a);
}