예제 #1
0
Array<T> matmul(const Array<T> &lhs, const Array<T> &rhs,
                af_mat_prop optLhs, af_mat_prop optRhs)
{
    lhs.eval();
    rhs.eval();

    CBLAS_TRANSPOSE lOpts = toCblasTranspose(optLhs);
    CBLAS_TRANSPOSE rOpts = toCblasTranspose(optRhs);

    int aRowDim = (lOpts == CblasNoTrans) ? 0 : 1;
    int aColDim = (lOpts == CblasNoTrans) ? 1 : 0;
    int bColDim = (rOpts == CblasNoTrans) ? 1 : 0;

    dim4 lDims = lhs.dims();
    dim4 rDims = rhs.dims();
    int M = lDims[aRowDim];
    int N = rDims[bColDim];
    int K = lDims[aColDim];

    using BT  =       typename blas_base<T>::type;
    using CBT = const typename blas_base<T>::type;

    Array<T> out = createEmptyArray<T>(af::dim4(M, N, 1, 1));
    auto func = [=] (Array<T> output, const Array<T> left, const Array<T> right) {
        auto alpha = getScale<T, 1>();
        auto beta  = getScale<T, 0>();

        dim4 lStrides = left.strides();
        dim4 rStrides = right.strides();

        if(rDims[bColDim] == 1) {
            gemv_func<T>()(
                CblasColMajor, lOpts,
                lDims[0], lDims[1],
                alpha,
                reinterpret_cast<CBT*>(left.get()), lStrides[1],
                reinterpret_cast<CBT*>(right.get()), rStrides[0],
                beta,
                reinterpret_cast<BT*>(output.get()), 1);
        } else {
            gemm_func<T>()(
                CblasColMajor, lOpts, rOpts,
                M, N, K,
                alpha,
                reinterpret_cast<CBT*>(left.get()), lStrides[1],
                reinterpret_cast<CBT*>(right.get()), rStrides[1],
                beta,
                reinterpret_cast<BT*>(output.get()), output.dims()[0]);
        }
    };
    getQueue().enqueue(func, out, lhs, rhs);

    return out;
}
예제 #2
0
파일: blas.cpp 프로젝트: mlloreda/arrayfire
Array<T> matmul(const Array<T> &lhs, const Array<T> &rhs,
                af_mat_prop optLhs, af_mat_prop optRhs)
{
    lhs.eval();
    rhs.eval();

    CBLAS_TRANSPOSE lOpts = toCblasTranspose(optLhs);
    CBLAS_TRANSPOSE rOpts = toCblasTranspose(optRhs);

    int aRowDim = (lOpts == CblasNoTrans) ? 0 : 1;
    int aColDim = (lOpts == CblasNoTrans) ? 1 : 0;
    int bColDim = (rOpts == CblasNoTrans) ? 1 : 0;

    auto lDims = lhs.dims();
    auto rDims = rhs.dims();
    int M = lDims[aRowDim];
    int N = rDims[bColDim];
    int K = lDims[aColDim];

    using BT  =       typename blas_base<T>::type;
    using CBT = const typename blas_base<T>::type;

    dim_t d2 = std::max(lDims[2], rDims[2]);
    dim_t d3 = std::max(lDims[3], rDims[3]);
    const dim4 oDims(M, N, d2, d3);
    Array<T> out = createEmptyArray<T>(oDims);

    auto func = [=] (Param<T> output, CParam<T> left, CParam<T> right) {
        auto alpha = getScale<T, 1>();
        auto beta  = getScale<T, 0>();

        dim4 lStrides = left.strides();
        dim4 rStrides = right.strides();
        dim4 oStrides = output.strides();

        int batchSize = oDims[2] * oDims[3];

        bool is_l_d2_batched = oDims[2] == lDims[2];
        bool is_l_d3_batched = oDims[3] == lDims[3];
        bool is_r_d2_batched = oDims[2] == rDims[2];
        bool is_r_d3_batched = oDims[3] == rDims[3];

        for (int n = 0; n < batchSize; n++) {
            int w = n / oDims[2];
            int z = n - w * oDims[2];

            int loff = z * (is_l_d2_batched * lStrides[2]) + w * (is_l_d3_batched * lStrides[3]);
            int roff = z * (is_r_d2_batched * rStrides[2]) + w * (is_r_d3_batched * rStrides[3]);

            CBT *lptr = reinterpret_cast<CBT*>(left.get() + loff);
            CBT *rptr = reinterpret_cast<CBT*>(right.get() + roff);
            BT *optr = reinterpret_cast<BT*>(output.get() + z * oStrides[2] + w * oStrides[3]);

            if(rDims[bColDim] == 1) {
                dim_t incr = (optRhs == AF_MAT_NONE) ? rStrides[0] : rStrides[1];
                gemv_func<T>()(
                    CblasColMajor, lOpts,
                    lDims[0], lDims[1],
                    alpha,
                    lptr, lStrides[1],
                    rptr, incr,
                    beta,
                    optr, 1);
            } else {
                gemm_func<T>()(
                    CblasColMajor, lOpts, rOpts,
                    M, N, K,
                    alpha,
                    lptr, lStrides[1],
                    rptr, rStrides[1],
                    beta,
                    optr, output.dims(0));
            }
        }
    };
    getQueue().enqueue(func, out, lhs, rhs);

    return out;
}
예제 #3
0
Array<T> matmul(const Array<T> &lhs, const Array<T> &rhs,
                af_mat_prop optLhs, af_mat_prop optRhs)
{
    CBLAS_TRANSPOSE lOpts = toCblasTranspose(optLhs);
    CBLAS_TRANSPOSE rOpts = toCblasTranspose(optRhs);

    int aRowDim = (lOpts == CblasNoTrans) ? 0 : 1;
    int aColDim = (lOpts == CblasNoTrans) ? 1 : 0;
    int bColDim = (rOpts == CblasNoTrans) ? 1 : 0;

    dim4 lDims = lhs.dims();
    dim4 rDims = rhs.dims();
    int M = lDims[aRowDim];
    int N = rDims[bColDim];
    int K = lDims[aColDim];
    dim_t d2 = std::max(lDims[2], rDims[2]);
    dim_t d3 = std::max(lDims[3], rDims[3]);
    dim4 oDims = af::dim4(M, N, d2, d3);

    //FIXME: Leaks on errors.
    Array<T> out = createValueArray<T>(oDims, scalar<T>(0));
    auto alpha = getScale<T, 1>();
    auto beta  = getScale<T, 0>();

    dim4 lStrides = lhs.strides();
    dim4 rStrides = rhs.strides();
    dim4 oStrides = out.strides();

    using BT  =       typename blas_base<T>::type;
    using CBT = const typename blas_base<T>::type;

    int batchSize = oDims[2] * oDims[3];

    bool is_l_d2_batched = (oDims[2] == lDims[2]);
    bool is_l_d3_batched = (oDims[3] == lDims[3]);
    bool is_r_d2_batched = (oDims[2] == rDims[2]);
    bool is_r_d3_batched = (oDims[3] == rDims[3]);

    for(int n = 0; n < batchSize; ++n) {
        int w = n / rDims[2];
        int z = n - w * rDims[2];

        int loff = z * (is_l_d2_batched * lStrides[2]) + w * (is_l_d3_batched * lStrides[3]);
        int roff = z * (is_r_d2_batched * rStrides[2]) + w * (is_r_d3_batched * rStrides[3]);

        // get host pointers from mapped memory
        auto lPtr = lhs.getMappedPtr();
        auto rPtr = rhs.getMappedPtr();
        auto oPtr = out.getMappedPtr();

        CBT *lptr = (CBT*)(lPtr.get() + loff);
        CBT *rptr = (CBT*)(rPtr.get() + roff);
        BT  *optr = (BT*)(oPtr.get() + z * oStrides[2] + w * oStrides[3]);

        if(rDims[bColDim] == 1) {
            dim_t incr = (rOpts == CblasNoTrans) ? rStrides[0] : rStrides[1];
            N = lDims[aColDim];
            gemv_func<T>()(
                CblasColMajor, lOpts,
                lDims[0], lDims[1],
                alpha,
                lptr, lStrides[1],
                rptr, incr,
                beta,
                optr, 1);
        } else {
            gemm_func<T>()(
                CblasColMajor, lOpts, rOpts,
                M, N, K,
                alpha,
                lptr, lStrides[1],
                rptr, rStrides[1],
                beta,
                optr, out.dims()[0]);
        }
    }

    return out;
}