/* mT <- alpha*mA^T */ void transpose(double alpha, Mat mA, Mat mT) { const int n = MatN(mA); const void* const a = MatElems(mA); void* const t = MatElems(mT); const bool dev = MatDev(mA); const double beta = 0; switch (MatElemSize(mA)) { case 4: if (dev) { float alpha32 = alpha; cublasSgeam(g_cublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, n, n, &alpha32, a, n, (float*)&beta, t, n, t, n); } else { cblas_somatcopy(CblasColMajor, CblasTrans, n, n, alpha, a, n, t, n); } break; case 8: if (dev) { cublasDgeam(g_cublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, n, n, &alpha, a, n, &beta, t, n, t, n); } else { cblas_domatcopy(CblasColMajor, CblasTrans, n, n, alpha, a, n, t, n); } break; } }
void caffe_gpu_geam<float>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int M, const int N, const float alpha, const float* A, const float* B, const float beta, float* C){ // Note that cublas follows fortran order. int lda = (TransA == CblasNoTrans) ? N : M; int ldb = (TransB == CblasNoTrans) ? N : M; cublasOperation_t cuTransA = (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransB = (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; CUBLAS_CHECK(cublasSgeam(Caffe::get_current_cublas_handle(), cuTransB, cuTransA, N, M, &alpha, B, ldb, &beta, A, lda, C, N)); }
/* mB = mC => C <- alpha*A + beta*C otherwise C <- alpha*A + beta*B */ void geam(double alpha, Mat mA, double beta, Mat mB, Mat mC) { const int n = MatN(mA); const int n2 = MatN2(mA); const void* const a = MatElems(mA); const void* const b = MatElems(mB); void* const c = MatElems(mC); const bool dev = MatDev(mA); switch (MatElemSize(mA)) { case 4: if (dev) { float alpha32 = alpha, beta32 = beta; cublasSgeam(g_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, n, n, &alpha32, a, n, &beta32, b, n, c, n); } else { if (b == c) { cblas_sscal(n2, beta, c, 1); } else { memset(c, 0, MatSize(mC)); cblas_saxpy(n2, beta, b, 1, c, 1); } cblas_saxpy(n2, alpha, a, 1, c, 1); } break; case 8: if (dev) { cublasDgeam(g_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, n, n, &alpha, a, n, &beta, b, n, c, n); } else { if (b == c) { cblas_dscal(n2, beta, c, 1); } else { memset(c, 0, MatSize(mC)); cblas_daxpy(n2, beta, b, 1, c, 1); } cblas_daxpy(n2, alpha, a, 1, c, 1); } break; } }
cublasStatus_t cublasXgeam(cublasOperation_t transa, cublasOperation_t transb, int m, int n, const float *alpha, const float *A, int lda, const float *beta, const float *B, int ldb, float *C, int ldc) { return cublasSgeam(g_context->cublasHandle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); }