Array<T> matmul(const Array<T> &lhs, const Array<T> &rhs, af_mat_prop optLhs, af_mat_prop optRhs) { lhs.eval(); rhs.eval(); CBLAS_TRANSPOSE lOpts = toCblasTranspose(optLhs); CBLAS_TRANSPOSE rOpts = toCblasTranspose(optRhs); int aRowDim = (lOpts == CblasNoTrans) ? 0 : 1; int aColDim = (lOpts == CblasNoTrans) ? 1 : 0; int bColDim = (rOpts == CblasNoTrans) ? 1 : 0; dim4 lDims = lhs.dims(); dim4 rDims = rhs.dims(); int M = lDims[aRowDim]; int N = rDims[bColDim]; int K = lDims[aColDim]; using BT = typename blas_base<T>::type; using CBT = const typename blas_base<T>::type; Array<T> out = createEmptyArray<T>(af::dim4(M, N, 1, 1)); auto func = [=] (Array<T> output, const Array<T> left, const Array<T> right) { auto alpha = getScale<T, 1>(); auto beta = getScale<T, 0>(); dim4 lStrides = left.strides(); dim4 rStrides = right.strides(); if(rDims[bColDim] == 1) { gemv_func<T>()( CblasColMajor, lOpts, lDims[0], lDims[1], alpha, reinterpret_cast<CBT*>(left.get()), lStrides[1], reinterpret_cast<CBT*>(right.get()), rStrides[0], beta, reinterpret_cast<BT*>(output.get()), 1); } else { gemm_func<T>()( CblasColMajor, lOpts, rOpts, M, N, K, alpha, reinterpret_cast<CBT*>(left.get()), lStrides[1], reinterpret_cast<CBT*>(right.get()), rStrides[1], beta, reinterpret_cast<BT*>(output.get()), output.dims()[0]); } }; getQueue().enqueue(func, out, lhs, rhs); return out; }
Array<T> matmul(const Array<T> &lhs, const Array<T> &rhs, af_mat_prop optLhs, af_mat_prop optRhs) { lhs.eval(); rhs.eval(); CBLAS_TRANSPOSE lOpts = toCblasTranspose(optLhs); CBLAS_TRANSPOSE rOpts = toCblasTranspose(optRhs); int aRowDim = (lOpts == CblasNoTrans) ? 0 : 1; int aColDim = (lOpts == CblasNoTrans) ? 1 : 0; int bColDim = (rOpts == CblasNoTrans) ? 1 : 0; auto lDims = lhs.dims(); auto rDims = rhs.dims(); int M = lDims[aRowDim]; int N = rDims[bColDim]; int K = lDims[aColDim]; using BT = typename blas_base<T>::type; using CBT = const typename blas_base<T>::type; dim_t d2 = std::max(lDims[2], rDims[2]); dim_t d3 = std::max(lDims[3], rDims[3]); const dim4 oDims(M, N, d2, d3); Array<T> out = createEmptyArray<T>(oDims); auto func = [=] (Param<T> output, CParam<T> left, CParam<T> right) { auto alpha = getScale<T, 1>(); auto beta = getScale<T, 0>(); dim4 lStrides = left.strides(); dim4 rStrides = right.strides(); dim4 oStrides = output.strides(); int batchSize = oDims[2] * oDims[3]; bool is_l_d2_batched = oDims[2] == lDims[2]; bool is_l_d3_batched = oDims[3] == lDims[3]; bool is_r_d2_batched = oDims[2] == rDims[2]; bool is_r_d3_batched = oDims[3] == rDims[3]; for (int n = 0; n < batchSize; n++) { int w = n / oDims[2]; int z = n - w * oDims[2]; int loff = z * (is_l_d2_batched * lStrides[2]) + w * (is_l_d3_batched * lStrides[3]); int roff = z * (is_r_d2_batched * rStrides[2]) + w * (is_r_d3_batched * rStrides[3]); CBT *lptr = reinterpret_cast<CBT*>(left.get() + loff); CBT *rptr = reinterpret_cast<CBT*>(right.get() + roff); BT *optr = reinterpret_cast<BT*>(output.get() + z * oStrides[2] + w * oStrides[3]); if(rDims[bColDim] == 1) { dim_t incr = (optRhs == AF_MAT_NONE) ? rStrides[0] : rStrides[1]; gemv_func<T>()( CblasColMajor, lOpts, lDims[0], lDims[1], alpha, lptr, lStrides[1], rptr, incr, beta, optr, 1); } else { gemm_func<T>()( CblasColMajor, lOpts, rOpts, M, N, K, alpha, lptr, lStrides[1], rptr, rStrides[1], beta, optr, output.dims(0)); } } }; getQueue().enqueue(func, out, lhs, rhs); return out; }
Array<T> matmul(const Array<T> &lhs, const Array<T> &rhs, af_mat_prop optLhs, af_mat_prop optRhs) { CBLAS_TRANSPOSE lOpts = toCblasTranspose(optLhs); CBLAS_TRANSPOSE rOpts = toCblasTranspose(optRhs); int aRowDim = (lOpts == CblasNoTrans) ? 0 : 1; int aColDim = (lOpts == CblasNoTrans) ? 1 : 0; int bColDim = (rOpts == CblasNoTrans) ? 1 : 0; dim4 lDims = lhs.dims(); dim4 rDims = rhs.dims(); int M = lDims[aRowDim]; int N = rDims[bColDim]; int K = lDims[aColDim]; dim_t d2 = std::max(lDims[2], rDims[2]); dim_t d3 = std::max(lDims[3], rDims[3]); dim4 oDims = af::dim4(M, N, d2, d3); //FIXME: Leaks on errors. Array<T> out = createValueArray<T>(oDims, scalar<T>(0)); auto alpha = getScale<T, 1>(); auto beta = getScale<T, 0>(); dim4 lStrides = lhs.strides(); dim4 rStrides = rhs.strides(); dim4 oStrides = out.strides(); using BT = typename blas_base<T>::type; using CBT = const typename blas_base<T>::type; int batchSize = oDims[2] * oDims[3]; bool is_l_d2_batched = (oDims[2] == lDims[2]); bool is_l_d3_batched = (oDims[3] == lDims[3]); bool is_r_d2_batched = (oDims[2] == rDims[2]); bool is_r_d3_batched = (oDims[3] == rDims[3]); for(int n = 0; n < batchSize; ++n) { int w = n / rDims[2]; int z = n - w * rDims[2]; int loff = z * (is_l_d2_batched * lStrides[2]) + w * (is_l_d3_batched * lStrides[3]); int roff = z * (is_r_d2_batched * rStrides[2]) + w * (is_r_d3_batched * rStrides[3]); // get host pointers from mapped memory auto lPtr = lhs.getMappedPtr(); auto rPtr = rhs.getMappedPtr(); auto oPtr = out.getMappedPtr(); CBT *lptr = (CBT*)(lPtr.get() + loff); CBT *rptr = (CBT*)(rPtr.get() + roff); BT *optr = (BT*)(oPtr.get() + z * oStrides[2] + w * oStrides[3]); if(rDims[bColDim] == 1) { dim_t incr = (rOpts == CblasNoTrans) ? rStrides[0] : rStrides[1]; N = lDims[aColDim]; gemv_func<T>()( CblasColMajor, lOpts, lDims[0], lDims[1], alpha, lptr, lStrides[1], rptr, incr, beta, optr, 1); } else { gemm_func<T>()( CblasColMajor, lOpts, rOpts, M, N, K, alpha, lptr, lStrides[1], rptr, rStrides[1], beta, optr, out.dims()[0]); } } return out; }