示例#1
0
static void 
mexsoftmax(float* y, float* shift, mwSize m, mwSize n) {
  __m128 i1, i2;
  __m128 o1, o2;
 
  while (m>0)
    {
      mwSize curn = n;
      float sum = 0.0f;
      declconst128(zero, 0.0f);
      
      while (curn>0 && ((unsigned long)(y+curn) & 15) != 0)
        {
          --curn;
          y[curn]=fastexp(y[curn]-*shift);
          sum += y[curn];
        }

      __m128 s1 = _mm_load1_ps (shift);
      __m128 sum1 = zero;

      while (curn>7) {
        i1 = _mm_load_ps (y+curn-4);
        i2 = _mm_load_ps (y+curn-8);
        i1 = _mm_sub_ps (i1, s1);
        i2 = _mm_sub_ps (i2, s1);
        o1 = vfastexp(i1);
        o2 = vfastexp(i2);
        _mm_store_ps (y+curn-4, o1);
        sum1 = _mm_add_ps (sum1, o1);
        _mm_store_ps (y+curn-8, o2);
        sum1 = _mm_add_ps (sum1, o2);
        curn-=8;
      }

      sum1 = _mm_hadd_ps (sum1, sum1);
      sum1 = _mm_hadd_ps (sum1, sum1);
      sum += _mm_cvtss_f32 (sum1);
     
      while(curn>0) {
        --curn;
        y[curn]=fastexp(y[curn]-*shift);
        sum += y[curn];
      }

      sum = 1.0f / sum;

      ptrdiff_t n_pdt = n;
      ptrdiff_t one_pdt = 1;

      sscal (&n_pdt, &sum, y, &one_pdt);

      ++shift;
      y+=n;
      --m;
    }
}
示例#2
0
int 
main (int   argc,
      char *argv[])
{
  char buf[4096];

  (void) argc;

  float x;
  for (x = -50; x > -1000; x -= 100)
    {
      assert (fastexp (x) >= 0);
      assert (fasterexp (x) >= 0);
#ifdef __SSE2__
      v4sf vx = v4sfl (x);
      assert (v4sf_index (vfastexp (vx), 0) >= 0);
      assert (v4sf_index (vfasterexp (vx), 0) >= 0);
#endif
    }

  srand48 (69);

  strncpy (buf, argv[0], sizeof (buf) - 5);
  strncat (buf, ".out", 5);

  fclose (stderr);
  stderr = fopen (buf, "w");


  test_fastexp ();
  test_fasterexp ();
  test_vfastexp ();
  test_vfasterexp ();

  time_fastexp ();
  time_fasterexp ();
  time_vfastexp ();
  time_vfasterexp ();

  return 0;
}