double GetCleanCase(char pre, enum CLEAN_WHICH which, int icase, int imul, int mb, int nb, int kb) { char cwh[3] = {'M', 'N', 'K'}; char outf[ROUTLEN], fnam[ROUTLEN], *MCC, *MMFLAGS; int ld=kb, NB[3], NB1[3], NBs[3], nb0; int iflag, mb1, nb1, kb1, muladd, lat, mu, nu, ku; assert(GetUserCase(pre, icase, &iflag, NB1, NB1+1, NB1+2, &muladd, &lat, &mu, &nu, &ku, fnam, outf, &MCC, &MMFLAGS)); if (ATL_MMNoClean(iflag)) return(-1.0); NBs[0] = mb; NBs[1] = nb; NBs[2] = kb; nb0 = kb; if (which == CleanK) { nb0 = nb; if (ATL_MMVarLda(iflag)) ld = 0; } NB[0] = NB[1] = NB[2] = nb0; if (NB1[which]) NB[which] = NBs[which]; else NB[which] = 0; sprintf(outf, "res/%cup%cB%d_%d_%dx%dx%d", pre, cwh[which], icase, imul, mb, nb, kb); return(ummcase0(pre, mb, nb, kb, NB[0], NB[1], NB[2], ld, ld, 0, muladd, lat, mu, nu, ku, fnam, MCC, MMFLAGS, outf)); }
int GetPNB(char pre, enum CLEAN_WHICH which, int icase, int NB, int imul, int *pNB) /* * Returns number of pNB that are multiple of imul (max of 3) to be timed; * pNB contains the values to try */ { int i=1, j; int iflag, NBs[3], muladd, lat, mu, nu, ku; char fnam[ROUTLEN], *MCC, *MMFLAGS; pNB[0] = pNB[1] = pNB[2] = 0; assert(GetUserCase(pre, icase, &iflag, NBs, NBs+1, NBs+2, &muladd, &lat, &mu, &nu, &ku, fnam, fnam, &MCC, &MMFLAGS)); if (NBs[which] < 0) pNB[0] = -NBs[which]; else { j = pNB[0] = MakeMult(NB-NB/8, imul); if (!j) pNB[0] = imul; j = MakeMult(NB/2, imul); if (j && j != pNB[0]) { pNB[1] = j; i = 2; j = NB/8; if (NB >= 32) j = Mmax(j, 16); j = pNB[2] = ((j+imul-1)/imul)*imul; if (j && j != pNB[1] && j != pNB[0]) i = 3; else pNB[2] = 0; } } return(i); }
int GetIflag(char pre, int icase) { int iflag, mb, nb, kb, muladd, lat, mu, nu, ku; char fnam[ROUTLEN], *MCC, *MMFLAGS; assert(GetUserCase(pre, icase, &iflag, &mb, &nb, &kb, &muladd, &lat, &mu, &nu, &ku, fnam, fnam, &MCC, &MMFLAGS)); return(iflag); }
MULTHEAD *BuildTable(char pre, enum CLEAN_WHICH which, int nb) /* * Builds table of possible cleanup codes, depending on which: * 0 : pMB * 1 : pNB * 2 : pKB */ { ROUTNODE *rn; int i, n, ID, NB[3]; int iin, io1, io2, iflag, muladd, lat, mu, nu, ku; char *MCC, *MMFLAGS; char rout[ROUTLEN], auth[AUTHLEN]; switch(which) { case CleanM: iin = 0; io1 = 1; io2 = 2; break; case CleanN: iin = 1; io1 = 0; io2 = 2; break; case CleanK: iin = 2; io1 = 0; io2 = 1; break; case CleanNot: exit(-1); } n = NumUserCases(pre); for (i=0; i < n; i++) { rn = NULL; ID = GetUserCase(pre, -i, &iflag, NB, NB+1, NB+2, &muladd, &lat, &mu, &nu, &ku, rout, auth, &MCC, &MMFLAGS); if (ATL_MMNoClean(iflag)) continue; if (NB[io1] < 0 && NB[io1] != -nb) continue; if (NB[io2] < 0 && NB[io2] != -nb) continue; if (NB[io1] && (nb % NB[io1])) continue; if (NB[io2] && (nb % NB[io2])) continue; if (NB[iin] < 0) { if (-NB[iin] < nb) rn = GetRoutNode(-NB[iin], rout, ID, NOTIMED); } else if (NB[iin] == 0) rn = GetRoutNode(1, rout, ID, NOTIMED); else if (NB[iin] < nb) rn = GetRoutNode(NB[iin], rout, ID, NOTIMED); if (rn) rn->fixed = IsCaseFixed(pre, ID, which); } return(imhead); }
int utstmmcase(char pre, int ifile, int NB) { char outnam[256], fnam[256]; char *MCC, *MMFLAGS; int iflag, mb, nb, kb, muladd, lat, mu, nu, ku; assert(GetUserCase(pre, ifile, &iflag, &mb, &nb, &kb, &muladd, &lat, &mu, &nu, &ku, fnam, outnam, &MCC, &MMFLAGS)); return(ummtstcase0(pre, NB, NB, NB, NB, NB, NB, NB, NB, 0, muladd, lat, mu, nu, ku, fnam, MCC, MMFLAGS)); }
double ummcase(char pre, int ifile, int NB) { char outnam[256], fnam[256]; char *MCC, *MMFLAGS; int iflag, mb, nb, kb, muladd, lat, mu, nu, ku; assert(GetUserCase(pre, ifile, &iflag, &mb, &nb, &kb, &muladd, &lat, &mu, &nu, &ku, fnam, outnam, &MCC, &MMFLAGS)); if (ATL_MMCleanOnly(iflag)) return(0.0); /* don't run if for cleanup only */ return(ummcase0(pre, NB, NB, NB, NB, NB, NB, NB, NB, 0, muladd, lat, mu, nu, ku, fnam, MCC, MMFLAGS, GetUserOutFile(pre, ifile, NB, NB, NB))); }
int IsCaseFixed(char pre, int icase, enum CLEAN_WHICH which) { char stmp[ROUTLEN], *MCC, *MMFLAGS; int iflag, ma, lat, mu, nu, ku, NB[3]; assert(GetUserCase(pre, icase, &iflag, NB, NB+1, NB+2, &ma, &lat, &mu, &nu, &ku, stmp, stmp, &MCC, &MMFLAGS)); if (NB[which] > 0) { if (which == CleanM && ATL_MMVarM(iflag)) ma = 0; else if (which == CleanN && ATL_MMVarN(iflag)) ma = 0; else if (which == CleanK && ATL_MMVarK(iflag) && ATL_MMVarLda(iflag)) ma = 0; else ma = 1; } else if (NB[which] < 0) ma = 2; else { if (which == CleanK && !ATL_MMVarLda(iflag)) ma = 1; else ma = 0; } return(ma); }
int FindBestUser(char pre, int nb0) /* * returns index in <pre>cases.dsc of best user-supplied GEMM, using * a blocking factor as close to nb0 as possible */ { char *MCC, *MMFLAGS; char ln[256], fnam[256]; double mf, mfbest=0.0; int ibest=(-1), i, ncases; int ID, iflag, NB, mb, nb, kb, ma, lat, mu, nu, ku; ncases = NumUserCases(pre); for (i=0; i < ncases; i++) { ID = GetUserCase(pre, -i, &iflag, &mb, &nb, &kb, &ma, &lat, &mu, &nu, &ku, fnam, ln, &MCC, &MMFLAGS); assert(ID > 0); NB = GetUserNB(pre, nb0, mb, nb, kb); if (NB) { mf = ummcase(pre, ID, NB); if (mf > mfbest) { if (utstmmcase(pre, ID, NB)) { /* test kernel before accepting */ ibest = ID; mfbest = mf; } } fprintf(stdout, "%3d. NB=%3d, rout=%40s, MFLOP=%.2f\n", i, NB, fnam, mf); } } return(ibest); }