/* Function: p7_MSVFilter() * Synopsis: Calculates MSV score, vewy vewy fast, in limited precision. * * Purpose: Calculates an approximation of the MSV score for sequence * <dsq> of length <L> residues, using optimized profile <om>, * and the one-row DP matrix <ox>. Return the * estimated MSV score (in nats) in <ret_sc>. * * Score may overflow (and will, on high-scoring * sequences), but will not underflow. * * <ox> will be resized if needed. It's fine if it was * just <_Reuse()'d> from a previous, smaller profile. * * The model may be in any mode, because only its match * emission scores will be used. The MSV filter inherently * assumes a multihit local mode, and uses its own special * state transition scores, not the scores in the profile. * * Args: dsq - digital target sequence, 1..L * L - length of dsq in residues * om - optimized profile * ox - filter DP matrix (one row) * ret_sc - RETURN: MSV score (in nats) * * Returns: <eslOK> on success. * <eslERANGE> if the score overflows the limited range; in * this case, this is a high-scoring hit. * <ox> may have been resized. * * Throws: <eslEMEML> if <ox> reallocation fails. */ int p7_MSVFilter_avx(const ESL_DSQ *dsq, int L, const P7_OPROFILE *om, P7_FILTERMX *ox, float *ret_sc) { #ifdef HAVE_AVX2 uint8_t xJ; /* special states' scores */ register __m256i mpv_AVX; /* previous row values */ register __m256i xEv_AVX; /* E state: keeps max for Mk->E as we go */ register __m256i xBv_AVX; /* B state: splatted vector of B[i-1] for B->Mk calculations */ register __m256i sv_AVX; /* temp storage of 1 curr row value in progress */ register __m256i biasv_AVX; /* emission bias in a vector */ __m256i *dp_AVX; /* the dp row memory */ __m256i *rsc_AVX; /* will point at om->rbv[x] for residue x[i] */ __m256i xJv_AVX; /* vector for states score */ __m256i tjbmv_AVX; /* vector for cost of moving {JN}->B->M */ __m256i tecv_AVX; /* vector for E->C cost */ __m256i basev_AVX; /* offset for scores */ __m256i ceilingv_AVX; /* saturated simd value used to test for overflow */ __m256i tempv_AVX; /* work vector */ int Q_AVX = P7_NVB_AVX(om->M); /* segment length: # of vectors */ int q_AVX; /* counter over vectors 0..nq-1 */ int i; /* counter over sequence positions 1..L */ int cmp; int status; //printf("Starting MSVFilter\n"); /* Contract checks */ ESL_DASSERT1(( om->mode == p7_LOCAL )); /* Production code assumes multilocal mode w/ length model <L> */ ESL_DASSERT1(( om->L == L )); /* ... and it's easy to forget to set <om> that way */ ESL_DASSERT1(( om->nj == 1.0f )); /* ... hence the check */ /* ... which you can disable, if you're playing w/ config */ /* note however that it makes no sense to run MSV w/ a model in glocal mode */ /* Try highly optimized Knudsen SSV filter first. * Note that SSV doesn't use any main memory (from <ox>) at all! */ //extern uint64_t SSV_time; uint64_t filter_start_time = __rdtsc(); status = p7_SSVFilter_avx(dsq, L, om, ret_sc); uint64_t filter_end_time = __rdtsc(); //SSV_time += (filter_end_time - filter_start_time); if (status != eslENORESULT) return status; extern uint64_t full_MSV_calls; full_MSV_calls++; /* Resize the filter mx as needed */ if (( status = p7_filtermx_GrowTo(ox, om->M)) != eslOK) ESL_EXCEPTION(status, "Reallocation of MSV filter matrix failed"); dp_AVX = ox->dp_AVX; /* ditto this */ /* Matrix type and size must be set early, not late: debugging dump functions need this information. */ ox->M = om->M; ox->type = p7F_MSVFILTER; /* Initialization. In offset unsigned arithmetic, -infinity is 0, and 0 is om->base. */ biasv_AVX = _mm256_set1_epi8((int8_t) om->bias_b); /* yes, you can set1() an unsigned char vector this way */ for (q_AVX = 0; q_AVX < Q_AVX; q_AVX++) dp_AVX[q_AVX] = _mm256_setzero_si256(); /* saturate simd register for overflow test */ ceilingv_AVX = _mm256_cmpeq_epi8(biasv_AVX, biasv_AVX); basev_AVX = _mm256_set1_epi8((int8_t) om->base_b); tjbmv_AVX = _mm256_set1_epi8((int8_t) om->tjb_b + (int8_t) om->tbm_b); tecv_AVX = _mm256_set1_epi8((int8_t) om->tec_b); xJv_AVX = _mm256_subs_epu8(biasv_AVX, biasv_AVX); xBv_AVX = _mm256_subs_epu8(basev_AVX, tjbmv_AVX); #ifdef p7_DEBUGGING if (ox->do_dumping) { uint8_t xB; xB = _mm_extract_epi16(xBv, 0); xJ = _mm_extract_epi16(xJv, 0); p7_filtermx_DumpMFRow(ox, 0, 0, 0, xJ, xB, xJ); } #endif for (i = 1; i <= L; i++) /* Outer loop over residues*/ { rsc_AVX = om->rbv_AVX[dsq[i]]; xEv_AVX = _mm256_setzero_si256(); /* Right shifts by 1 byte. 4,8,12,x becomes x,4,8,12. * Because ia32 is littlendian, this means a left bit shift. * Zeros shift on automatically, which is our -infinity. */ __m256i dp_temp_AVX = dp_AVX[Q_AVX -1]; mpv_AVX = esl_avx_leftshift_one(dp_temp_AVX); for (q_AVX = 0; q_AVX < Q_AVX; q_AVX++) { /* Calculate new MMXo(i,q); don't store it yet, hold it in sv. */ sv_AVX = _mm256_max_epu8(mpv_AVX, xBv_AVX); sv_AVX = _mm256_adds_epu8(sv_AVX, biasv_AVX); sv_AVX = _mm256_subs_epu8(sv_AVX, *rsc_AVX); rsc_AVX++; xEv_AVX = _mm256_max_epu8(xEv_AVX, sv_AVX); mpv_AVX = dp_AVX[q_AVX]; /* Load {MDI}(i-1,q) into mpv */ dp_AVX[q_AVX] = sv_AVX; /* Do delayed store of M(i,q) now that memory is usable */ } /* test for the overflow condition */ tempv_AVX = _mm256_adds_epu8(xEv_AVX, biasv_AVX); tempv_AVX = _mm256_cmpeq_epi8(tempv_AVX, ceilingv_AVX); cmp = _mm256_movemask_epi8(tempv_AVX); /* Now the "special" states, which start from Mk->E (->C, ->J->B) * Use shuffles instead of shifts so when the last max has completed, * the last four elements of the simd register will contain the * max value. Then the last shuffle will broadcast the max value * to all simd elements. */ xEv_AVX = _mm256_set1_epi8(esl_avx_hmax_epu8(xEv_AVX)); // broadcast the max byte from original xEv_AVX // to all bytes of xEv_AVX /* immediately detect overflow */ if (cmp != 0x0000) { // MSV_end_time = __rdtsc(); // MSV_time += (MSV_end_time - MSV_start_time); *ret_sc = eslINFINITY; return eslERANGE; } xEv_AVX = _mm256_subs_epu8(xEv_AVX, tecv_AVX); xJv_AVX = _mm256_max_epu8(xJv_AVX,xEv_AVX); xBv_AVX = _mm256_max_epu8(basev_AVX, xJv_AVX); xBv_AVX = _mm256_subs_epu8(xBv_AVX, tjbmv_AVX); #ifdef p7_DEBUGGING if (ox->do_dumping) { uint8_t xB, xE; xB = _mm_extract_epi16(xBv, 0); xE = _mm_extract_epi16(xEv, 0); xJ = _mm_extract_epi16(xJv, 0); p7_filtermx_DumpMFRow(ox, i, xE, 0, xJ, xB, xJ); } #endif } /* end loop over sequence residues 1..L */ /* finally C->T, and add our missing precision on the NN,CC,JJ back */ xJ = _mm256_extract_epi8(xJv_AVX, 0); *ret_sc = ((float) (xJ - om->tjb_b) - (float) om->base_b); *ret_sc /= om->scale_b; *ret_sc -= 3.0; /* that's ~ L \log \frac{L}{L+3}, for our NN,CC,JJ */ /* MSV_end_time = __rdtsc(); MSV_time += (MSV_end_time - MSV_start_time); */ return eslOK; #endif #ifndef HAVE_AVX2 return eslENORESULT; // Stub so we have something to link if we build without AVX2 support #endif }
//Prints the 5x2 YMM registers given as input to the function. void print_state_as_hex(YMM(*state)[2]) { //Print u64 index for help printf("\n"); for (int i = 0; i < 8; i++) printf("i0 "); printf("\t"); for (int i = 0; i < 8; i++) printf("i1 "); printf("\t"); for (int i = 0; i < 8; i++) printf("i2 "); printf("\t"); for (int i = 0; i < 8; i++) printf("i3 "); printf("\t"); for (int i = 0; i < 8; i++) printf("i0 "); printf("\t"); for (int i = 0; i < 8; i++) printf("i1 "); printf("\t"); for (int i = 0; i < 8; i++) printf("i2 "); printf("\t"); for (int i = 0; i < 8; i++) printf("i3 "); printf("\t"); printf("\n"); //_mm256_extract_epi8 requires a constant as its second parameters, so did some loop unrolling here. for (int reg_no = 0; reg_no < 5; reg_no++) { //Print first reg section printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 0)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 1)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 2)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 3)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 4)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 5)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 6)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 7)); printf("\t"); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 8)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 9)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 10)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 11)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 12)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 13)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 14)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 15)); printf("\t"); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 16)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 17)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 18)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 19)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 20)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 21)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 22)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 23)); printf("\t"); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 24)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 25)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 26)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 27)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 28)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 29)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 30)); printf("%02x ", _mm256_extract_epi8(state[reg_no][0], 31)); printf("\t"); //Print second reg section printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 0)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 1)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 2)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 3)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 4)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 5)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 6)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 7)); printf("\t"); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 8)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 9)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 10)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 11)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 12)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 13)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 14)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 15)); printf("\t"); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 16)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 17)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 18)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 19)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 20)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 21)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 22)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 23)); printf("\t"); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 24)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 25)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 26)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 27)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 28)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 29)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 30)); printf("%02x ", _mm256_extract_epi8(state[reg_no][1], 31)); printf("\n"); } }