static void customcl_transpose_matrix(const int ctx_id, const Dtype* source, int width, int height,
    Dtype* target, int output_width, int output_height) {

  cl_kernel kernel = transpose_exec.handle().get();

  err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &target);
  SAMPLE_CHECK_ERRORS(err);

  err = clSetKernelArg(kernel, 1, sizeof(int), &output_width);
  SAMPLE_CHECK_ERRORS(err);

  err = clSetKernelArg(kernel, 2, sizeof(int), &output_height);
  SAMPLE_CHECK_ERRORS(err);

  err = clSetKernelArg(kernel, 3, sizeof(cl_mem), &source);
  SAMPLE_CHECK_ERRORS(err);

  err = clSetKernelArg(kernel, 4, sizeof(int), &width);
  SAMPLE_CHECK_ERRORS(err);

  err = clSetKernelArg(kernel, 5, sizeof(int), &height);
  SAMPLE_CHECK_ERRORS(err);
  
  size_t local_size[2] = {16, 16};
  size_t global_size[2] = {(output_width + 16 - 1) / 16 * 16, 
                           (output_height + 16 - 1) / 16 * 16};

  auto queue = viennacl::ocl::get_context(ctx_id).get_queue().handle().get();

  err = clEnqueueNDRangeKernel(
      queue,
      kernel,
      2,
      0,
      global_size,
      local_size,
      0, 0, NULL
  );
  SAMPLE_CHECK_ERRORS(err);

  err = clFinish(queue);
  SAMPLE_CHECK_ERRORS(err);

}
static void customcl_gpu_gemm(const int ctx_id, const int M,
                       const int N, const int K ,
                       const Dtype* A, const Dtype* B, Dtype* C) {
  // implement transpose.
  //std::cout << "addr " << B << std::endl;
  //std::cout << "MNK " << M << " " << N << " " << K << std::endl;
  if(!customcl_is_setup) {
    caffe::customcl_setup();
    customcl_is_setup = true;
  }
  auto queue = viennacl::ocl::get_context(ctx_id).get_queue().handle().get();

  const int align = 32;
  int oK = (K + align - 1) / align * align;
  int oM = (M + align - 1) / align * align;
  int oN = (N + align - 1) / align * align;

  Dtype* copy_buffer = (Dtype*)copy_ptr;
  if(sizeof(Dtype) * oK * oM > MAX_BUFFER_DIM) {
    throw "customcl_gpu_gemm: maximum buffer size exceeded.";
  }

  if(sizeof(Dtype) * oN * oK > MAX_BUFFER_DIM) {
    throw "customcl_gpu_gemm: maximum buffer size exceeded.";
  }

  customcl_copy_matrix(ctx_id,
      A, K, M,
      copy_buffer, oK, oM);

#if CUSTOM_GEMM_VERIFICATION == true
  clEnqueueMapBuffer(
      queue,
      (cl_mem) copy_buffer,
      CL_TRUE,    // blocking map
      CL_MAP_READ,
      0,
      oK * oM * sizeof(Dtype),
      0, 0, 0,
      &err
  );
  SAMPLE_CHECK_ERRORS(err);

  std::cout << "[verify copy] " << std::endl;
  for(size_t i = 0; i < oK; i++) {
    for(size_t j = 0; j < oM; j++) {
      if(i < K and j < M) {
        assertEq(((Dtype*)host_copy_buffer)[j * oK + i], A[j * K + i]);
      }else{
        assertEq(((Dtype*)host_copy_buffer)[j * oK + i], (Dtype)0.);
      }
    }
  }
#endif

  Dtype* trans_buffer = (Dtype*)transpose_ptr;

  customcl_transpose_matrix(ctx_id,
      B, K, N,
      trans_buffer, oN, oK);

#if CUSTOM_GEMM_VERIFICATION == true
  clEnqueueMapBuffer(
      queue,
      (cl_mem) trans_buffer,
      CL_TRUE,    // blocking map
      CL_MAP_READ,
      0,
      N * K * sizeof(Dtype),
      0, 0, 0,
      &err
  );
  SAMPLE_CHECK_ERRORS(err);

  std::cout << "[verifying] " << std::endl;
  for(size_t i = 0; i < N; i++) {
    for(size_t j = 0; j < K; j++) {
      if(B[i * K + j] != host_trans_buffer[j * oN + i]) {
        throw "verifcation failed";
      }
    }
  }
#endif


  cl_kernel kernel = gemm_exec.handle().get();

  err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &copy_buffer);
  SAMPLE_CHECK_ERRORS(err);

  err = clSetKernelArg(kernel, 1, sizeof(cl_mem), &trans_buffer);
  SAMPLE_CHECK_ERRORS(err);

  err = clSetKernelArg(kernel, 2, sizeof(cl_mem), &result_ptr);
  SAMPLE_CHECK_ERRORS(err);

  err = clSetKernelArg(kernel, 3, sizeof(int), &oM);
  SAMPLE_CHECK_ERRORS(err);

  err = clSetKernelArg(kernel, 4, sizeof(int), &oK);
  SAMPLE_CHECK_ERRORS(err);

  err = clSetKernelArg(kernel, 5, sizeof(int), &oN);
  SAMPLE_CHECK_ERRORS(err);

  size_t local_size[2] = {16, 16};
  size_t global_size[2] = {oM / 2, oN / 2};

  err = clEnqueueNDRangeKernel(
      queue,
      kernel,
      2,
      0,
      global_size,
      local_size,
      0, 0, NULL
  );
  SAMPLE_CHECK_ERRORS(err);

  err = clFinish(queue);
  SAMPLE_CHECK_ERRORS(err);

  // copy to output mem.
  customcl_copy_matrix(ctx_id,
      (Dtype*)result_ptr, oN, oM,
      C, N, M);
}