/* * compute: * C*Q.T = C*(I -Y*T*Y.T).T == C - C*Y*T.T*Y.T * or * C*Q = (I -Y*T*Y.T)*C == C - C*Y*T*Y.T * * where C = ( C1 C2 ) Y = ( Y1 Y2 ) * * C1 is K*nb, C2 is K*P, Y1 is nb*nb trilu, Y2 is nb*P, T is nb*nb, W = K*nb */ int __update_rq_right(armas_x_dense_t *C1, armas_x_dense_t *C2, armas_x_dense_t *Y1, armas_x_dense_t *Y2, armas_x_dense_t *T, armas_x_dense_t *W, int transpose, armas_conf_t *conf) { // W = C1 armas_x_scale_plus(__ZERO, W, __ONE, C1, ARMAS_NONE, conf); // W = C1*Y1 = W*Y1 armas_x_mult_trm(W, __ONE, Y1, ARMAS_LOWER|ARMAS_UNIT|ARMAS_RIGHT|ARMAS_TRANSA, conf); // W = W + C2*Y2.T armas_x_mult(__ONE, W, __ONE, C2, Y2, ARMAS_TRANSB, conf); // here: W = C*Y int bits = ARMAS_LOWER|ARMAS_RIGHT; if (transpose) bits |= ARMAS_TRANSA; // W = W*T or W.T*T armas_x_mult_trm(W, __ONE, T, bits, conf); // here: W == C*Y*T or C*Y*T.T // C2 = C2 - W*Y2 armas_x_mult(__ONE, C2, -__ONE, W, Y2, ARMAS_NONE, conf); // C1 = C1 - W*Y1 // W = W*Y1.T armas_x_mult_trm(W, __ONE, Y1, ARMAS_LOWER|ARMAS_UNIT|ARMAS_RIGHT, conf); // C1 = C1 - W armas_x_scale_plus(__ONE, C1, -__ONE, W, ARMAS_NONE, conf); // here: C = C*(I - Y*T*Y.T)*C or C = C*(I - Y*T.Y.T).T return 0; }
/* * compute: * Q.T*C = (I -Y*T*Y.T).T*C == C - Y*(C.T*Y*T).T * or * Q*C = (I -Y*T*Y.T)*C == C - Y*(C.T*Y*T.T).T * * where C = ( C1 ) Y = ( Y2 Y1 ) * ( C2 ) * * C1 is nb*K, C2 is P*K, Y1 is nb*nb triuu, Y2 is nb*P, T is nb*nb, W is K*nb */ int __update_rq_left(armas_x_dense_t *C1, armas_x_dense_t *C2, armas_x_dense_t *Y1, armas_x_dense_t *Y2, armas_x_dense_t *T, armas_x_dense_t *W, int transpose, armas_conf_t *conf) { // W = C1.T armas_x_scale_plus(__ZERO, W, __ONE, C1, ARMAS_TRANSB, conf); // W = C1.T*Y1.T = W*Y1.T armas_x_mult_trm(W, __ONE, Y1, ARMAS_LOWER|ARMAS_UNIT|ARMAS_RIGHT|ARMAS_TRANSA, conf); // W = W + C2.T*Y2.T armas_x_mult(__ONE, W, __ONE, C2, Y2, ARMAS_TRANSA|ARMAS_TRANSB, conf); // here: W = C.T*Y int bits = ARMAS_LOWER|ARMAS_RIGHT; if (! transpose) bits |= ARMAS_TRANSA; // W = W*T or W.T*T armas_x_mult_trm(W, __ONE, T, bits, conf); // here: W == C.T*Y*T or C.T*Y*T.T // C2 = C2 - Y2*W.T armas_x_mult(__ONE, C2, -__ONE, Y2, W, ARMAS_TRANSA|ARMAS_TRANSB, conf); // W = Y1*W.T ==> W.T = W*Y1 armas_x_mult_trm(W, __ONE, Y1, ARMAS_LOWER|ARMAS_UNIT|ARMAS_RIGHT, conf); // C1 = C1 - W.T armas_x_scale_plus(__ONE, C1, -__ONE, W, ARMAS_TRANSB, conf); // here: C = (I - Y*T*Y.T)*C or C = (I - Y*T.Y.T).T*C return 0; }
void __trmmf(char *side, char *uplo, char *transa, char *diag,int *m, int *n, DTYPE *alpha, DTYPE *A, int *lda, DTYPE *B, int *ldb) { armas_conf_t *conf = armas_conf_default(); armas_x_dense_t a, b; int flags = 0; armas_x_make(&b, *m, *n, *ldb, B); switch (toupper(*side)) { case 'R': flags |= ARMAS_RIGHT; armas_x_make(&a, *n, *n, *lda, A); break; case 'L': default: flags |= ARMAS_LEFT; armas_x_make(&a, *m, *m, *lda, A); break; } flags |= toupper(*uplo) == 'L' ? ARMAS_LOWER : ARMAS_UPPER; if (toupper(*transa) == 'T') flags |= ARMAS_TRANS; if (toupper(*diag) == 'U') flags |= ARMAS_UNIT; armas_x_mult_trm(&b, &a, *alpha, flags, conf); }
void __cblas_trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo, const enum CBLAS_TRANSPOSE transa, const enum CBLAS_DIAG diag, int M, int N, DTYPE alpha, DTYPE *A, int lda, DTYPE *B, int ldb) { armas_conf_t conf = *armas_conf_default(); armas_x_dense_t Aa, Ba; int flags = 0; switch (order) { case CblasColMajor: flags |= side == CblasRight ? ARMAS_RIGHT : ARMAS_LEFT; flags |= uplo == CblasUpper ? ARMAS_UPPER : ARMAS_LOWER; if (diag == CblasUnit) flags |= ARMAS_UNIT; if (transa == CblasTrans) flags |= ARMAS_TRANSA; // M > ldb --> error armas_x_make(&Ba, M, N, ldb, B); if (side == CblasRight) { // N > lda --> error armas_x_make(&Aa, N, N, lda, A); } else { // M > lda --> error armas_x_make(&Aa, M, M, lda, A); } break; case CblasRowMajor: flags |= side == CblasRight ? ARMAS_LEFT : ARMAS_RIGHT; flags |= uplo == CblasUpper ? ARMAS_LOWER : ARMAS_UPPER; if (diag == CblasUnit) flags |= ARMAS_UNIT; if (transa == CblasNoTrans) flags |= ARMAS_TRANSA; // N > ldb --> error armas_x_make(&Ba, N, M, ldb, B); if (side == CblasRight) { // N > lda --> error armas_x_make(&Aa, M, M, lda, A); } else { // M > lda --> error armas_x_make(&Aa, N, N, lda, A); } break; default: return; } armas_x_mult_trm(&Ba, &Aa, alpha, flags, &conf); }
int test_left_right(int N, int verbose) { int ok; armas_x_dense_t A, At, B, Bt; DTYPE n0, nrmB; armas_conf_t conf = *armas_conf_default(); armas_x_init(&A, N, N); armas_x_init(&At, N, N); armas_x_init(&B, N, N); armas_x_init(&Bt, N, N); armas_x_set_values(&A, one, ARMAS_SYMM); armas_x_transpose(&At, &A); armas_x_set_values(&B, one, ARMAS_ANY); armas_x_mult_trm(&B, 1.0, &A, ARMAS_UPPER, ARMAS_ANY); armas_x_transpose(&Bt, &B); if (N < 10) { printf("A\n"); armas_x_printf(stdout, "%6.3f", &A); printf("At\n"); armas_x_printf(stdout, "%6.3f", &At); } nrmB = armas_x_mnorm(&B, ARMAS_NORM_INF, &conf); // ||k*A.-1*B + (B.T*-k*A.-T).T|| ~ eps armas_x_solve_trm(&B, 2.0, &A, ARMAS_LEFT|ARMAS_UPPER, &conf); armas_x_solve_trm(&Bt, -2.0, &At, ARMAS_RIGHT|ARMAS_UPPER|ARMAS_TRANS, &conf); if (N < 10) { printf("B\n"); armas_x_printf(stdout, "%6.3f", &B); printf("Bt\n"); armas_x_printf(stdout, "%6.3f", &Bt); } armas_x_scale_plus(1.0, &B, 1.0, &Bt, ARMAS_TRANSB, &conf); if (N < 10) { printf("B + B.T\n"); armas_x_printf(stdout, "%6.3f", &B); } n0 = armas_x_mnorm(&B, ARMAS_NORM_INF, &conf) / nrmB; ok = isOK(n0, N) || n0 == 0.0; printf("%4s : || k*A.-1*B + (-k*B.T*A.-T).T|| : %e\n", PASS(ok), n0); return 1 - ok; }
int main(int argc, char **argv) { armas_conf_t conf; armas_x_dense_t B0, A, B; int ok, opt; int N = 301; int verbose = 1; int fails = 0; DTYPE alpha = 1.0; DTYPE n0, n1; while ((opt = getopt(argc, argv, "vC:")) != -1) { switch (opt) { case 'v': verbose++; break; case 'C': Aconstant = STRTOF(optarg); break; default: fprintf(stderr, "usage: trsm [-P nproc] [size]\n"); exit(1); } } if (optind < argc) N = atoi(argv[optind]); conf = *armas_conf_default(); armas_x_init(&A, N, N); armas_x_init(&B, N, N); armas_x_init(&B0, N, N); // Upper triangular matrix armas_x_set_values(&A, one, ARMAS_UPPER); if (N < 10) { printf("A:\n"); armas_x_printf(stdout, "%8.1e", &A); } armas_x_set_values(&B, one, ARMAS_NULL); armas_x_mcopy(&B0, &B); armas_x_mult_trm(&B, alpha, &A, ARMAS_UPPER|ARMAS_LEFT, &conf); if (N < 10) { printf("A*B:\n"); armas_x_printf(stdout, "%8.1e", &B); } armas_x_solve_trm(&B, alpha, &A, ARMAS_UPPER|ARMAS_LEFT, &conf); if (N < 10) { printf("A.-1*B:\n"); armas_x_printf(stdout, "%8.1e", &B); } n0 = rel_error(&n1, &B, &B0, ARMAS_NORM_ONE, ARMAS_NONE, &conf); ok = n0 == 0.0 || isOK(n0, N) ? 1 : 0; printf("%6s: B = solve_trm(mult_trm(B, A, L|U|N), A, L|U|N)\n", PASS(ok)); if (verbose > 0) { printf(" || rel error || : %e, [%d]\n", n0, ndigits(n0)); } fails += 1 - ok; armas_x_set_values(&B, one, ARMAS_NULL); armas_x_mcopy(&B0, &B); armas_x_mult_trm(&B, alpha, &A, ARMAS_UPPER|ARMAS_RIGHT, &conf); armas_x_solve_trm(&B, alpha, &A, ARMAS_UPPER|ARMAS_RIGHT, &conf); n0 = rel_error(&n1, &B, &B0, ARMAS_NORM_ONE, ARMAS_NONE, &conf); ok = n0 == 0.0 || isOK(n0, N) ? 1 : 0; printf("%6s: B = solve_trm(mult_trm(B, A, R|U|N), A, R|U|N)\n", PASS(ok)); if (verbose > 0) { printf(" || rel error || : %e, [%d]\n", n0, ndigits(n0)); } fails += 1 - ok; armas_x_set_values(&B, one, ARMAS_NULL); armas_x_mcopy(&B0, &B); armas_x_mult_trm(&B, alpha, &A, ARMAS_UPPER|ARMAS_LEFT|ARMAS_TRANSA, &conf); armas_x_solve_trm(&B, alpha, &A, ARMAS_UPPER|ARMAS_LEFT|ARMAS_TRANSA, &conf); n0 = rel_error(&n1, &B, &B0, ARMAS_NORM_ONE, ARMAS_NONE, &conf); ok = n0 == 0.0 || isOK(n0, N) ? 1 : 0; printf("%6s: B = solve_trm(mult_trm(B, A, L|U|T), A, L|U|T)\n", PASS(ok)); if (verbose > 0) { printf(" || rel error || : %e, [%d]\n", n0, ndigits(n0)); } fails += 1 - ok; armas_x_set_values(&B, one, ARMAS_NULL); armas_x_mcopy(&B0, &B); armas_x_mult_trm(&B, alpha, &A, ARMAS_UPPER|ARMAS_RIGHT|ARMAS_TRANSA, &conf); armas_x_solve_trm(&B, alpha, &A, ARMAS_UPPER|ARMAS_RIGHT|ARMAS_TRANSA, &conf); n0 = rel_error(&n1, &B, &B0, ARMAS_NORM_ONE, ARMAS_NONE, &conf); ok = n0 == 0.0 || isOK(n0, N) ? 1 : 0; printf("%6s: B = solve_trm(mult_trm(B, A, R|U|T), A, R|U|T)\n", PASS(ok)); if (verbose > 0) { printf(" || rel error || : %e, [%d]\n", n0, ndigits(n0)); } fails += 1 - ok; // Lower triangular matrix armas_x_set_values(&A, one, ARMAS_LOWER); armas_x_set_values(&B, one, ARMAS_NULL); armas_x_mcopy(&B0, &B); armas_x_mult_trm(&B, alpha, &A, ARMAS_LOWER|ARMAS_LEFT, &conf); armas_x_solve_trm(&B, alpha, &A, ARMAS_LOWER|ARMAS_LEFT, &conf); n0 = rel_error(&n1, &B, &B0, ARMAS_NORM_ONE, ARMAS_NONE, &conf); ok = n0 == 0.0 || isOK(n0, N) ? 1 : 0; printf("%6s: B = solve_trm(mult_trm(B, A, L|L|N), A, L|L|N)\n", PASS(ok)); if (verbose > 0) { printf(" || rel error || : %e, [%d]\n", n0, ndigits(n0)); } fails += 1 - ok; armas_x_set_values(&B, one, ARMAS_NULL); armas_x_mcopy(&B0, &B); armas_x_mult_trm(&B, alpha, &A, ARMAS_LOWER|ARMAS_RIGHT, &conf); armas_x_solve_trm(&B, alpha, &A, ARMAS_LOWER|ARMAS_RIGHT, &conf); n0 = rel_error(&n1, &B, &B0, ARMAS_NORM_ONE, ARMAS_NONE, &conf); ok = n0 == 0.0 || isOK(n0, N) ? 1 : 0; printf("%6s: B = solve_trm(mult_trm(B, A, R|L|N), A, R|L|N)\n", PASS(ok)); if (verbose > 0) { printf(" || rel error || : %e, [%d]\n", n0, ndigits(n0)); } fails += 1 - ok; armas_x_set_values(&B, one, ARMAS_NULL); armas_x_mcopy(&B0, &B); armas_x_mult_trm(&B, alpha, &A, ARMAS_LOWER|ARMAS_LEFT|ARMAS_TRANSA, &conf); armas_x_solve_trm(&B, alpha, &A, ARMAS_LOWER|ARMAS_LEFT|ARMAS_TRANSA, &conf); n0 = rel_error(&n1, &B, &B0, ARMAS_NORM_ONE, ARMAS_NONE, &conf); ok = n0 == 0.0 || isOK(n0, N) ? 1 : 0; printf("%6s: B = solve_trm(mult_trm(B, A, L|L|T), A, L|L|T)\n", PASS(ok)); if (verbose > 0) { printf(" || rel error || : %e, [%d]\n", n0, ndigits(n0)); } fails += 1 - ok; armas_x_set_values(&B, one, ARMAS_NULL); armas_x_mcopy(&B0, &B); armas_x_mult_trm(&B, alpha, &A, ARMAS_LOWER|ARMAS_RIGHT|ARMAS_TRANSA, &conf); armas_x_solve_trm(&B, alpha, &A, ARMAS_LOWER|ARMAS_RIGHT|ARMAS_TRANSA, &conf); n0 = rel_error(&n1, &B, &B0, ARMAS_NORM_ONE, ARMAS_NONE, &conf); ok = n0 == 0.0 || isOK(n0, N) ? 1 : 0; printf("%6s: B = solve_trm(mult_trm(B, A, R|L|T), A, R|L|T)\n", PASS(ok)); if (verbose > 0) { printf(" || rel error || : %e, [%d]\n", n0, ndigits(n0)); } fails += 1 - ok; test_left_right(N, verbose); exit(fails); }
int main(int argc, char **argv) { int ok, opt, i; int count = 5; int nproc = 2; int trans = 0; int right = 0; int lower = 0; int algo = 'B'; int flags = 0; int N = 600; int verbose = 0; double rt, min, max, avg; armas_conf_t conf; armas_x_dense_t C, A, B; while ((opt = getopt(argc, argv, "vc:P:a:s:t:T:")) != -1) { switch (opt) { case 's': right = *optarg == 'R' || *optarg == 'r'; break; case 't': lower = *optarg == 'L' || *optarg == 'l'; break; case 'T': trans = *optarg == 'T' || *optarg == 't'; break; case 'a': algo = *optarg; break; case 'v': verbose = 1; break; case 'c': count = atoi(optarg); break; case 'P': nproc = atoi(optarg); break; default: /* ? */ fprintf(stderr, "Usage: time_symm [-v] [-c numtest] [-P nproc] size"); break; } } if (optind < argc) N = atoi(argv[optind]); flags |= right ? ARMAS_RIGHT : ARMAS_LEFT; flags |= lower ? ARMAS_LOWER : ARMAS_UPPER; if (trans) flags |= ARMAS_TRANSA; //armas_init(); conf = *armas_conf_default(); /*conf.mb = 64; conf.nb = 128; conf.kb = 160; conf.maxproc = nproc; */ if (algo == 'N' || algo == 'n') { conf.optflags |= ARMAS_ONAIVE; } else if (algo == 'R' || algo == 'r') { conf.optflags |= ARMAS_ORECURSIVE; } if (verbose) { printf(".mb=%d, .nb=%d, .kb=%d, .wb=%d\n", conf.mb, conf.nb, conf.kb,conf.wb); printf(".maxproc=%d\n", conf.maxproc); } armas_x_init(&A, N, N); armas_x_init(&B, N, N); armas_x_set_values(&A, one, flags); armas_x_set_values(&B, one, ARMAS_NULL); // C = A*B min = max = avg = 0.0; for (i = 0; i < count; i++) { flush(); rt = time_msec(); armas_x_mult_trm(&B, 1.0, &A, flags, &conf); rt = time_msec() - rt; armas_x_set_values(&B, one, ARMAS_NULL); if (i == 0) { min = max = avg = rt; } else { if (rt < min) min = rt; if (rt > max) max = rt; avg += (rt - avg)/(i+1); } if (verbose) printf("%2d: %.4f, %.4f, %.4f msec\n", i, min, avg, max); } // forward/backward multiplication N^2 ops, and for N columns int64_t nops = (int64_t)N*N*N; printf("N: %4d, %8.4f, %8.4f, %8.4f Gflops\n", N, gflops(max, nops), gflops(avg, nops), gflops(min, nops)); }