void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::pp_ker_t::operator () (dst_data_t *dst, const acc_data_t *acc, const char *bias, const float *scales, float nslope, float sum_scale, float signed_scale, int g, size_t start, size_t end) { using math::get_bias; if (end <= start) return; if (ker_) { // JIT ker_args args; size_t oc_offset = start % OC_; size_t os_offset = start / OC_; args.acc = acc + start; args.dst = dst + os_offset * dst_os_stride_ + oc_offset; args.bias = bias + (g * jcp_.oc + oc_offset) * bias_data_type_size_; args.scales = scales + scale_idx_mult_ * (g * jcp_.oc + oc_offset); args.nslope = nslope; args.sum_scale = sum_scale; args.signed_scale = signed_scale; args.len = end - start; args.oc_offset = oc_offset; ker_(&args); } else { // Fallback const size_t first_oc = start % OC_; const size_t last_oc = (end - 1) % OC_; const size_t first_os = start / OC_; const size_t last_os = (end - 1) / OC_; for (size_t os = first_os; os <= last_os; os++) { const size_t start_oc = (os == first_os) ? first_oc : 0; const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1; for (size_t oc = start_oc; oc <= end_oc; oc++) { const size_t acc_off = os * jcp_.oc + oc; const size_t dst_off = os * dst_os_stride_ + oc; float d = (float)(acc[acc_off]); if (jcp_.signed_input) d *= signed_scale; if (do_bias_) d += get_bias(bias, g * jcp_.oc + oc, bias_data_type_); d *= scales[(g * jcp_.oc + oc) * scale_idx_mult_]; if (do_sum_) d += sum_scale * dst[dst_off]; if (do_relu_ && d < 0) d *= nslope; dst[dst_off] = qz_a1b0<float, dst_data_t>()(d); } } } };
void _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>:: execute_backward_data_thr(const int ithr, const int nthr, const diff_dst_data_t *diff_dst_base, const wei_data_t *wei_base, const char *bia_base, diff_src_data_t *diff_src_base, const memory_tracking::grantor_t &scratchpad) const { const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; const auto diff_dst_md = memory_desc_wrapper(pd()->diff_dst_md()); const size_t diff_dst_mb_stride = diff_dst_md.blk_off(1); const size_t diff_dst_g_stride = diff_dst_md.blk_off(0, 1) * jcp.oc; const auto wei_md = memory_desc_wrapper(pd()->weights_md(0)); const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0; const auto diff_src_md = memory_desc_wrapper(pd()->diff_src_md()); const size_t diff_src_mb_stride = diff_src_md.blk_off(1); const size_t diff_src_g_stride = diff_src_md.blk_off(0, 1) * jcp.ic; const size_t diff_src_os_stride = diff_src_md.blk_off(0, 0, 0, 1); /* scale_idx_mult = 1 for per_oc scales and 0, otherwise */ const int scale_idx_mult = pd()->attr()->output_scales_.mask_ == (1 << 1); const float *scales = pd()->attr()->output_scales_.scales_; const size_t work_amount = jcp.ngroups * jcp.mb; auto col = scratchpad.get<acc_data_t>(key_conv_gemm_col) + (ptrdiff_t)ithr * jcp.im2col_sz; auto acc = scratchpad.get<acc_data_t>(key_conv_int_dat_in_acc_dt) + (ptrdiff_t)ithr * jcp.is * jcp.ic; int n{0}, g{0}; size_t start = 0, end = 0; balance211(work_amount, nthr, ithr, start, end); nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups); for (size_t iwork = start; iwork < end; ++iwork) { const diff_dst_data_t *diff_dst = diff_dst_base + n * diff_dst_mb_stride + g * diff_dst_g_stride; const wei_data_t *wei = wei_base + g * wei_g_stride; diff_src_data_t *diff_src = diff_src_base + n * diff_src_mb_stride + g * diff_src_g_stride; const int M = jcp.ks * jcp.ic; const int N = jcp.os; const int K = jcp.oc; const int8_t off_a = 0, off_b = 0; const int32_t off_c = 0; const float onef = 1.0, zerof = 0.0; const int LD = K * jcp.ngroups; gemm_s8x8s32("T", "N", "F", &M, &N, &K, &onef, wei, &LD, &off_a, diff_dst, &LD, &off_b, &zerof, jcp.im2col_sz ? col : acc, &M, &off_c); if (jcp.im2col_sz) jit_gemm_convolution_utils::col2im_s32(jcp, col, acc); parallel_nd(jcp.is, jcp.ic, [&](int is, int ic) { float d = (float)acc[is * jcp.ic + ic]; if (jcp.with_bias) d += get_bias(bia_base, g * jcp.ic + ic, pd()->desc()->bias_desc.data_type); d *= scales[(g * jcp.ic + ic) * scale_idx_mult]; const size_t diff_src_off = is * diff_src_os_stride + ic; diff_src[diff_src_off] = qz_a1b0<float, diff_src_data_t>()(d); }); nd_iterator_step(n, jcp.mb, g, jcp.ngroups); } }