コード例 #1
0
** HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,    **
** SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED  **
** TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR    **
** PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF    **
** LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING      **
** NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS        **
** SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.              **
******************************************************************************/
/* Kunal Banerjee (Intel Corp.), Dheevatsa Mudigere (Intel Corp.)
   Alexander Heinecke (Intel Corp.), Hans Pabst (Intel Corp.)
******************************************************************************/

LIBXSMM_VLA_DECL(2, libxsmm_bgemm_lock, locks, handle->locks, handle->nb);
/* TODO: pad thread-local buffer members by the size of a cache-line in order to avoid "Ping-Pong" */
LIBXSMM_VLA_DECL(2, LIBXSMM_BGEMM_TEMPLATE_TYPE_C, l_out, (LIBXSMM_BGEMM_TEMPLATE_TYPE_C*)(((char*)handle->buffer) +
  tid * LIBXSMM_UP2(handle->bm * handle->bn * sizeof(LIBXSMM_BGEMM_TEMPLATE_TYPE_C), LIBXSMM_CACHELINE)), handle->bm);
LIBXSMM_VLA_DECL(4, const LIBXSMM_BGEMM_TEMPLATE_TYPE_AB, real_a, (const LIBXSMM_BGEMM_TEMPLATE_TYPE_AB*)a, handle->kb, handle->bk, handle->bm);
LIBXSMM_VLA_DECL(4, const LIBXSMM_BGEMM_TEMPLATE_TYPE_AB, real_b, (const LIBXSMM_BGEMM_TEMPLATE_TYPE_AB*)b, handle->kb, handle->bn, handle->bk);
LIBXSMM_VLA_DECL(4, LIBXSMM_BGEMM_TEMPLATE_TYPE_C, real_c, (LIBXSMM_BGEMM_TEMPLATE_TYPE_C*)c, handle->mb, handle->bn, handle->bm);

const LIBXSMM_MMFUNCTION_TYPE2(LIBXSMM_BGEMM_TEMPLATE_TYPE_AB, LIBXSMM_BGEMM_TEMPLATE_TYPE_C) kernel =
        handle->kernel.LIBXSMM_TPREFIX2(LIBXSMM_BGEMM_TEMPLATE_TYPE_AB, LIBXSMM_BGEMM_TEMPLATE_TYPE_C, mm);
const LIBXSMM_MMFUNCTION_TYPE2(LIBXSMM_BGEMM_TEMPLATE_TYPE_AB, LIBXSMM_BGEMM_TEMPLATE_TYPE_C) kernel_pf =
        handle->kernel_pf.LIBXSMM_TPREFIX2(LIBXSMM_BGEMM_TEMPLATE_TYPE_AB, LIBXSMM_BGEMM_TEMPLATE_TYPE_C, mm);

const libxsmm_blasint b_m1 = handle->b_m1;
const libxsmm_blasint b_n1 = handle->b_n1;
const libxsmm_blasint b_k1 = handle->b_k1;
const libxsmm_blasint b_k2 = handle->b_k2;

const libxsmm_blasint mm = handle->m / b_m1;
#endif
#if defined(LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK)
/* Weight and transpose_weight tensor declaration */
LIBXSMM_VLA_DECL(6, element_filter_type, wt, (element_filter_type*)handle->reg_filter->data, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock);
#endif

LIBXSMM_VLA_DECL(6, element_filter_type, tr_wt, (element_filter_type*)handle->scratch1, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock);
/* define weight pointer which has the correct format */
element_filter_type* weight_base = 0;

/* padding via stack allocated buffers */
const int padded_w = handle->desc.W + (2 * handle->desc.pad_w);
const int padded_h = handle->desc.H + (2 * handle->desc.pad_h);
const int size_tls1 = padded_h * padded_w * handle->ifmblock;
element_input_type *const del_input_scratch_padding = (element_input_type*)(((char*)handle->scratch5) +
  ltid * LIBXSMM_UP2(size_tls1 * sizeof(element_input_type), LIBXSMM_CACHELINE));
LIBXSMM_ASSERT(size_tls1 * sizeof(element_input_type) * handle->desc.threads <= handle->max_scratch5_size);
for ( ii = 0; ii < size_tls1; ++ii ) { del_input_scratch_padding[ii] = (element_input_type)0; }

/* transpose filters, if requested */
if ( (handle->options & LIBXSMM_DNN_CONV_OPTION_BWD_NO_FILTER_TRANSPOSE) > 0 ) {
  weight_base = (element_filter_type*)handle->reg_filter_tr->data;
} else {
  /* lazy barrier init */
  libxsmm_barrier_init(handle->barrier, ltid);

  for (ifm1ofm1 = transpose_thr_begin; ifm1ofm1 < transpose_thr_end; ++ifm1ofm1) {
    ofm1 = ifm1ofm1 / handle->blocksifm;
    ifm1 = ifm1ofm1 % handle->blocksifm;
    for (kj=0; kj < handle->desc.R; kj++) {
      for (ki=0; ki < handle->desc.S; ki++) {
#endif
#if defined(LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK)
/* Weight and transpose_weight tensor declaration */
LIBXSMM_VLA_DECL(6, element_filter_type, wt, (element_filter_type*)handle->reg_filter->data, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock);
#endif

LIBXSMM_VLA_DECL(6, element_filter_type, tr_wt, (element_filter_type*)handle->scratch1, handle->blocksofm, handle->desc.R, handle->desc.S, handle->ofmblock, handle->ifmblock);
/* define weight pointer which has the correct format */
element_filter_type* weight_base = 0;

/* padding via stack allocated buffers */
const int padded_w = handle->desc.W + (2 * handle->desc.pad_w);
const int padded_h = handle->desc.H + (2 * handle->desc.pad_h);
const int scratch7_size = padded_h * padded_w * handle->ifmblock;
#if defined(LIBXSMM_SCRATCH7)
element_input_type *const del_input_scratch_padding = (element_input_type*)(((char*)handle->scratch7) + ltid * LIBXSMM_UP2(scratch7_size * sizeof(element_input_type), LIBXSMM_CACHELINE));
#else
element_input_type del_input_scratch_padding_array[scratch7_size];
element_input_type *const del_input_scratch_padding = del_input_scratch_padding_array;
#endif
LIBXSMM_ASSERT(scratch7_size * sizeof(element_input_type) * handle->desc.threads <= handle->scratch7_size);
for ( ii = 0; ii < scratch7_size; ++ii ) { del_input_scratch_padding[ii] = (element_input_type)0; }

/* transpose filters, if requested */
if ( (handle->options & LIBXSMM_DNN_CONV_OPTION_BWD_NO_FILTER_TRANSPOSE) > 0 ) {
  weight_base = (element_filter_type*)handle->reg_filter_tr->data;
} else {
  /* lazy barrier init */
  libxsmm_barrier_init(handle->barrier, ltid);

  for (ifm1ofm1 = transpose_thr_begin; ifm1ofm1 < transpose_thr_end; ++ifm1ofm1) {
/* computing first logical thread */
const int ltid = tid - start_thread;
/* number of tasks that could be run in parallel */
const int work = handle->blocksifm * handle->blocksofm;
/* compute chunk size */
const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1);
/* compute thr_begin and thr_end */
const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work;
const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work;

/* transpose + padding via stack allocated buffers for input */
const int padded_w = handle->desc.W + (2 * handle->desc.pad_w);
const int padded_h = handle->desc.H + (2 * handle->desc.pad_h);
const int size_tls1 = padded_h * padded_w * handle->ifmblock;
element_input_type *const input_scratch = (element_input_type*)(((char*)handle->scratch5) +
  ltid * LIBXSMM_UP2(size_tls1 * sizeof(element_input_type), LIBXSMM_CACHELINE));

/* transpose via stack allocated buffers for output and weights to control stride-GEMM issue
   idea: we transpose grad_output and transpose filters when done */
const int scratch6_size = handle->ofhp * handle->ofwp * handle->ofmblock;
const int scratch7_size = handle->desc.R * handle->desc.S * handle->ifmblock * handle->ofmblock;
element_output_type *const output_scratch = (element_output_type*)(((char*)handle->scratch6) +
  ltid * LIBXSMM_UP2(scratch6_size * sizeof(element_output_type), LIBXSMM_CACHELINE));
element_filter_type *const filter_scratch = (element_filter_type*)(((char*)handle->scratch7) +
  ltid * LIBXSMM_UP2(scratch7_size * sizeof(element_filter_type), LIBXSMM_CACHELINE));

element_output_type *const out = (element_output_type*)handle->grad_output->data + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->blocksofm*handle->ofmblock;
LIBXSMM_VLA_DECL(5, const element_output_type, output, out, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock);
LIBXSMM_VLA_DECL(5, const element_output_type, output_padded, (const element_output_type*)handle->grad_output->data, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock);
LIBXSMM_VLA_DECL(5, const element_input_type, input, (element_input_type*)handle->reg_input->data, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock);
#if defined(LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM)