void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::pp_ker_t::operator ()
    (dst_data_t *dst, const acc_data_t *acc, const char *bias,
        const float *scales, float nslope, float sum_scale, float signed_scale,
        int g, size_t start, size_t end)
{
    using math::get_bias;

    if (end <= start)
        return;

    if (ker_) {
        // JIT
        ker_args args;
        size_t oc_offset = start % OC_;
        size_t os_offset = start / OC_;
        args.acc = acc + start;
        args.dst = dst + os_offset * dst_os_stride_ + oc_offset;
        args.bias = bias + (g * jcp_.oc + oc_offset) * bias_data_type_size_;
        args.scales = scales + scale_idx_mult_ * (g * jcp_.oc + oc_offset);
        args.nslope = nslope;
        args.sum_scale = sum_scale;
        args.signed_scale = signed_scale;
        args.len = end - start;
        args.oc_offset = oc_offset;
        ker_(&args);
    }
    else {
        // Fallback
        const size_t first_oc = start % OC_;
        const size_t last_oc = (end - 1) % OC_;
        const size_t first_os = start / OC_;
        const size_t last_os = (end - 1) / OC_;
        for (size_t os = first_os; os <= last_os; os++) {
            const size_t start_oc = (os == first_os) ? first_oc : 0;
            const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1;
            for (size_t oc = start_oc; oc <= end_oc; oc++) {
                const size_t acc_off = os * jcp_.oc + oc;
                const size_t dst_off = os * dst_os_stride_ + oc;

                float d = (float)(acc[acc_off]);
                if (jcp_.signed_input)
                    d *= signed_scale;

                if (do_bias_)
                    d += get_bias(bias, g * jcp_.oc + oc,
                        bias_data_type_);

                d *= scales[(g * jcp_.oc + oc) * scale_idx_mult_];
                if (do_sum_)
                    d += sum_scale * dst[dst_off];
                if (do_relu_ && d < 0)
                    d *= nslope;
                dst[dst_off] = qz_a1b0<float, dst_data_t>()(d);
            }
        }
    }
};
void _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>::
execute_backward_data_thr(const int ithr, const int nthr,
        const diff_dst_data_t *diff_dst_base, const wei_data_t *wei_base,
        const char *bia_base, diff_src_data_t *diff_src_base,
        const memory_tracking::grantor_t &scratchpad) const
{
    const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;

    const auto diff_dst_md = memory_desc_wrapper(pd()->diff_dst_md());
    const size_t diff_dst_mb_stride = diff_dst_md.blk_off(1);
    const size_t diff_dst_g_stride = diff_dst_md.blk_off(0, 1) * jcp.oc;

    const auto wei_md = memory_desc_wrapper(pd()->weights_md(0));
    const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0;

    const auto diff_src_md = memory_desc_wrapper(pd()->diff_src_md());
    const size_t diff_src_mb_stride = diff_src_md.blk_off(1);
    const size_t diff_src_g_stride = diff_src_md.blk_off(0, 1) * jcp.ic;
    const size_t diff_src_os_stride = diff_src_md.blk_off(0, 0, 0, 1);

    /* scale_idx_mult = 1 for per_oc scales and 0, otherwise */
    const int scale_idx_mult = pd()->attr()->output_scales_.mask_ == (1 << 1);
    const float *scales = pd()->attr()->output_scales_.scales_;
    const size_t work_amount = jcp.ngroups * jcp.mb;

    auto col = scratchpad.get<acc_data_t>(key_conv_gemm_col)
        + (ptrdiff_t)ithr * jcp.im2col_sz;
    auto acc = scratchpad.get<acc_data_t>(key_conv_int_dat_in_acc_dt)
        + (ptrdiff_t)ithr * jcp.is * jcp.ic;

    int n{0}, g{0};
    size_t start = 0, end = 0;

    balance211(work_amount, nthr, ithr, start, end);
    nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups);

    for (size_t iwork = start; iwork < end; ++iwork) {
        const diff_dst_data_t *diff_dst = diff_dst_base
            + n * diff_dst_mb_stride + g * diff_dst_g_stride;
        const wei_data_t *wei = wei_base + g * wei_g_stride;
        diff_src_data_t *diff_src = diff_src_base + n * diff_src_mb_stride
            + g * diff_src_g_stride;

        const int M = jcp.ks * jcp.ic;
        const int N = jcp.os;
        const int K = jcp.oc;
        const int8_t off_a = 0, off_b = 0;
        const int32_t off_c = 0;
        const float onef = 1.0, zerof = 0.0;
        const int LD = K * jcp.ngroups;

        gemm_s8x8s32("T", "N", "F", &M, &N, &K, &onef,
                wei, &LD, &off_a, diff_dst, &LD, &off_b,
                &zerof, jcp.im2col_sz ? col : acc, &M, &off_c);

        if (jcp.im2col_sz)
            jit_gemm_convolution_utils::col2im_s32(jcp, col, acc);

        parallel_nd(jcp.is, jcp.ic, [&](int is, int ic) {
            float d = (float)acc[is * jcp.ic + ic];
            if (jcp.with_bias)
                d += get_bias(bia_base, g * jcp.ic + ic,
                        pd()->desc()->bias_desc.data_type);
            d *= scales[(g * jcp.ic + ic) * scale_idx_mult];
            const size_t diff_src_off = is * diff_src_os_stride + ic;
            diff_src[diff_src_off] =
                qz_a1b0<float, diff_src_data_t>()(d);
        });
        nd_iterator_step(n, jcp.mb, g, jcp.ngroups);
    }
}