void mexFunction (int nlhs, mxArray *plhs[],
                  int nrhs, const mxArray*prhs[])

{
  if (nrhs < 2 || nrhs > 3) 
    mexErrMsgTxt ("Invalid number of input arguments");
  
  if (nlhs != 1)
    mexErrMsgTxt ("1 output arguments required");

  int d = mxGetM (prhs[0]);
  int na = mxGetN (prhs[0]);
  int nb = mxGetN (prhs[1]);

  if (mxGetM (prhs[1]) != d)
      mexErrMsgTxt("Dimension of base and query vectors are not consistent");
  
  
  if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS 
      || mxGetClassID(prhs[1]) != mxSINGLE_CLASS )
    mexErrMsgTxt ("need single precision array"); 


  float *a = (float*) mxGetPr (prhs[0]);  
  float *b = (float*) mxGetPr (prhs[1]); 

  /* ouptut: centroids, assignment, distances */
  plhs[0] = mxCreateNumericMatrix (na, nb, mxSINGLE_CLASS, mxREAL);
  float *dis = (float*) mxGetPr (plhs[0]);

  compute_cross_distances (d, na, nb, a, b, dis);
}
Exemplo n.º 2
0
void compute_cross_distances_thread (int d, int na, int nb,
                                     const float *a,
                                     const float *b, float *dist2,
                                     int nt) 
{
  cross_distances_params_t t={-1,d,na,nb,a,b,dist2,nt};
  
  int n=MAX(na,nb);
  
  if(n<nt) /* too small, no threads */
    compute_cross_distances(d,na,nb,a,b,dist2);
  else { 
    t.split_a=na>nb;    
    compute_tasks(nt,nt,&compute_cross_distances_task,&t);
  } 
}
Exemplo n.º 3
0
void knn_full (int distance_type,int n1, int n2, int d, int k,
	       const float *mat2, const float *mat1,
	       const float *vw_weights,
	       int *vw, float *vwdis)
{
  assert (k <= n2);

  if(k==1) {
    nn_single_full(distance_type, n1, n2, d, mat2, mat1, vw_weights, vw, vwdis);
    return;
  }

  
  int step1 = MIN (n1, BLOCK_N1), step2 = MIN (n2, BLOCK_N2);

  float *dists = fvec_new (step1 * step2);


  /* allocate all heaps at once */
  long oneh = fbinheap_sizeof(k);
  // oneh=(oneh+7) & ~7; /* round up to 8 bytes */
  char *minbuf = malloc (oneh * step1);

#define MINS(i) ((fbinheap_t*)(minbuf + oneh * i))
  
  long i1,i2,j1,j2;
  for (i1 = 0; i1 < n1; i1 += step1) {  

    int m1 = MIN (step1, n1 - i1);

    /* clear mins */
    for (j1 = 0; j1 < m1; j1++) 
      fbinheap_init(MINS(j1),k);
        

    for (i2 = 0; i2 < n2 ; i2 += step2) {     
      
      int m2 = MIN (step2, n2 - i2);
      
      
      if(distance_type==2)       
        compute_cross_distances (d, m2, m1, mat2+i2*d, mat1+i1*d, dists);
      else 
        compute_cross_distances_alt (distance_type, d, m2, m1, mat2+i2*d, mat1+i1*d, dists);    

      if(vw_weights) {
        for(j1=0;j1<m1;j1++) for (j2 = 0; j2 < m2; j2++)
          dists[j1 * m2 + j2] *= vw_weights[j2 + i2];        
      }

      /* update mins */

      for(j1=0;j1<m1;j1++) {
        float *dline=dists+j1*m2; 
        fbinheap_addn_label_range(MINS(j1),m2,i2,dline);
      }      

    }  

    for (j1 = 0; j1 < m1; j1++) {
      fbinheap_t *mh = MINS(j1);
      assert (mh->k == k);
      fbinheap_sort(mh, vw + (i1+j1) * k, vwdis + (i1+j1) * k);
    }
  }

#undef MINS
  free (minbuf);
  free(dists);
}
Exemplo n.º 4
0
/* n1 = pts */
static void nn_single_full (int distance_type,
			    int n1, int n2, int d,
			    const float *mat2, const float *mat1, 
			    const float *vw_weights,                             
			    int *vw, float *vwdis)
{
  int step1 = MIN (n1, BLOCK_N1), step2 = MIN (n2, BLOCK_N2);

  float *dists = fvec_new (step1 * step2);

  /* divide the dataset into sub-blocks to:
   * - not make a too big dists2 output array 
   */
  
  long i1,i2,j1,j2;
  for (i1 = 0; i1 < n1; i1 += step1) {  

    int m1 = MIN (step1, n1 - i1);

    /* clear mins */

    for (j1 = 0; j1 < m1; j1++) {
      vw[j1+i1]=-1;
      vwdis[j1+i1]=1e30;
    }

    for (i2 = 0; i2 < n2 ; i2 += step2) {     
      
      int m2 = MIN (step2, n2 - i2);
      
      if(distance_type==2)       
        compute_cross_distances (d, m2, m1, mat2+i2*d, mat1+i1*d, dists);
      else
        compute_cross_distances_alt (distance_type, d, m2, m1, mat2+i2*d, mat1+i1*d, dists);

      if(vw_weights) {
        for(j1=0;j1<m1;j1++) for (j2 = 0; j2 < m2; j2++)
          dists[j1 * m2 + j2] *= vw_weights[j2 + i2];        
      }

      /* update mins */

      for(j1=0;j1<m1;j1++) {
        float *dline=dists+j1*m2;
        
        int imin=vw[i1+j1];
        float dmin=vwdis[i1+j1];

        for(j2=0;j2<m2;j2++) 
          if(dline[j2]<dmin) {
            imin=j2+i2;
            dmin=dline[j2];
          }
          
        vw[i1+j1]=imin;
        vwdis[i1+j1]=dmin;

      }      

    }  
  }

  free (dists);
}