void mexFunction(int nout, mxArray *out[], int nin, const mxArray *in[]) { enum {IN_FOREST = 0, IN_DATA, IN_QUERY, IN_END} ; enum {OUT_INDEX = 0, OUT_DISTANCE} ; int verbose = 0 ; int opt ; int next = IN_END ; mxArray const *optarg ; VlKDForest * forest ; mxArray const * forest_array = in[IN_FOREST] ; mxArray const * data_array = in[IN_DATA] ; mxArray const * query_array = in[IN_QUERY] ; mxArray * index_array ; mxArray * distance_array ; void * query ; vl_uint32 * index ; void * distance ; vl_size numNeighbors = 1 ; vl_size numQueries ; vl_uindex qi, ni; unsigned int numComparisons = 0 ; unsigned int maxNumComparisons = 0 ; VlKDForestNeighbor * neighbors ; mxClassID dataClass ; VL_USE_MATLAB_ENV ; /* ----------------------------------------------------------------- * Check the arguments * -------------------------------------------------------------- */ if (nin < 3) { vlmxError(vlmxErrNotEnoughInputArguments, NULL) ; } if (nout > 2) { vlmxError(vlmxErrTooManyOutputArguments, NULL) ; } forest = new_kdforest_from_array (forest_array, data_array) ; dataClass = mxGetClassID (data_array) ; if (mxGetClassID (query_array) != dataClass) { vlmxError(vlmxErrInvalidArgument, "QUERY must have the same storage class as DATA.") ; } if (! vlmxIsReal (query_array)) { vlmxError(vlmxErrInvalidArgument, "QUERY must be real.") ; } if (! vlmxIsMatrix (query_array, forest->dimension, -1)) { vlmxError(vlmxErrInvalidArgument, "QUERY must be a matrix with TREE.NUMDIMENSIONS rows.") ; } while ((opt = vlmxNextOption (in, nin, options, &next, &optarg)) >= 0) { switch (opt) { case opt_num_neighs : if (! vlmxIsScalar(optarg) || (numNeighbors = mxGetScalar(optarg)) < 1) { vlmxError(vlmxErrInvalidArgument, "NUMNEIGHBORS must be a scalar not smaller than one.") ; } break; case opt_max_num_comparisons : if (! vlmxIsScalar(optarg)) { vlmxError(vlmxErrInvalidArgument, "MAXNUMCOMPARISONS must be a scalar.") ; } maxNumComparisons = mxGetScalar(optarg) ; break; case opt_verbose : ++ verbose ; break ; } } vl_kdforest_set_max_num_comparisons (forest, maxNumComparisons) ; neighbors = vl_malloc (sizeof(VlKDForestNeighbor) * numNeighbors) ; query = mxGetData (query_array) ; numQueries = mxGetN (query_array) ; out[OUT_INDEX] = index_array = mxCreateNumericMatrix (numNeighbors, numQueries, mxUINT32_CLASS, mxREAL) ; out[OUT_DISTANCE] = distance_array = mxCreateNumericMatrix (numNeighbors, numQueries, dataClass, mxREAL) ; index = mxGetData (index_array) ; distance = mxGetData (distance_array) ; if (verbose) { VL_PRINTF ("vl_kdforestquery: number of queries: %d\n", numQueries) ; VL_PRINTF ("vl_kdforestquery: number of neighbors per query: %d\n", numNeighbors) ; VL_PRINTF ("vl_kdforestquery: max num of comparisons per query: %d\n", vl_kdforest_get_max_num_comparisons (forest)) ; } for (qi = 0 ; qi < numQueries ; ++ qi) { numComparisons += vl_kdforest_query (forest, neighbors, numNeighbors, query) ; switch (dataClass) { case mxSINGLE_CLASS: { float * distance_ = (float*) distance ; for (ni = 0 ; ni < numNeighbors ; ++ni) { *index++ = neighbors[ni].index + 1 ; *distance_++ = neighbors[ni].distance ; } query = (float*)query + vl_kdforest_get_data_dimension (forest) ; distance = distance_ ; break ; } case mxDOUBLE_CLASS: { double * distance_ = (double*) distance ; for (ni = 0 ; ni < numNeighbors ; ++ni) { *index++ = neighbors[ni].index + 1 ; *distance_++ = neighbors[ni].distance ; } query = (double*)query + vl_kdforest_get_data_dimension (forest) ; distance = distance_ ; break ; } default: abort() ; } } if (verbose) { VL_PRINTF ("vl_kdforestquery: number of comparisons per query: %.3f\n", ((double) numComparisons) / numQueries) ; VL_PRINTF ("vl_kdforestquery: number of comparisons per neighbor: %.3f\n", ((double) numComparisons) / (numQueries * numNeighbors)) ; } vl_kdforest_delete (forest) ; vl_free (neighbors) ; }
vl_size vl_kdforest_query_with_array (VlKDForest * self, vl_uint32 * indexes, vl_size numNeighbors, vl_size numQueries, void * distances, void const * queries) { vl_size numComparisons = 0; vl_type dataType = vl_kdforest_get_data_type(self) ; vl_size dimension = vl_kdforest_get_data_dimension(self) ; #ifdef _OPENMP #pragma omp parallel default(shared) num_threads(vl_get_max_threads()) #endif { vl_index qi ; vl_size thisNumComparisons = 0 ; VlKDForestSearcher * searcher ; VlKDForestNeighbor * neighbors ; #ifdef _OPENMP #pragma omp critical #endif { searcher = vl_kdforest_new_searcher(self) ; neighbors = vl_calloc (sizeof(VlKDForestNeighbor), numNeighbors) ; } #ifdef _OPENMP #pragma omp for #endif for(qi = 0 ; qi < (signed)numQueries; ++ qi) { switch (dataType) { case VL_TYPE_FLOAT: { vl_size ni; thisNumComparisons += vl_kdforestsearcher_query (searcher, neighbors, numNeighbors, (float const *) (queries) + qi * dimension) ; for (ni = 0 ; ni < numNeighbors ; ++ni) { indexes [qi*numNeighbors + ni] = (vl_uint32) neighbors[ni].index ; if (distances){ *((float*)distances + qi*numNeighbors + ni) = neighbors[ni].distance ; } } break ; } case VL_TYPE_DOUBLE: { vl_size ni; thisNumComparisons += vl_kdforestsearcher_query (searcher, neighbors, numNeighbors, (double const *) (queries) + qi * dimension) ; for (ni = 0 ; ni < numNeighbors ; ++ni) { indexes [qi*numNeighbors + ni] = (vl_uint32) neighbors[ni].index ; if (distances){ *((double*)distances + qi*numNeighbors + ni) = neighbors[ni].distance ; } } break ; } default: abort() ; } } #ifdef _OPENMP #pragma omp critical #endif { numComparisons += thisNumComparisons ; vl_kdforestsearcher_delete (searcher) ; vl_free (neighbors) ; } } return numComparisons ; }