** HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, ** ** SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED ** ** TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR ** ** PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF ** ** LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING ** ** NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS ** ** SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ** ******************************************************************************/ /* Kunal Banerjee (Intel Corp.), Dheevatsa Mudigere (Intel Corp.) Alexander Heinecke (Intel Corp.), Hans Pabst (Intel Corp.) ******************************************************************************/ LIBXSMM_VLA_DECL(2, libxsmm_bgemm_lock, locks, handle->locks, handle->nb); /* TODO: pad thread-local buffer members by the size of a cache-line in order to avoid "Ping-Pong" */ LIBXSMM_VLA_DECL(2, LIBXSMM_BGEMM_TEMPLATE_TYPE_C, l_out, (LIBXSMM_BGEMM_TEMPLATE_TYPE_C*)(((char*)handle->buffer) + tid * LIBXSMM_UP2(handle->bm * handle->bn * sizeof(LIBXSMM_BGEMM_TEMPLATE_TYPE_C), LIBXSMM_CACHELINE)), handle->bm); LIBXSMM_VLA_DECL(4, const LIBXSMM_BGEMM_TEMPLATE_TYPE_AB, real_a, (const LIBXSMM_BGEMM_TEMPLATE_TYPE_AB*)a, handle->kb, handle->bk, handle->bm); LIBXSMM_VLA_DECL(4, const LIBXSMM_BGEMM_TEMPLATE_TYPE_AB, real_b, (const LIBXSMM_BGEMM_TEMPLATE_TYPE_AB*)b, handle->kb, handle->bn, handle->bk); LIBXSMM_VLA_DECL(4, LIBXSMM_BGEMM_TEMPLATE_TYPE_C, real_c, (LIBXSMM_BGEMM_TEMPLATE_TYPE_C*)c, handle->mb, handle->bn, handle->bm); const LIBXSMM_MMFUNCTION_TYPE2(LIBXSMM_BGEMM_TEMPLATE_TYPE_AB, LIBXSMM_BGEMM_TEMPLATE_TYPE_C) kernel = handle->kernel.LIBXSMM_TPREFIX2(LIBXSMM_BGEMM_TEMPLATE_TYPE_AB, LIBXSMM_BGEMM_TEMPLATE_TYPE_C, mm); const LIBXSMM_MMFUNCTION_TYPE2(LIBXSMM_BGEMM_TEMPLATE_TYPE_AB, LIBXSMM_BGEMM_TEMPLATE_TYPE_C) kernel_pf = handle->kernel_pf.LIBXSMM_TPREFIX2(LIBXSMM_BGEMM_TEMPLATE_TYPE_AB, LIBXSMM_BGEMM_TEMPLATE_TYPE_C, mm); const libxsmm_blasint b_m1 = handle->b_m1; const libxsmm_blasint b_n1 = handle->b_n1; const libxsmm_blasint b_k1 = handle->b_k1; const libxsmm_blasint b_k2 = handle->b_k2; const libxsmm_blasint mm = handle->m / b_m1;
#endif #if defined(LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK) /* Weight and transpose_weight tensor declaration */ LIBXSMM_VLA_DECL(6, element_filter_type, wt, (element_filter_type*)handle->reg_filter->data, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); #endif LIBXSMM_VLA_DECL(6, element_filter_type, tr_wt, (element_filter_type*)handle->scratch1, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock); /* define weight pointer which has the correct format */ element_filter_type* weight_base = 0; /* padding via stack allocated buffers */ const int padded_w = handle->desc.W + (2 * handle->desc.pad_w); const int padded_h = handle->desc.H + (2 * handle->desc.pad_h); const int size_tls1 = padded_h * padded_w * handle->ifmblock; element_input_type *const del_input_scratch_padding = (element_input_type*)(((char*)handle->scratch5) + ltid * LIBXSMM_UP2(size_tls1 * sizeof(element_input_type), LIBXSMM_CACHELINE)); LIBXSMM_ASSERT(size_tls1 * sizeof(element_input_type) * handle->desc.threads <= handle->max_scratch5_size); for ( ii = 0; ii < size_tls1; ++ii ) { del_input_scratch_padding[ii] = (element_input_type)0; } /* transpose filters, if requested */ if ( (handle->options & LIBXSMM_DNN_CONV_OPTION_BWD_NO_FILTER_TRANSPOSE) > 0 ) { weight_base = (element_filter_type*)handle->reg_filter_tr->data; } else { /* lazy barrier init */ libxsmm_barrier_init(handle->barrier, ltid); for (ifm1ofm1 = transpose_thr_begin; ifm1ofm1 < transpose_thr_end; ++ifm1ofm1) { ofm1 = ifm1ofm1 / handle->blocksifm; ifm1 = ifm1ofm1 % handle->blocksifm; for (kj=0; kj < handle->desc.R; kj++) { for (ki=0; ki < handle->desc.S; ki++) {
#endif #if defined(LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK) /* Weight and transpose_weight tensor declaration */ LIBXSMM_VLA_DECL(6, element_filter_type, wt, (element_filter_type*)handle->reg_filter->data, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); #endif LIBXSMM_VLA_DECL(6, element_filter_type, tr_wt, (element_filter_type*)handle->scratch1, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock); /* define weight pointer which has the correct format */ element_filter_type* weight_base = 0; /* padding via stack allocated buffers */ const int padded_w = handle->desc.W + (2 * handle->desc.pad_w); const int padded_h = handle->desc.H + (2 * handle->desc.pad_h); const int scratch7_size = padded_h * padded_w * handle->ifmblock; #if defined(LIBXSMM_SCRATCH7) element_input_type *const del_input_scratch_padding = (element_input_type*)(((char*)handle->scratch7) + ltid * LIBXSMM_UP2(scratch7_size * sizeof(element_input_type), LIBXSMM_CACHELINE)); #else element_input_type del_input_scratch_padding_array[scratch7_size]; element_input_type *const del_input_scratch_padding = del_input_scratch_padding_array; #endif LIBXSMM_ASSERT(scratch7_size * sizeof(element_input_type) * handle->desc.threads <= handle->scratch7_size); for ( ii = 0; ii < scratch7_size; ++ii ) { del_input_scratch_padding[ii] = (element_input_type)0; } /* transpose filters, if requested */ if ( (handle->options & LIBXSMM_DNN_CONV_OPTION_BWD_NO_FILTER_TRANSPOSE) > 0 ) { weight_base = (element_filter_type*)handle->reg_filter_tr->data; } else { /* lazy barrier init */ libxsmm_barrier_init(handle->barrier, ltid); for (ifm1ofm1 = transpose_thr_begin; ifm1ofm1 < transpose_thr_end; ++ifm1ofm1) {
/* computing first logical thread */ const int ltid = tid - start_thread; /* number of tasks that could be run in parallel */ const int work = handle->blocksifm * handle->blocksofm; /* compute chunk size */ const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); /* compute thr_begin and thr_end */ const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; /* transpose + padding via stack allocated buffers for input */ const int padded_w = handle->desc.W + (2 * handle->desc.pad_w); const int padded_h = handle->desc.H + (2 * handle->desc.pad_h); const int size_tls1 = padded_h * padded_w * handle->ifmblock; element_input_type *const input_scratch = (element_input_type*)(((char*)handle->scratch5) + ltid * LIBXSMM_UP2(size_tls1 * sizeof(element_input_type), LIBXSMM_CACHELINE)); /* transpose via stack allocated buffers for output and weights to control stride-GEMM issue idea: we transpose grad_output and transpose filters when done */ const int scratch6_size = handle->ofhp * handle->ofwp * handle->ofmblock; const int scratch7_size = handle->desc.R * handle->desc.S * handle->ifmblock * handle->ofmblock; element_output_type *const output_scratch = (element_output_type*)(((char*)handle->scratch6) + ltid * LIBXSMM_UP2(scratch6_size * sizeof(element_output_type), LIBXSMM_CACHELINE)); element_filter_type *const filter_scratch = (element_filter_type*)(((char*)handle->scratch7) + ltid * LIBXSMM_UP2(scratch7_size * sizeof(element_filter_type), LIBXSMM_CACHELINE)); element_output_type *const out = (element_output_type*)handle->grad_output->data + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->blocksofm*handle->ofmblock; LIBXSMM_VLA_DECL(5, const element_output_type, output, out, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock); LIBXSMM_VLA_DECL(5, const element_output_type, output_padded, (const element_output_type*)handle->grad_output->data, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock); LIBXSMM_VLA_DECL(5, const element_input_type, input, (element_input_type*)handle->reg_input->data, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock); #if defined(LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM)