static void mexsoftmax(float* y, float* shift, mwSize m, mwSize n) { __m128 i1, i2; __m128 o1, o2; while (m>0) { mwSize curn = n; float sum = 0.0f; declconst128(zero, 0.0f); while (curn>0 && ((unsigned long)(y+curn) & 15) != 0) { --curn; y[curn]=fastexp(y[curn]-*shift); sum += y[curn]; } __m128 s1 = _mm_load1_ps (shift); __m128 sum1 = zero; while (curn>7) { i1 = _mm_load_ps (y+curn-4); i2 = _mm_load_ps (y+curn-8); i1 = _mm_sub_ps (i1, s1); i2 = _mm_sub_ps (i2, s1); o1 = vfastexp(i1); o2 = vfastexp(i2); _mm_store_ps (y+curn-4, o1); sum1 = _mm_add_ps (sum1, o1); _mm_store_ps (y+curn-8, o2); sum1 = _mm_add_ps (sum1, o2); curn-=8; } sum1 = _mm_hadd_ps (sum1, sum1); sum1 = _mm_hadd_ps (sum1, sum1); sum += _mm_cvtss_f32 (sum1); while(curn>0) { --curn; y[curn]=fastexp(y[curn]-*shift); sum += y[curn]; } sum = 1.0f / sum; ptrdiff_t n_pdt = n; ptrdiff_t one_pdt = 1; sscal (&n_pdt, &sum, y, &one_pdt); ++shift; y+=n; --m; } }
int main (int argc, char *argv[]) { char buf[4096]; (void) argc; float x; for (x = -50; x > -1000; x -= 100) { assert (fastexp (x) >= 0); assert (fasterexp (x) >= 0); #ifdef __SSE2__ v4sf vx = v4sfl (x); assert (v4sf_index (vfastexp (vx), 0) >= 0); assert (v4sf_index (vfasterexp (vx), 0) >= 0); #endif } srand48 (69); strncpy (buf, argv[0], sizeof (buf) - 5); strncat (buf, ".out", 5); fclose (stderr); stderr = fopen (buf, "w"); test_fastexp (); test_fasterexp (); test_vfastexp (); test_vfasterexp (); time_fastexp (); time_fasterexp (); time_vfastexp (); time_vfasterexp (); return 0; }