// ================================================ // square: C = A * op(A) or C = 0.5*(A*B'+B*A') // ================================================ void squareMatMat(double* C, double* A, double* B, const int rA, const int cA, const int rB, const int cB, const char *mod) { // can't pass consts to BLAS ptrdiff_t rA0 = rA, cA0 = cA, rB0 = rB; // rows(Op(A)), columns(Op(A)), columns(Op(B)), rows(C) ptrdiff_t copA, rC; int i,j; double temp; if ( (mod[0] == 'N') ){ copA = cA; rC = rA; } else { copA = rA; rC = cA; } #ifndef USE_BLAS // naive C implementations if ((rB == 0) || (cB == 0)){ // one input C = A*A' if ( (mod[0] == 'N') ) multABt(C, A, A, rA, cA, rA); else multAtB(C, A, A, rA, cA, cA); }else{ if ( (mod[0] == 'N') ) multABt(C, A, B, rA, cA, rB); else multAtB(C, A, B, rA, cA, cB); // symmetrize for( i=0; i<rC; i++ ) for( j=i; j<rC; j++ ){ temp = C[i*rC+j] + C[j*rC+i]; C[i*rC+j] = C[j*rC+i] = 0.5*temp; } } #else char modA = mod[0], modB = mod[1], uplo = 'U'; double one = 1.0, zero = 0.0, half = 0.5; if ((!rB) && (!cB)) // one input C = A*A' dsyrk(&uplo, &modA, &rC, &copA, &one, A, &rA0, &zero, C, &rC); else // two inputs C = 0.5*(A*B'+B*A') dsyr2k(&uplo, &modA, &rC, &copA, &half, A, &rA0, B, &rB0, &zero, C, &rC); // symmetrize for( i=0; i<rC; i++ ) for( j=i+1; j<rC; j++ ) C[i*rC+j] = C[j*rC+i]; #endif }
void offload_dsyrk(const char *uplo, const char *trans, const MKL_INT *n, const MKL_INT *k, const double *alpha, const double *a, const MKL_INT *lda, const double *beta, double *c, const MKL_INT *ldc){ /* * perform dsyrk on the device. a,c pre-exist on the device */ intptr_t aptr = (intptr_t)a; intptr_t cptr = (intptr_t)c; #pragma offload target(mic:MYDEVICE) in(uplo,trans,n,k:length(1)) \ in(alpha,lda,beta,ldc:length(1)) in(aptr,cptr) { dsyrk(uplo,trans,n,k,alpha,(double*)aptr,lda,beta,(double*)cptr,ldc); } }
void toast::lapack::syrk ( char * UPLO, char * TRANS, int * N, int * K, double * ALPHA, double * A, int * LDA, double * BETA, double * C, int * LDC ) { dsyrk ( UPLO, TRANS, N, K, ALPHA, A, LDA, BETA, C, LDC ); return; }