af_array retain(const af_array in) { ArrayInfo info = getInfo(in, false, false); af_dtype ty = info.getType(); if(info.isSparse()) { switch(ty) { case f32: return retainSparseHandle<float >(in); case f64: return retainSparseHandle<double >(in); case c32: return retainSparseHandle<detail::cfloat >(in); case c64: return retainSparseHandle<detail::cdouble>(in); default: TYPE_ERROR(1, ty); } } else { switch(ty) { case f32: return retainHandle<float >(in); case f64: return retainHandle<double >(in); case s32: return retainHandle<int >(in); case u32: return retainHandle<uint >(in); case u8: return retainHandle<uchar >(in); case c32: return retainHandle<detail::cfloat >(in); case c64: return retainHandle<detail::cdouble >(in); case b8: return retainHandle<char >(in); case s64: return retainHandle<intl >(in); case u64: return retainHandle<uintl >(in); case s16: return retainHandle<short >(in); case u16: return retainHandle<ushort >(in); default: TYPE_ERROR(1, ty); } } }
af_err af_matmul(af_array *out, const af_array lhs, const af_array rhs, const af_mat_prop optLhs, const af_mat_prop optRhs) { using namespace detail; try { ArrayInfo lhsInfo = getInfo(lhs, false, true); ArrayInfo rhsInfo = getInfo(rhs, true, true); if(lhsInfo.isSparse()) return af_sparse_matmul(out, lhs, rhs, optLhs, optRhs); af_dtype lhs_type = lhsInfo.getType(); af_dtype rhs_type = rhsInfo.getType(); if (!(optLhs == AF_MAT_NONE || optLhs == AF_MAT_TRANS || optLhs == AF_MAT_CTRANS)) { AF_ERROR("Using this property is not yet supported in matmul", AF_ERR_NOT_SUPPORTED); } if (!(optRhs == AF_MAT_NONE || optRhs == AF_MAT_TRANS || optRhs == AF_MAT_CTRANS)) { AF_ERROR("Using this property is not yet supported in matmul", AF_ERR_NOT_SUPPORTED); } if (lhsInfo.ndims() > 2 || rhsInfo.ndims() > 2) { AF_ERROR("matmul can not be used in batch mode", AF_ERR_BATCH); } TYPE_ASSERT(lhs_type == rhs_type); af_array output = 0; int aColDim = (optLhs == AF_MAT_NONE) ? 1 : 0; int bRowDim = (optRhs == AF_MAT_NONE) ? 0 : 1; DIM_ASSERT(1, lhsInfo.dims()[aColDim] == rhsInfo.dims()[bRowDim]); switch(lhs_type) { case f32: output = matmul<float >(lhs, rhs, optLhs, optRhs); break; case c32: output = matmul<cfloat >(lhs, rhs, optLhs, optRhs); break; case f64: output = matmul<double >(lhs, rhs, optLhs, optRhs); break; case c64: output = matmul<cdouble>(lhs, rhs, optLhs, optRhs); break; default: TYPE_ERROR(1, lhs_type); } std::swap(*out, output); } CATCHALL return AF_SUCCESS; }
af_err af_sparse_matmul(af_array *out, const af_array lhs, const af_array rhs, const af_mat_prop optLhs, const af_mat_prop optRhs) { using namespace detail; try { common::SparseArrayBase lhsBase = getSparseArrayBase(lhs); ArrayInfo rhsInfo = getInfo(rhs); ARG_ASSERT(2, lhsBase.isSparse() == true && rhsInfo.isSparse() == false); af_dtype lhs_type = lhsBase.getType(); af_dtype rhs_type = rhsInfo.getType(); ARG_ASSERT(1, lhsBase.getStorage() == AF_STORAGE_CSR); if (!(optLhs == AF_MAT_NONE || optLhs == AF_MAT_TRANS || optLhs == AF_MAT_CTRANS)) { // Note the ! operator. AF_ERROR("Using this property is not yet supported in sparse matmul", AF_ERR_NOT_SUPPORTED); } // No transpose options for RHS if (optRhs != AF_MAT_NONE) { AF_ERROR("Using this property is not yet supported in matmul", AF_ERR_NOT_SUPPORTED); } if (rhsInfo.ndims() > 2) { AF_ERROR("Sparse matmul can not be used in batch mode", AF_ERR_BATCH); } TYPE_ASSERT(lhs_type == rhs_type); af::dim4 ldims = lhsBase.dims(); int lColDim = (optLhs == AF_MAT_NONE) ? 1 : 0; int rRowDim = (optRhs == AF_MAT_NONE) ? 0 : 1; DIM_ASSERT(1, ldims[lColDim] == rhsInfo.dims()[rRowDim]); af_array output = 0; switch(lhs_type) { case f32: output = sparseMatmul<float >(lhs, rhs, optLhs, optRhs); break; case c32: output = sparseMatmul<cfloat >(lhs, rhs, optLhs, optRhs); break; case f64: output = sparseMatmul<double >(lhs, rhs, optLhs, optRhs); break; case c64: output = sparseMatmul<cdouble>(lhs, rhs, optLhs, optRhs); break; default: TYPE_ERROR(1, lhs_type); } std::swap(*out, output); } CATCHALL; return AF_SUCCESS; }
af_err af_release_array(af_array arr) { try { int dev = getActiveDeviceId(); ArrayInfo info = getInfo(arr, false, false); af_dtype type = info.getType(); if(info.isSparse()) { switch(type) { case f32: releaseSparseHandle<float >(arr); break; case f64: releaseSparseHandle<double >(arr); break; case c32: releaseSparseHandle<cfloat >(arr); break; case c64: releaseSparseHandle<cdouble>(arr); break; default : TYPE_ERROR(0, type); } } else { setDevice(info.getDevId()); switch(type) { case f32: releaseHandle<float >(arr); break; case c32: releaseHandle<cfloat >(arr); break; case f64: releaseHandle<double >(arr); break; case c64: releaseHandle<cdouble >(arr); break; case b8: releaseHandle<char >(arr); break; case s32: releaseHandle<int >(arr); break; case u32: releaseHandle<uint >(arr); break; case u8: releaseHandle<uchar >(arr); break; case s64: releaseHandle<intl >(arr); break; case u64: releaseHandle<uintl >(arr); break; case s16: releaseHandle<short >(arr); break; case u16: releaseHandle<ushort >(arr); break; default: TYPE_ERROR(0, type); } setDevice(dev); } } CATCHALL return AF_SUCCESS; }