示例#1
0
Array<T> matmul(const Array<T> &lhs, const Array<T> &rhs,
                af_blas_transpose optLhs, af_blas_transpose optRhs)
{
    initBlas();
    clblasTranspose lOpts = toClblasTranspose(optLhs);
    clblasTranspose rOpts = toClblasTranspose(optRhs);

    int aRowDim = (lOpts == clblasNoTrans) ? 0 : 1;
    int aColDim = (lOpts == clblasNoTrans) ? 1 : 0;
    int bColDim = (rOpts == clblasNoTrans) ? 1 : 0;

    dim4 lDims = lhs.dims();
    dim4 rDims = rhs.dims();
    int M = lDims[aRowDim];
    int N = rDims[bColDim];
    int K = lDims[aColDim];

    //FIXME: Leaks on errors.
    Array<T> out = createEmptyArray<T>(af::dim4(M, N, 1, 1));
    auto alpha = scalar<T>(1);
    auto beta  = scalar<T>(0);

    dim4 lStrides = lhs.strides();
    dim4 rStrides = rhs.strides();
    clblasStatus err;
    cl::Event event;
    if(rDims[bColDim] == 1) {
        N = lDims[aColDim];
        gemv_func<T> gemv;
        err = gemv(
            clblasColumnMajor, lOpts,
            lDims[0], lDims[1],
            alpha,
            (*lhs.get())(),    lhs.getOffset(),   lStrides[1],
            (*rhs.get())(),    rhs.getOffset(),   rStrides[0],
            beta ,
            (*out.get())(),   out.getOffset(),             1,
            1, &getQueue()(), 0, nullptr, &event());
    } else {
        gemm_func<T> gemm;
        err = gemm(
                clblasColumnMajor, lOpts, rOpts,
                M, N, K,
                alpha,
                (*lhs.get())(),    lhs.getOffset(),   lStrides[1],
                (*rhs.get())(),    rhs.getOffset(),   rStrides[1],
                beta,
                (*out.get())(),   out.getOffset(),  out.dims()[0],
                1, &getQueue()(), 0, nullptr, &event());

    }
    if(err) {
        throw runtime_error(std::string("CLBLAS error: ") + std::to_string(err));
    }

    return out;
}
示例#2
0
Array<T> solve(const Array<T> &a, const Array<T> &b, const af_mat_prop options)
{
    try {
        initBlas();

        if (options & AF_MAT_UPPER ||
            options & AF_MAT_LOWER) {
            return triangleSolve<T>(a, b, options);
        }

        if(a.dims()[0] == a.dims()[1]) {
            return generalSolve<T>(a, b);
        } else {
            return leastSquares<T>(a, b);
        }
    } catch(cl::Error &err) {
        CL_TO_AF_ERROR(err);
    }
}
示例#3
0
int cholesky_inplace(Array<T> &in, const bool is_upper)
{
    try {
        initBlas();

        dim4 iDims = in.dims();
        int N = iDims[0];

        magma_uplo_t uplo = is_upper ? MagmaUpper : MagmaLower;

        int info = 0;
        cl::Buffer *in_buf = in.get();
        magma_potrf_gpu<T>(uplo, N,
                           (*in_buf)(), in.getOffset(),  in.strides()[1],
                           getQueue()(), &info);
        return info;
    } catch (cl::Error &err) {
        CL_TO_AF_ERROR(err);
    }
}
示例#4
0
Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
             af_blas_transpose optLhs, af_blas_transpose optRhs)
{
    initBlas();

    int N = lhs.dims()[0];
    dot_func<T> dot;
    cl::Event event;
    auto out = createEmptyArray<T>(af::dim4(1));
    cl::Buffer scratch(getContext(), CL_MEM_READ_WRITE, sizeof(T) * N);
    clblasStatus err;
    err = dot(N,
              (*out.get())(), out.getOffset(),
              (*lhs.get())(),  lhs.getOffset(), lhs.strides()[0],
              (*rhs.get())(),  rhs.getOffset(), rhs.strides()[0],
              scratch(),
              1, &getQueue()(), 0, nullptr, &event());

    if(err) {
        throw runtime_error(std::string("CLBLAS error: ") + std::to_string(err));
    }
    return out;
}